In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import os
import gc
import math
from tqdm import tqdm

from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from transformers import AutoTokenizer
from accelerate.utils import ProjectConfiguration, set_seed
from accelerate import Accelerator
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    UNet2DConditionModel,
)
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

In [None]:
from t2iadapter import (
    Adapter_XL,
    T2IConfig,
    import_model_class_from_model_name_or_path,
    compute_embeddings_sd1x5,
    generate_mri_slices,
    plot_generated_and_ground_truth
)
from slicedMRI import DatasetConfig, PairedMRIDataset
from eval import MRIEvaluator

In [None]:
# Hack to fix https://github.com/XPixelGroup/BasicSR/pull/650
#!sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/' /usr/local/lib/python3.12/dist-packages/basicsr/data/degradations.py
# Hack to fix other issue
# !sed -i 's/from diffusers.models.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput, logger/from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput, logger/' /content/T2I-Adapter/models/unet.py

debug: bool = True

### PairedMRIDataset - Setup

- Bli
- Bla
- Blubl

In [None]:
train_val_test_split = [0.8, 0.0, 0.2]
assert sum(train_val_test_split) == 1.0, "Dataset split should sum up to one"

shared_config: dict = {
    "data_dir": Path("./mri_dataset/"),
    "mode": "train",
    "fractions": train_val_test_split,
    "slice_axis": 2,
    "do_registration": True,
    "do_n4": False,
}
config = DatasetConfig(**shared_config)
train_dataset = PairedMRIDataset(config=DatasetConfig(**shared_config), verbose=1)
shared_config["mode"] = "test"
test_dataset = PairedMRIDataset(config=DatasetConfig(**shared_config), verbose=1)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=2)

In [None]:
if debug:

    def inspect_sample(sample):
        hr = sample["hr"]
        lr = sample["lr"]
        # convert to numpy (handle torch tensors)
        if isinstance(hr, torch.Tensor):
            hr = hr.detach().cpu().numpy()
        if isinstance(lr, torch.Tensor):
            lr = lr.detach().cpu().numpy()

        print("HR shape:", hr.shape, "LR shape:", lr.shape, " dtype:", hr.dtype)
        print(
            "HR min/max/mean:",
            float(np.nanmin(hr)),
            float(np.nanmax(hr)),
            float(np.nanmean(hr)),
        )
        print(
            "LR min/max/mean:",
            float(np.nanmin(lr)),
            float(np.nanmax(lr)),
            float(np.nanmean(lr)),
        )
        print("HR NaNs:", int(np.isnan(hr).sum()), "LR NaNs:", int(np.isnan(lr).sum()))
        print("HR unique (sample):", np.unique(hr.ravel()[:200]).tolist())
        # quick histogram (coarse)
        hist, edges = np.histogram(hr.ravel(), bins=10)
        print("HR hist bins:", hist, "edges:", np.round(edges, 4))

    num_plots = 40
    offset = 0
    fig, axes = plt.subplots(
        nrows=num_plots,
        ncols=2,
        figsize=(12, num_plots * 6),  # Adjust figure size based on number of plots
        dpi=100,
    )
    if num_plots == 1:
        axes = axes[np.newaxis, :]
    fig.suptitle(
        f"Generated vs. Ground Truth (First {num_plots} Slices)", fontsize=16, y=1.02
    )
    for i in range(num_plots):
        data = full_dataset[offset + i]
        print("=== SAMPLE", i, "===")
        inspect_sample(full_dataset[i + offset])
        ax_gt = axes[i, 0]
        # Ground Truth is typically grayscale (H, W), use 'gray' colormap
        # We assume the HR slices have been normalized to [0, 1] or similar
        ax_gt.imshow(np.squeeze(data["hr"]))
        ax_gt.set_title(f"HR Slice {i+1}", fontsize=12)
        ax_gt.axis("off")
        ax_lr = axes[i, 1]
        ax_lr.imshow(np.squeeze(data["lr"]))
        ax_lr.set_title(f"LR Slice {i+1}", fontsize=12)
        ax_lr.axis("off")
    plt.tight_layout(rect=[0, 0, 1, 1.01])  # Adjust layout to make space for suptitle
    plt.show()

### Training Setup

In [None]:
t2i_config = T2IConfig()

logging_dir = Path(t2i_config.output_dir, t2i_config.logging_dir)
accelerator_project_config = ProjectConfiguration(
    project_dir=t2i_config.output_dir, logging_dir=t2i_config.logging_dir
)
accelerator = Accelerator(
    gradient_accumulation_steps=t2i_config.gradient_accumulation_steps,
    mixed_precision=t2i_config.mixed_precision,
    log_with=t2i_config.report_to,
    project_config=accelerator_project_config,
)
set_seed(t2i_config.seed)
os.makedirs(t2i_config.output_dir, exist_ok=True)

# load the tokenizers
tokenizer = AutoTokenizer.from_pretrained(
    t2i_config.pretrained_model_name_or_path,
    subfolder="tokenizer",
    revision=t2i_config.revision,
    use_fast=False,
)
# load the correct scheduler and models
text_encoder_cls = import_model_class_from_model_name_or_path(
    t2i_config.pretrained_model_name_or_path,
    t2i_config.revision,
    subfolder="text_encoder",
)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(
    t2i_config.pretrained_model_name_or_path,
    subfolder="scheduler",
    prediction_type="v_prediction",  # velocity prediction
    timestep_spacing="trailing",  # for zero-SNR
    rescale_betas_zero_snr=True,  # enforces pure noise at t=1000
)
text_encoder = text_encoder_cls.from_pretrained(
    t2i_config.pretrained_model_name_or_path,
    subfolder="text_encoder",
    revision=t2i_config.revision,
)
vae_path = (
    t2i_config.pretrained_model_name_or_path
    if t2i_config.pretrained_vae_model_name_or_path is None
    else t2i_config.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
    vae_path,  # TODO: play around with this microsoft/mri-autoencoder-v0.1
    subfolder="vae" if t2i_config.pretrained_vae_model_name_or_path is None else None,
    revision=t2i_config.revision,
)
unet = UNet2DConditionModel.from_pretrained(
    t2i_config.pretrained_model_name_or_path,
    subfolder="unet",
    revision=t2i_config.revision,
)

# These are never trained to convert mode collapse (see controlnet paper)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)

In [None]:
if t2i_config.enable_xformers_memory_efficient_attention:
    import xformers # pyright: ignore[reportMissingImports]
    from packaging import version
    xformers_version = version.parse(xformers.__version__)
    if xformers_version == version.parse("0.0.16"):
        print(
            "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
        )
    unet.enable_xformers_memory_efficient_attention()
if t2i_config.gradient_checkpointing:
    unet.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if t2i_config.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True
if t2i_config.scale_lr:
    learning_rate = (
        t2i_config.learning_rate
        * t2i_config.gradient_accumulation_steps
        * t2i_config.train_batch_size
        * accelerator.num_processes
    )
else:
    learning_rate = t2i_config.learning_rate
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if t2i_config.use_8bit_adam:
    try:
        import bitsandbytes as bnb # pyright: ignore[reportMissingImports]
    except ImportError:
        raise ImportError(
            "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
        )
    optimizer_class = bnb.optim.AdamW8bit
else:
    optimizer_class = torch.optim.AdamW

In [None]:
adapter = Adapter_XL(
    channels=[320, 640, 1280, 1280],
    nums_rb=3,
    cin=3 * 64,
    ksize=3,
    sk=True,
    use_conv=True,
).to(accelerator.device)
params_to_optimize = adapter.parameters()
optimizer = optimizer_class(
    params_to_optimize,
    lr=learning_rate,
    betas=(t2i_config.adam_beta1, t2i_config.adam_beta2),
    weight_decay=t2i_config.adam_weight_decay,
    eps=t2i_config.adam_epsilon,
)

# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16
# Move vae, unet and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
if t2i_config.pretrained_vae_model_name_or_path is not None:
    vae.to(accelerator.device, dtype=weight_dtype)
else:
    vae.to(accelerator.device, dtype=torch.float32)
unet.to(accelerator.device)
text_encoder.to(accelerator.device, dtype=weight_dtype)

In [None]:
# Let's first compute all the embeddings so that we can free up the text encoders from memory.
text_encoders = [text_encoder]
tokenizers = [tokenizer]
gc.collect()
torch.cuda.empty_cache()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
# num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(1e7 / t2i_config.gradient_accumulation_steps)
if t2i_config.max_train_steps is None:
    t2i_config.max_train_steps = (
        t2i_config.num_train_epochs * num_update_steps_per_epoch
    )
    overrode_max_train_steps = True
lr_scheduler = get_scheduler(
    t2i_config.lr_scheduler_name,
    optimizer=optimizer,
    num_warmup_steps=t2i_config.lr_warmup_steps * accelerator.num_processes,
    num_training_steps=t2i_config.max_train_steps * accelerator.num_processes,
    num_cycles=t2i_config.lr_num_cycles,
    power=t2i_config.lr_power,
)
ema_adapter = EMAModel(
    adapter.parameters(),
    model_cls=adapter.__class__,  # Custom class, passing it here helps with some utilities
    decay=0.9999,
)
# Prepare everything with our `accelerator`.
adapter, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    adapter, unet, optimizer, train_loader, lr_scheduler
)

if accelerator.is_main_process:
    tracker_config = dict(vars(t2i_config))
    # accelerator.init_trackers(args.tracker_project_name, config=tracker_config)

total_batch_size = (
    t2i_config.train_batch_size
    * accelerator.num_processes
    * t2i_config.gradient_accumulation_steps
)

### Training Loop

In [None]:
global_step = 0
first_epoch = 0
initial_global_step = 0
loss_history = []

progress_bar = tqdm(
    range(0, t2i_config.max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, t2i_config.num_train_epochs):
    adapter.train()
    for step, batch in enumerate(train_loader):
        with accelerator.accumulate(adapter):
            bsz = batch["hr"].shape[0]
            h, w = batch["hr"].shape[-2:]
            hr_slices = batch["hr"].to(accelerator.device).float()
            hr_slices = hr_slices.expand(bsz, 3, h, w)
            lr_slices = batch["lr"].to(accelerator.device).float()
            condition = lr_slices.expand(bsz, 3, h, w)

            # VAE Encoding (Target)
            latents = vae.encode(hr_slices.to(vae.dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor
            latents = latents.to(weight_dtype)

            if t2i_config.pretrained_vae_model_name_or_path is None:
                latents = latents.to(weight_dtype)

            # Noise generation
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            timesteps = torch.randint(
                0,
                noise_scheduler.config.num_train_timesteps,
                (bsz,),
                device=latents.device,
            ).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Text Embedding
            prompt_embeds = compute_embeddings_sd1x5(
                batch=batch,
                proportion_empty_prompts=0.1,
                text_encoders=text_encoders,
                tokenizers=tokenizers,
            )["prompt_embeds"]

            # Adapter conditioning
            down_block_additional_residuals = adapter(condition)
            model_pred = unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=prompt_embeds,
                down_block_additional_residuals=down_block_additional_residuals,
            ).sample

            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(
                    f"Unknown prediction type {noise_scheduler.config.prediction_type}"
                )

            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            loss_history.append(loss.detach().item())

            accelerator.backward(loss)

            if accelerator.sync_gradients:
                params_to_clip = adapter.parameters()
                accelerator.clip_grad_norm_(params_to_clip, t2i_config.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                if accelerator.is_main_process:
                    unwrapped_adapter = accelerator.unwrap_model(adapter)
                    ema_adapter.step(unwrapped_adapter.parameters())
                optimizer.zero_grad()

        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1
            if accelerator.is_main_process:
                if global_step % 100 == 0:  # Increased logging interval slightly
                    plt.figure(figsize=(10, 5))
                    plt.plot(loss_history)
                    plt.title("Training Loss")
                    plt.savefig(os.path.join(t2i_config.output_dir, "loss_curve.png"))
                    plt.close()
                if global_step % t2i_config.checkpointing_steps == 0:
                    save_path = os.path.join(
                        t2i_config.output_dir, f"checkpoint-{global_step}"
                    )
                    os.makedirs(save_path, exist_ok=True)
                    unwrapped_adapter = accelerator.unwrap_model(adapter)
                    ema_adapter.store(unwrapped_adapter.parameters())  # Backup
                    ema_adapter.copy_to(unwrapped_adapter.parameters())  # Load EMA
                    torch.save(
                        unwrapped_adapter.state_dict(),
                        os.path.join(save_path, "adapter.pt"),
                    )
                    ema_adapter.restore(unwrapped_adapter.parameters())  # Restore
                    print(f"Saved EMA checkpoint to {save_path}")

        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)

# After training is done
if accelerator.is_main_process:
    unwrapped_adapter = accelerator.unwrap_model(adapter)
    # Final EMA Save
    ema_adapter.copy_to(unwrapped_adapter.parameters())
    torch.save(unwrapped_adapter.state_dict(), os.path.join(t2i_config.output_dir, "adapter_tuned.pt"))
    print("Training finished. Saved final T2I-Adapter EMA.")

# Evaluation

- Bli
- Bla

In [None]:
metrics = []
psnr = PeakSignalNoiseRatio(data_range=1.0).to(accelerator.device)
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(accelerator.device)

for batch in test_loader:
    with torch.no_grad():
        prompt_embeds_eval = compute_embeddings_sd1x5(
            batch=batch,
            proportion_empty_prompts=0.0,  # Use 0.0 for evaluation when skipping CFG
            text_encoders=text_encoders,
            tokenizers=tokenizers,
        )["prompt_embeds"]
    image_batch_np, postprocessed = generate_mri_slices(
        batch=batch,
        adapter=adapter,
        unet=unet,
        vae=vae,
        noise_scheduler=noise_scheduler,
        prompt_embeds=prompt_embeds_eval,
        num_inference_steps=500,
        weight_dtype=weight_dtype,
        accelerator=accelerator,
    )
    metrics_batch = MRIEvaluator.eval_all_metrics(
        ground_truth=batch["hr"],
        generated=torch.from_numpy(image_batch_np).to(accelerator.device),
        psnr=psnr,
        ssim=ssim,
    )
    metrics.append(metrics_batch)

plot_generated_and_ground_truth(
    generated_slices_np=image_batch_np, batch=batch, num_images_to_show=6
)

Error in callback <function _draw_all_if_interactive at 0x11692a090> (for post_execute), with arguments args (),kwargs {}:


RecursionError: maximum recursion depth exceeded

RecursionError: maximum recursion depth exceeded

<Figure size 640x480 with 8 Axes>