In [1]:
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F


from monai.config import print_config
from monai.data import DataLoader
from torch.amp import GradScaler, autocast
from tqdm import tqdm


from generative.inferers import DiffusionInferer
from generative.networks.nets import DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import random
from PIL import Image
from torchvision import transforms
import torchvision
import pandas as pd
from skimage.metrics import structural_similarity as ssim_fn
import wandb


from torch import nn


from tqdm import trange, tqdm



print_config()

  @torch.cuda.amp.autocast(enabled=False)
  @torch.cuda.amp.autocast(enabled=False)


MONAI version: 1.3.2
Numpy version: 2.0.1
Pytorch version: 2.5.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 59a7211070538586369afd4a01eca0a7fe2e742e
MONAI __file__: /home/<username>/miniconda3/envs/medical/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: 0.25.0
scipy version: 1.15.3
Pillow version: 11.1.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.20.1
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.0
pandas version: 2.2.3
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INST

In [2]:
class UnifiedBrainDataset(Dataset):
    def __init__(self, root_dir, transform=None, split="train", seed=42):
        assert split in ["train", "val", "test"], "split must be 'train', 'val' or 'test'"
        self.root_dir = root_dir
        self.transform = transform
        self.split = split
        self.seed = seed
        self.samples = self._create_file_pairs()
        self._split_dataset()

    def _create_file_pairs(self):
        t1_dir = os.path.join(self.root_dir, "t1")
        t2_dir = os.path.join(self.root_dir, "t2")

        t1_files = set(os.listdir(t1_dir))
        t2_files = set(os.listdir(t2_dir))
        common_files = list(t1_files.intersection(t2_files))
        common_files.sort()

        pairs = [(os.path.join(t1_dir, fname), os.path.join(t2_dir, fname)) for fname in common_files]
        return pairs

    def _split_dataset(self):
        random.seed(self.seed)
        random.shuffle(self.samples)

        n_total = len(self.samples)
        n_train = int(n_total * 0.80)
        n_val = int(n_total * 0.05)
        n_test = n_total - n_train - n_val

        if self.split == "train":
            self.samples = self.samples[:n_train]
        elif self.split == "val":
            self.samples = self.samples[n_train:n_train + n_val]
        elif self.split == "test":
            self.samples = self.samples[n_train + n_val:]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        t1_path, t2_path = self.samples[idx]
        t1_image = Image.open(t1_path).convert("L")
        t2_image = Image.open(t2_path).convert("L")

        if self.transform:
            t1_image = self.transform(t1_image)
            t2_image = self.transform(t2_image)

        return {
            "t1": t1_image,
            "t2": t2_image,
            "filename": os.path.basename(t1_path)
        }

In [3]:
DATAPATH = DATAPATH = '/home/andrea_moschetto/flow_matching_t1t2/data'
OUTPUT_DIR = "/home/andrea_moschetto/flow_matching_t1t2/outputs"
CHECKPOINTS_PATH = '/home/andrea_moschetto/flow_matching_t1t2/baseline_checkpoints'

In [4]:
transform = transforms.Compose([
    transforms.Pad(padding=(5, 3, 5, 3), fill=0),
    transforms.ToTensor(),  # Normalize to [0, 1]
])

train_dataset = UnifiedBrainDataset(root_dir=DATAPATH, transform=transform, split="train")
train_loader = DataLoader(train_dataset, batch_size=6, num_workers=2, shuffle=True)
val_dataset = UnifiedBrainDataset(root_dir=DATAPATH, transform=transform, split="val")
val_loader = DataLoader(val_dataset, batch_size=6, num_workers=2, shuffle=False)
test_dataset = UnifiedBrainDataset(root_dir=DATAPATH, transform=transform, split="test")
test_loader = DataLoader(test_dataset, batch_size=6, num_workers=2, shuffle=False)

In [5]:
def percnorm(arr, lperc=5, uperc=99.5):
    """
    Remove outlier intensities from a brain component,
    similar to Tukey's fences method.
    """
    upperbound = np.percentile(arr, uperc)
    lowerbound = np.percentile(arr, lperc)
    arr[arr > upperbound] = upperbound
    arr[arr < lowerbound] = lowerbound
    return arr

def normalize(img):
    # img: [C, H, W]
    img = (img - img.min())/ (img.max() - img.min() + 1e-8)
    return img


In [None]:
if not os.path.exists(CHECKPOINTS_PATH):
    os.makedirs(CHECKPOINTS_PATH)


def train_diffusion(model: DiffusionModelUNet, inferer: DiffusionInferer, train_loader: DataLoader, val_loader: DataLoader, project: str, exp_name: str, notes: str, n_epochs: int = 10, lr : float = 1e-3, generation_steps: int = 100):
    with wandb.init(
        project=project,
        name=exp_name,
        notes=notes,
        tags=["flow", "brain", "diffusion"],
        config={
            'model': model.__class__.__name__,
            'epochs': n_epochs,
            'batch_size': train_loader.batch_size,
            'num_workers': train_loader.num_workers,
            'optimizer': 'Adam',
            'learning_rate': lr,
            'loss_function': 'MSELoss',
            'generation_steps': generation_steps,
            'device': str(torch.cuda.get_device_name(0)
                          if torch.cuda.is_available() else "CPU"),
        }
    ) as run:

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("Using", torch.cuda.get_device_name(0)
                if torch.cuda.is_available() else "CPU")

        model.to(device)
        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.MSELoss()

        best_val_loss = float("inf")
        best_model_path = None
        start_time = time.time()
        for e in trange(n_epochs, desc="Epochs"):
            start_e_time = time.time()
            # Training
            model.train()
            train_losses = []
            for batch in train_loader:
                t2_targets = batch["t2"].to(device)  # [B, 1, H, W]
                t1_cond = batch["t1"].to(device)  # [B, 1, H, W]  # torch.randn_like(x_1).to(device)  # [B, 1, H, W]
                noise = torch.randn_like(t2_targets).to(device)

                
                B = t2_targets.shape[0]
                # Create timesteps
                timesteps = torch.randint(0, generation_steps, (B,), device=t2_targets.device).long()
                optimizer.zero_grad()
                # Get model prediction
                noise_pred = inferer(inputs=t2_targets, diffusion_model=model, noise=noise, timesteps=timesteps, condition=t1_cond, mode='concat')

                loss = criterion(noise_pred.float(), noise.float())
                train_losses.append(loss.item())
                
                loss.backward()
                optimizer.step()
            wandb.log({"train_loss": sum(train_losses) / len(train_losses)})
            
            # Validation
            model.eval()
            val_losses = []
            with torch.no_grad():
                for batch in val_loader:
                    t2_targets = batch["t2"].to(device)
                    t1_cond = batch["t1"].to(device)
                    noise = torch.randn_like(t2_targets).to(device)
                    B = t2_targets.shape[0]
                    timesteps = torch.randint(0, generation_steps, (B,), device=t2_targets.device).long()
                    noise_pred = inferer(inputs=t2_targets, diffusion_model=model, noise=noise, timesteps=timesteps, condition=t1_cond, mode='concat')
                    val_loss = criterion(noise_pred.float(), noise.float())
                    val_losses.append(val_loss.item())
            batch_val_loss = sum(val_losses) / len(val_losses)
            wandb.log({"val_loss": batch_val_loss})
            e_time = time.time() - start_e_time
            wandb.log({"epoch_time_minutes": e_time // 60})


            # Checkpoint
            if batch_val_loss < best_val_loss:
                sample_batch = next(iter(val_loader))  # just one batch
                t1_cond = sample_batch["t1"][0].unsqueeze(0).to(device)
                t2_target = sample_batch["t2"][0].unsqueeze(0).to(device)
                
                with torch.no_grad():
                    gen_image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=inferer.scheduler, mode='concat', conditioning=t1_cond)
                    images = torch.cat([t1_cond, t2_target, normalize(gen_image)], dim=0)
                    grid = torchvision.utils.make_grid(images, nrow=3)
                    wandb.log({
                        "generation": [wandb.Image(grid, caption=f"Epoch {e+1}")]
                    })
                
                path = f'{CHECKPOINTS_PATH}/checkpoint_{exp_name}_{e+1}.pth'
                torch.save({
                    'epoch': e+1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, path)
                if best_model_path is not None and os.path.exists(best_model_path):
                    os.remove(best_model_path)
                best_model_path = path
                best_val_loss = batch_val_loss
        end_time = time.time()
        elapsed_time = end_time - start_time
        wandb.log({"total_running_hours": elapsed_time // 3600})
        print(f"Training completed in {elapsed_time // 60:.0f}m {elapsed_time % 60:.0f}s")
    print("Training complete.")

In [None]:
@torch.no_grad()
def generate_and_save_predictions(model: nn.Module, inferer: DiffusionInferer, test_loader: DataLoader, device: str, output_dir: str = OUTPUT_DIR, just_one_batch: bool = False):
    os.makedirs(output_dir, exist_ok=True)
    model.eval()

    all_outputs = []

    for batch in tqdm(test_loader, desc="Generating Predictions"):
        t1_cond = batch["t1"].to(device)           # [B, 1, H, W]
        t2_target = batch["t2"].to(device)           # [B, 1, H, W]
        noise = torch.randn_like(t2_target).to(device)          # [B, 1, H, W]
        filenames = batch["filename"]         # list of strings (length B)

        gen_image = inferer.sample(input_noise=noise, diffusion_model=model,
                                   scheduler=inferer.scheduler, mode='concat', conditioning=t1_cond)

        for i in range(t1_cond.size(0)):
            sample = {
                "filename": filenames[i],
                "input": t1_cond[i].cpu(),         # torch.Tensor [1, H, W]
                "target": t2_target[i].cpu(),
                "prediction": gen_image[i].cpu()
            }

            torch.save(sample, os.path.join(output_dir, f"{filenames[i]}.pt"))
            all_outputs.append(sample)
        if just_one_batch:
            break

    return all_outputs

In [8]:
class PredictionDataset(Dataset):
    def __init__(self, directory):
        super().__init__()
        self.directory = directory
        self.files = sorted([
            f for f in os.listdir(directory) if f.endswith('.pt')
        ])
        if not self.files:
            raise ValueError(f"No .pt files found in directory: {directory}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = os.path.join(self.directory, self.files[idx])
        data = torch.load(file_path)
        pred = data["prediction"]       # expected shape: [1, H, W] or [C, H, W]
        gt = data["target"]
        return pred, gt


In [9]:
def compute_ssim_from_dataset(dataset):
    ssim_scores = []
    mse_scores = []

    example_pred = None
    example_gt = None

    for i in range(len(dataset)):
        pred, gt = dataset[i]  # tensors [1, H, W]

        # Convert to numpy and squeeze channel
        pred_np = pred.squeeze().cpu().numpy()
        gt_np = gt.squeeze().cpu().numpy()

        pred_np = normalize(percnorm(pred_np))
        gt_np = normalize(percnorm(gt_np))

        # Compute SSIM
        ssim_val = ssim_fn(pred_np, gt_np, data_range=1.0)
        ssim_scores.append(ssim_val)

        # Compute MSE
        mse_val = F.mse_loss(pred, gt).item()
        mse_scores.append(mse_val)

        # Store one example for visualization
        if i == 4 and example_pred is None:
            example_pred = pred_np
            example_gt = gt_np

    ssim_scores = np.array(ssim_scores)
    mse_scores = np.array(mse_scores)
    
    summary = pd.DataFrame({
        "Metric": ["SSIM", "MSE"],
        "Mean": [ssim_scores.mean(), mse_scores.mean()],
        "Variance": [ssim_scores.var(), mse_scores.var()]
    })

    # Visualize example
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    axs[0].imshow(example_gt, cmap='gray')
    axs[0].set_title("Ground Truth")
    axs[0].axis("off")

    axs[1].imshow(example_pred, cmap='gray')
    axs[1].set_title("Prediction")
    axs[1].axis("off")

    plt.suptitle("Example Comparison")
    plt.tight_layout()
    plt.show()

    return summary

In [10]:
device = torch.device("cuda")
num_train_timesteps = 1000

model = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=2,
    out_channels=1
)

scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps)
inferer = DiffusionInferer(scheduler)

In [None]:
exp_name = "diffusion-t1t2-brain150e"
train_diffusion(
    model=model,
    inferer=inferer,
    train_loader=train_loader,
    val_loader=val_loader,
    project="FlowMatching-Baselines",
    exp_name=exp_name,
    notes="Training a diffusion model for T1-T2 brain image generation.",
    n_epochs=150,
    lr=1e-4,
    generation_steps=num_train_timesteps
)
generate_and_save_predictions(
    model=model,
    inferer=inferer,
    test_loader=test_loader,
    device=device,
    output_dir=f'{OUTPUT_DIR}/{exp_name}',
)
out_dataset = PredictionDataset(f'{OUTPUT_DIR}/{exp_name}')
with wandb.init(
    project = 'FlowMatching-Baselines',
    name=f'evaluation-{exp_name}',
    notes="Evaluation of the diffusion model on the test set.",
):
    summary = compute_ssim_from_dataset(out_dataset)
    wandb.log({"eval/metrics": wandb.Table(dataframe=summary)})
    wandb.log({"eval/ssim_mean": summary["Mean"][0]})
    wandb.log({"eval/mse_mean": summary["Mean"][1]})
    wandb.log({"eval/ssim_var": summary["Variance"][0]})
    wandb.log({"eval/mse_var": summary["Variance"][1]})
summary

[34m[1mwandb[0m: Currently logged in as: [33mandreamoschetto99[0m ([33mandreamoschetto99-university-of-catania[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using Tesla T4


100%|████████████████████████████████████████████████████████████████| 1000/1000 [00:42<00:00, 23.33it/s]
Epochs:   1%|▎                                                       | 1/150 [04:16<10:36:22, 256.26s/it]