- Caveats:
  - Do the images have phase information or only magnitude bias (two channels). If only one channel we might fall into the magnitude bias (i.e. k-space is collapsed for visualization)
  - Condition on k-space mask & measured k-space
  - Loss: standard denoising loss (MSE on noise prediction) works well ⚡️ so far we use v_predicion. 
    - Optionally add perceptual/SSIM auxiliary losses on the final output when supervised paired training is available. However, perceptual losses can encourage hallucination; keep them as regualarizer.
  - Data Consistency when sampling from $p(x|y)$. This is known as posterior sampling with a learned prior & project to approximate sampling from the posterior because we alternate stochastic denoising (prior) and deterministic projection (likelihood).
    - Train a standard diffusion model on HF images to learn the denoiser $\epsilon_\theta$ 
    - During sampling, incoprterate the measurement constraint by performing data consistency projection at each denoising step: after a denoising step that outputs an image $\tilde{x}$, replace the low-frequency k-space of $\tilde{x}$ with the observed k-space from $y$ (or do a gradient step to minimize $\||A(\tilde{x}) - y||^2$). This enforces fidelity to measured low frequencies and reduces hallucination. (see [ScoreMD](https://arxiv.org/pdf/2110.05243))
  - Overfitting to scanner / protocol: include multi-center/multi-protocol data or use domain-adaptation.
  - Mismatch between simulated and real low-freq data: when training with synthetically low-passed images (ideal LPF) but test on different acquisition patterns, performance drops. Use acquisition-matched simulation or train with a diversity of realistic forward operators and noise levels.
  - Training on magnitude images: loses phase info — may produce inconsistent k-space and artifacts. Prefer complex-domain training if the downstream requires phase.

What to do when utilising a novel architecture:
- Data domain: train on complex images (two channels: real/imag) or on coil-combined images (e.g., SENSE‐reconstructed). Using complex data preserves phase information and avoids magnitude bias.
- Latent Diffusion: Utilise a pretrained VAE in the domain of MRIs rather than a own network.
- Network: UNet backbone with time embedding (sinusoidal) and optional cross-attention conditioning on low-res measurement (see SR3 / Cascaded Diffusion literature). Use residual blocks, group normalization, and attention at mid/high resolutions.
- Conditioning: provide low-frequency image (zero-filled inverse FFT) as conditional input by channel-concatenation or via cross-attention. Alternatively condition on k-space mask & measured k-space.
- Noise schedule: cosine or linear schedule, common defaults from DDPM/EDM. Adjust max noise to match data dynamic range.
- Complex handling: either two channels (real/imag) with network learning both, or magnitude + phase decomposition and predict real/imag reconstructions.
- Loss: standard denoising loss (MSE on noise prediction) works well. Optionally add perceptual/SSIM auxiliary losses on the final output when supervised paired training is available. Be careful: perceptual losses can encourage hallucination; keep them auxiliary.

- Initialize sampling with upsampled low-res image rather than pure Gaussian noise. This speeds convergence and reduces hallucination (especially for high upscaling factors).
- Data consistency is crucial — models without DC can hallucinate high-frequency content inconsistent with measured low-freq info.
- Multi-coil: integrate coil sensitivities explicitly and perform DC in coil k-space; solve for image using CG with the learned prior used as a denoiser (plug-and-play CG).
- Patch vs full images: training on patches speeds training; ensure the patch size supports the highest receptive field your UNet needs for contextual structures.
- Normalization: normalize per-volume intensity (e.g. whiten or scale by robust max). Keep track of normalization to invert at test time.
- Stability: use EMA of weights for sampling; mixed precision training and gradient clipping help.

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import os
import gc
import math
import random
from typing import Union
from tqdm import tqdm
import wandb
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from transformers import AutoTokenizer
from accelerate.utils import ProjectConfiguration, set_seed
from accelerate import Accelerator
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    UNet2DConditionModel,
)
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F

In [None]:
from MRIDiffusion.t2iadapter.config import (T2IConfig)
from MRIDiffusion.t2iadapter.t2iadapter import (
    Adapter_XL,
)
from MRIDiffusion.t2iadapter.MRIProjector import MRIProjector, LatentMRIProjector, InverseProjectorLearned
from MRIDiffusion.t2iadapter.utils import (
    import_model_class_from_model_name_or_path,
    compute_embeddings_sd1x5,
    plot_generated_and_ground_truth,
    log_configs,
    generate_mri_slices_partial_dc,
    print_trainable_parameters,
    generate_mri_slices_partial_latent_align_dc,
)
from MRIDiffusion.slicedMRI import DatasetConfig, PairedMRI_MiniDataset, FastMRILazyDataset
from MRIDiffusion.eval import MRIEvaluator
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

In [None]:
debug: bool = False

### PairedMRIDataset - Setup

In [None]:
train_val_test_split = (0.8, 0.0, 0.2)
assert sum(train_val_test_split) == 1.0, "Dataset split should sum up to one"

shared_config: dict = {
    "data_dir": Path("./brain_fastMRI_DICOM"),
    "mode": "train",
    "fractions": train_val_test_split,
    "target_size": (512, 512),
    "contrast_filter": "T2",
    "strength_filter": "3.0T",
    "scale_factor": 4.0,
    "fastMRI_manifest_json": "/content/drive/MyDrive/Colab Notebooks/MasterInfo/GenAI/fast_MRI_brain_patient_records_manifest.json",
}
config = DatasetConfig(**shared_config)
t2i_config = T2IConfig(
    train_batch_size=16,
    test_batch_size=4,
    partial_start_step=350,
    max_train_steps=12000,
    pretrained_vae_model_name_or_path="microsoft/mri-autoencoder-v0.1",
)
# train_dataset = PairedMRI_MiniDataset(config=DatasetConfig(**shared_config), verbose=1)
train_dataset = FastMRILazyDataset(config=DatasetConfig(**shared_config))
shared_config["mode"] = "test"
test_dataset = FastMRILazyDataset(config=DatasetConfig(**shared_config))
# test_dataset = PairedMRI_MiniDataset(config=DatasetConfig(**shared_config), verbose=1)
train_loader = DataLoader(
    train_dataset, batch_size=t2i_config.train_batch_size, shuffle=True, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=t2i_config.test_batch_size, shuffle=True, num_workers=2
)

In [None]:
train_val_test_split = [0.8, 0.0, 0.2]
assert sum(train_val_test_split) == 1.0, "Dataset split should sum up to one"

shared_config: dict = {
    "data_dir": Path("./mri_dataset/Data/3T data"),
    "mode": "train",
    "fractions": train_val_test_split,
    "slice_axis": 2,
    "do_registration": True,
    "do_n4": False,
}
config = DatasetConfig(**shared_config)
# performs SDEEdit (Stochastic Differential Equation Editing)
t2i_config = T2IConfig(train_batch_size=16, test_batch_size=4, partial_start_step=250)
train_dataset = PairedMRI_MiniDataset(config=DatasetConfig(**shared_config), verbose=1)
shared_config["mode"] = "test"
test_dataset = PairedMRI_MiniDataset(config=DatasetConfig(**shared_config), verbose=1)
train_loader = DataLoader(train_dataset, batch_size=t2i_config.train_batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=t2i_config.test_batch_size, shuffle=True, num_workers=2)

In [None]:
if debug:
    def inspect_sample(sample):
        hr = sample["hr"]
        lr = sample["lr"]
        # convert to numpy (handle torch tensors)
        if isinstance(hr, torch.Tensor):
            hr = hr.detach().cpu().numpy()
        if isinstance(lr, torch.Tensor):
            lr = lr.detach().cpu().numpy()

        print("HR shape:", hr.shape, "LR shape:", lr.shape, " dtype:", hr.dtype)
        print(
            "HR min/max/mean:",
            float(np.nanmin(hr)),
            float(np.nanmax(hr)),
            float(np.nanmean(hr)),
        )
        print(
            "LR min/max/mean:",
            float(np.nanmin(lr)),
            float(np.nanmax(lr)),
            float(np.nanmean(lr)),
        )
        print("HR NaNs:", int(np.isnan(hr).sum()), "LR NaNs:", int(np.isnan(lr).sum()))
        print("HR unique (sample):", np.unique(hr.ravel()[:200]).tolist())
        # quick histogram (coarse)
        hist, edges = np.histogram(hr.ravel(), bins=10)
        print("HR hist bins:", hist, "edges:", np.round(edges, 4))

    num_plots = 40
    offset = 0
    fig, axes = plt.subplots(
        nrows=num_plots,
        ncols=2,
        figsize=(12, num_plots * 6),  # Adjust figure size based on number of plots
        dpi=100,
    )
    if num_plots == 1:
        axes = axes[np.newaxis, :]
    fig.suptitle(
        f"Generated vs. Ground Truth (First {num_plots} Slices)", fontsize=16, y=1.02
    )
    for i in range(num_plots):
        data = train_dataset[offset + i]
        print("=== SAMPLE", i, "===")
        inspect_sample(train_dataset[i + offset])
        ax_gt = axes[i, 0]
        # Ground Truth is typically grayscale (H, W), use 'gray' colormap
        # We assume the HR slices have been normalized to [0, 1] or similar
        ax_gt.imshow(np.squeeze(data["hr"]))
        ax_gt.set_title(f"HR Slice {i+1}", fontsize=12)
        ax_gt.axis("off")
        ax_lr = axes[i, 1]
        ax_lr.imshow(np.squeeze(data["lr"]))
        ax_lr.set_title(f"LR Slice {i+1}", fontsize=12)
        ax_lr.axis("off")
    plt.tight_layout(rect=[0, 0, 1, 1.01])  # Adjust layout to make space for suptitle
    plt.show()

### Training Setup

- SDEEdit (Stochastic Differential Equation Editing): The lr latents are not treated as pure noise but as slightly noisy state (here \(t=200\)), then a small amount of noisy is injected (just enough to break the blur and allow the model to inject texture) and finally we denoise it.
- Follows the work "MRI Super-Resolution with Partial Diffusion Models" by Zhao et al. where they found that the latents of LF and HF MRI scans become indistinguishable after a certain amount of noise
- MRIProjector: Use a pseudo-colorization mapping preserving the high dynamic range of MRI while distributing the signal across the R, G, and B channels in a way the VAE can better compress. 
  - Alternatively, one could train a mapping from the microsoft mri vae to the space of the sd1.5 vae

In [None]:
logging_dir = Path(t2i_config.output_dir, t2i_config.logging_dir)
accelerator_project_config = ProjectConfiguration(
    project_dir=t2i_config.output_dir, logging_dir=t2i_config.logging_dir
)
accelerator = Accelerator(
    gradient_accumulation_steps=t2i_config.gradient_accumulation_steps,
    mixed_precision=t2i_config.mixed_precision,
    log_with=t2i_config.report_to,
    project_config=accelerator_project_config,
)
set_seed(t2i_config.seed)
os.makedirs(t2i_config.output_dir, exist_ok=True)

# load the tokenizers
tokenizer = AutoTokenizer.from_pretrained(
    t2i_config.pretrained_model_name_or_path,
    subfolder="tokenizer",
    revision=t2i_config.revision,
    use_fast=False,
)
# load the correct scheduler and models
text_encoder_cls = import_model_class_from_model_name_or_path(
    t2i_config.pretrained_model_name_or_path,
    t2i_config.revision,
    subfolder="text_encoder",
)
# Load scheduler and models
# timesteps defauls to 1000
noise_scheduler = DDPMScheduler.from_pretrained(
    t2i_config.pretrained_model_name_or_path,
    subfolder="scheduler",
    prediction_type=t2i_config.ddpm_scheduler_prediction_type,  # velocity prediction
    timestep_spacing=t2i_config.ddpm_scheduler_timestep_spacing,  # for zero-SNR
    rescale_betas_zero_snr=t2i_config.ddpm_scheduler_rescale_betas_zero_snr,  # enforces pure noise at t=1000
)
text_encoder = text_encoder_cls.from_pretrained(
    t2i_config.pretrained_model_name_or_path,
    subfolder="text_encoder",
    revision=t2i_config.revision,
)
vae_path = (
    t2i_config.pretrained_model_name_or_path
    if t2i_config.pretrained_vae_model_name_or_path is None
    else t2i_config.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
    vae_path,  # TODO: play around with this microsoft/mri-autoencoder-v0.1
    subfolder="vae" if t2i_config.pretrained_vae_model_name_or_path is None else None,
    revision=t2i_config.revision,
)
unet = UNet2DConditionModel.from_pretrained(
    t2i_config.pretrained_model_name_or_path,
    subfolder="unet",
    revision=t2i_config.revision,
)

# These are never trained to convert mode collapse (see controlnet paper)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)

print(f"Using VAE: {vae.config['_name_or_path']}, UNET: {unet.config['_name_or_path']}, Scheduler Steps: {noise_scheduler.config["num_train_timesteps"]}")

In [None]:
if t2i_config.enable_xformers_memory_efficient_attention:
    import xformers # pyright: ignore[reportMissingImports]
    from packaging import version
    xformers_version = version.parse(xformers.__version__)
    if xformers_version == version.parse("0.0.16"):
        print(
            "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
        )
    unet.enable_xformers_memory_efficient_attention()
if t2i_config.gradient_checkpointing:
    unet.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if t2i_config.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True
if t2i_config.scale_lr:
    learning_rate = (
        t2i_config.learning_rate
        * t2i_config.gradient_accumulation_steps
        * t2i_config.train_batch_size
        * accelerator.num_processes
    )
else:
    learning_rate = t2i_config.learning_rate
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if t2i_config.use_8bit_adam:
    try:
        import bitsandbytes as bnb # pyright: ignore[reportMissingImports]
    except ImportError:
        raise ImportError(
            "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
        )
    optimizer_class = bnb.optim.AdamW8bit
else:
    optimizer_class = torch.optim.AdamW

In [None]:
adapter = Adapter_XL(
    channels=[320, 640, 1280, 1280],
    nums_rb=3,
    cin=3 * 64,
    ksize=3,
    sk=True,
    use_conv=True,
).to(accelerator.device)
is_special_vae = (
    t2i_config.pretrained_vae_model_name_or_path == "microsoft/mri-autoencoder-v0.1"
)
output_dims = 2 if is_special_vae else 3
latent_projector = LatentMRIProjector(
    spatial_in=128, spatial_out=64, in_channels=4, out_channels=4
).to(accelerator.device)
mri_projector = MRIProjector(output_dims=2).to(accelerator.device)
inverse_latent_projector = InverseProjectorLearned(spatial_in=64, spatial_out=128).to(
    accelerator.device
)
params_to_optimize = (
    list(adapter.parameters())
    + list(latent_projector.parameters())
    + list(mri_projector.parameters())
    + list(inverse_latent_projector.parameters())
)
optimizer = optimizer_class(
    params_to_optimize,
    lr=learning_rate,
    betas=(t2i_config.adam_beta1, t2i_config.adam_beta2),
    weight_decay=t2i_config.adam_weight_decay,
    eps=t2i_config.adam_epsilon,
)

# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16
# Move vae, unet and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
if t2i_config.pretrained_vae_model_name_or_path is not None:
    vae.to(accelerator.device, dtype=weight_dtype)
else:
    vae.to(accelerator.device, dtype=torch.float32)
unet.to(accelerator.device)
text_encoder.to(accelerator.device, dtype=weight_dtype)

In [None]:
# Let's first compute all the embeddings so that we can free up the text encoders from memory.
text_encoders = [text_encoder]
tokenizers = [tokenizer]
gc.collect()
torch.cuda.empty_cache()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
# num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(1e7 / t2i_config.gradient_accumulation_steps)
if t2i_config.max_train_steps is None:
    t2i_config.max_train_steps = (
        t2i_config.num_train_epochs * num_update_steps_per_epoch
    )
    overrode_max_train_steps = True
lr_scheduler = get_scheduler(
    t2i_config.lr_scheduler_name,
    optimizer=optimizer,
    num_warmup_steps=t2i_config.lr_warmup_steps * accelerator.num_processes,
    num_training_steps=t2i_config.max_train_steps * accelerator.num_processes,
    num_cycles=t2i_config.lr_num_cycles,
    power=t2i_config.lr_power,
)
ema_adapter = EMAModel(
    adapter.parameters(),
    model_cls=adapter.__class__,  # Custom class, passing it here helps with some utilities
    decay=0.9999,
)
# Prepare everything with our `accelerator`.
adapter, latent_projector, mri_projector, inverse_latent_projector, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    adapter, latent_projector, mri_projector, inverse_latent_projector, unet, optimizer, train_loader, lr_scheduler
)

if accelerator.is_main_process:
    tracker_config = dict(vars(t2i_config))
    # accelerator.init_trackers(args.tracker_project_name, config=tracker_config)

total_batch_size = (
    t2i_config.train_batch_size
    * accelerator.num_processes
    * t2i_config.gradient_accumulation_steps
)

print(f"Using VAE: {vae.config['_name_or_path']}, UNET: {unet.config['_name_or_path']}, Scheduler Steps: {noise_scheduler.config["num_train_timesteps"]}")
print_trainable_parameters(latent_projector, name="MRI Latent Projector")
print_trainable_parameters(adapter, name="T2I-Adapter")
print_trainable_parameters(mri_projector, name="MRI-Projector")
print_trainable_parameters(inverse_latent_projector, name="Inverse MRI Latent Projector")

### Training Loop

In [None]:
wandb.login()

In [None]:
run = wandb.init(
    entity="hannes-leonhard",
    project="mir-sr",
    config=log_configs(t2i_config=t2i_config, dataset_config=config),
)

In [None]:
global_step = 0
first_epoch = 0
initial_global_step = 0
loss_history = []
random.seed(t2i_config.seed)
vis_idx = random.randrange(0, len(test_dataset))
print(f"Using test entry {vis_idx} for training visualization")
psnr = PeakSignalNoiseRatio(data_range=1.0).to(accelerator.device)
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(accelerator.device)

random.seed(t2i_config.seed)
progress_bar = tqdm(
    range(0, t2i_config.max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    disable=not accelerator.is_local_main_process,
)
prompt_embeds: Union[torch.Tensor, None] = None

for epoch in range(first_epoch, t2i_config.num_train_epochs):
    adapter.train()
    latent_projector.train()
    mri_projector.train()
    inverse_latent_projector.train()
    for step, batch in enumerate(train_loader):
        with accelerator.accumulate(adapter):
            bsz = batch["hr"].shape[0]
            # Use projector for HR (Target) and LR (Condition)
            hr_rgb = mri_projector(batch["hr"].to(accelerator.device).float())
            condition = mri_projector(batch["lr"].to(accelerator.device).float())

            latents = vae.encode(hr_rgb.to(vae.dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor
            latents = latents.to(weight_dtype)
            # project mri latents to rgb latente
            latents_projected = latent_projector(latents)
            # calculate inverse latent loss
            latent_projected_aug = latents_projected + 0.01 * torch.randn_like(
                latents_projected
            )
            latents_inverse = inverse_latent_projector(latent_projected_aug)
            latents_loss = t2i_config.lambda_latent * F.mse_loss(
                latents_inverse, latents
            )
            latent_image_loss = t2i_config.lambda_latent_image * F.l1_loss(
                vae.decode(latents_inverse / vae.config.scaling_factor), hr_rgb
            )
            # Noise generation
            noise = torch.randn_like(latents_projected)
            bsz = latents_projected.shape[0]

            # SDEEdit
            timesteps = torch.randint(
                0,
                t2i_config.partial_start_step,
                (bsz,),
                device=latents_projected.device,
            ).long()
            noisy_latents = noise_scheduler.add_noise(
                latents_projected, noise, timesteps
            )
            if prompt_embeds is None:
                prompt_embeds = compute_embeddings_sd1x5(
                    batch=batch,
                    proportion_empty_prompts=0.1,
                    text_encoders=text_encoders,
                    tokenizers=tokenizers,
                    accelerator=accelerator,
                )["prompt_embeds"]
            # Adapter conditioning
            down_block_additional_residuals = adapter(condition)
            model_pred = unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=prompt_embeds,
                down_block_additional_residuals=down_block_additional_residuals,
            ).sample

            # Loss Computation
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(
                    latents_projected, noise, timesteps
                )
            else:
                raise ValueError(
                    f"Unknown prediction type {noise_scheduler.config.prediction_type}"
                )

            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            total_loss = loss + latents_loss + latent_image_loss
            loss_history.append(loss.detach().item())
            accelerator.backward(total_loss)

            if accelerator.sync_gradients:
                params_to_clip = (
                    list(adapter.parameters())
                    + list(inverse_latent_projector.parameters())
                    + list(mri_projector.parameters())
                )
                accelerator.clip_grad_norm_(params_to_clip, t2i_config.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                if accelerator.is_main_process:
                    unwrapped_adapter = accelerator.unwrap_model(adapter)
                    ema_adapter.step(unwrapped_adapter.parameters())
                optimizer.zero_grad()

        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1

            # --- Validation / Visualization ---
            if accelerator.is_main_process:
                if global_step % t2i_config.media_reporting_step == 0:
                    item = test_dataset[vis_idx]
                    # Add batch dimension
                    item["hr"] = item["hr"].unsqueeze(0)
                    item["lr"] = item["lr"].unsqueeze(0)
                    item["txt"] = [item["txt"]]

                    with torch.no_grad():
                        prompt_embeds_eval = compute_embeddings_sd1x5(
                            batch=item,
                            proportion_empty_prompts=0.0,
                            text_encoders=text_encoders,
                            tokenizers=tokenizers,
                            accelerator=accelerator,
                            is_train=False,
                        )["prompt_embeds"]
                    images, _ = generate_mri_slices_partial_latent_align_dc(
                        batch=item,
                        adapter=adapter,
                        mri_projector=mri_projector,
                        latent_projector=latent_projector,
                        inverse_latent_projector=inverse_latent_projector,
                        unet=unet,
                        vae=vae,
                        noise_scheduler=noise_scheduler,
                        prompt_embeds=prompt_embeds_eval,
                        start_step=t2i_config.partial_start_step,  # Start from t=250
                        num_inference_steps=500,  # Scheduler will slice this
                        weight_dtype=weight_dtype,
                        accelerator=accelerator,
                        use_data_consistency=True,
                        dc_reduction_factor=1.7,
                        taper=0.12,
                        apply_final_pixel_dc=False,
                    )

                    adapter.train()
                    mri_projector.train()
                    latent_projector.train()
                    inverse_latent_projector.train()
                    views = []
                    images_gt = wandb.Image(
                        item["hr"].numpy().transpose((1, 2, 0)),
                        caption="GT MRI",
                    )
                    images_lr = wandb.Image(
                        item["lr"].numpy().transpose((1, 2, 0)),
                        caption="Low-Res MRI",
                    )
                    views.append(images_gt)
                    views.append(images_lr)
                    channels = images.shape[3]
                    for dim in range(channels):
                        images_gen = wandb.Image(
                            images[0, :, :, dim][:, :, np.newaxis],
                            caption=f"Gen MRI (axis {dim}) - Partial Diff",
                        )
                        views.append(images_gen)
                    run.log({"validation views": views})

                    # Metrics calculation
                    if item["hr"].ndim == 3:
                        gt = item["hr"].unsqueeze(1).expand(1, channels, 512, 512)
                    else:
                        gt = item["hr"].expand(1, channels, 512, 512)

                    metrics_val = MRIEvaluator.eval_all_metrics(
                        ground_truth=gt.to(accelerator.device),
                        generated=torch.from_numpy(images)
                        .permute(0, 3, 1, 2)
                        .to(accelerator.device),
                        psnr=psnr,
                        ssim=ssim,
                    )
                    run.log(
                        {
                            "val_hfen": metrics_val[0],
                            "val_nmse": metrics_val[1],
                            "val_psnr": metrics_val[2],
                            "val_ssim": metrics_val[3],
                        }
                    )

                if global_step % t2i_config.checkpointing_steps == 0:
                    save_path = os.path.join(
                        t2i_config.output_dir, f"checkpoint-{global_step}"
                    )
                    os.makedirs(save_path, exist_ok=True)
                    unwrapped_adapter = accelerator.unwrap_model(adapter)
                    ema_adapter.store(unwrapped_adapter.parameters())
                    ema_adapter.copy_to(unwrapped_adapter.parameters())
                    torch.save(
                        unwrapped_adapter.state_dict(),
                        os.path.join(save_path, "adapter.pt"),
                    )
                    ema_adapter.restore(unwrapped_adapter.parameters())
                    print(f"Saved EMA checkpoint to {save_path}")

        logs = {
            "loss": total_loss.detach().item(),
            "diffusion_loss": loss.detach().item(),
            "latent_loss": latents_loss.detach().item(),
            "latent_image_loss": latent_image_loss.detach().item(),
            "lr": lr_scheduler.get_last_lr()[0],
            "epoch": epoch,
        }
        run.log(logs)
        progress_bar.set_postfix(**logs)

# After training is done
if accelerator.is_main_process:
    unwrapped_adapter = accelerator.unwrap_model(adapter)
    # Final EMA Save
    ema_adapter.copy_to(unwrapped_adapter.parameters())
    torch.save(
        unwrapped_adapter.state_dict(),
        os.path.join(t2i_config.output_dir, "adapter_tuned.pt"),
    )
    torch.save(
        mri_projector.state_dict(),
        os.path.join(t2i_config.output_dir, "mri_projector.pt"),
    )
    torch.save(
        inverse_latent_projector.state_dict(),
        os.path.join(t2i_config.output_dir, "inverse_latent_projector.pt"),
    )
    print("Training finished. Saved final T2I-Adapter EMA.")

# Evaluation

- Run over the whoole validation dataset and aggreate PSNR; SSIM; HFEN and NMSE
- Plot testing views
- Logs aggreagted metrics to wandb

In [None]:
metrics = []
psnr = PeakSignalNoiseRatio(data_range=1.0).to(accelerator.device)
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(accelerator.device)

for batch in test_loader:
    with torch.no_grad():
        prompt_embeds_eval = compute_embeddings_sd1x5(
            batch=batch,
            proportion_empty_prompts=0.0,  # Use 0.0 for evaluation when skipping CFG
            text_encoders=text_encoders,
            tokenizers=tokenizers,
            accelerator=accelerator,
        )["prompt_embeds"]
    image_batch_np, postprocessed = generate_mri_slices_partial_latent_align_dc(
        batch=batch,
        adapter=adapter,
        mri_projector=mri_projector,
        inverse_latent_projector=inverse_latent_projector,
        unet=unet,
        vae=vae,
        noise_scheduler=noise_scheduler,
        prompt_embeds=prompt_embeds_eval,
        start_step=t2i_config.partial_start_step,  # Start from t=250
        num_inference_steps=500,  # Scheduler will slice this
        weight_dtype=weight_dtype,
        accelerator=accelerator,
        use_data_consistency=True,
        dc_reduction_factor=1.7,
        taper=0.12,
        apply_final_pixel_dc=False,
    )
    channels = image_batch_np.shape[3]
    if batch["hr"].ndim == 3:
        gt = (
            batch["hr"]
            .unsqueeze(1)
            .expand((t2i_config.test_batch_size, channels, 512, 512))
            .to(accelerator.device)
        )
    else:
        gt = (
            batch["hr"]
            .expand((t2i_config.test_batch_size, 512, 512, channels))
            .permute(0, 3, 1, 2)
            .to(accelerator.device),
        )
    metrics_batch = MRIEvaluator.eval_all_metrics(
        ground_truth=gt,
        generated=torch.from_numpy(image_batch_np)
        .permute(0, 3, 1, 2)
        .to(accelerator.device),
        psnr=psnr,
        ssim=ssim,
    )  # returns hfen, nmse, psnr, ssim
    metrics.append(metrics_batch)

metrics_np = np.array(metrics)  # shape: (num_batches, 4)
avg_metrics = metrics_np.mean(axis=0)
run.summary["hfen"] = float(avg_metrics[0])
run.summary["nmse"] = float(avg_metrics[1])
run.summary["psnr"] = float(avg_metrics[2])
run.summary["ssim"] = float(avg_metrics[3])


plot_generated_and_ground_truth(
    generated_slices_np=image_batch_np, batch=batch, num_images_to_show=6
)

Error in callback <function _draw_all_if_interactive at 0x11692a090> (for post_execute), with arguments args (),kwargs {}:


RecursionError: maximum recursion depth exceeded

RecursionError: maximum recursion depth exceeded

<Figure size 640x480 with 8 Axes>

In [None]:
run.finish()