In [3]:
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

from monai.config import print_config
from monai.data import DataLoader
from torch.amp import autocast
from torch.optim.lr_scheduler import CosineAnnealingLR

from tqdm import tqdm


from generative.inferers import DiffusionInferer, ControlNetDiffusionInferer
from generative.networks.nets import DiffusionModelUNet, ControlNet
from generative.networks.schedulers import DDPMScheduler

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import random
from PIL import Image
from torchvision import transforms
import torchvision
import pandas as pd
from skimage.metrics import structural_similarity as ssim_fn
from skimage.metrics import peak_signal_noise_ratio as psnr_fn
import wandb


from torch import nn


from tqdm import trange, tqdm



print_config()

  @torch.cuda.amp.autocast(enabled=False)
  @torch.cuda.amp.autocast(enabled=False)


MONAI version: 1.3.2
Numpy version: 2.0.1
Pytorch version: 2.5.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 59a7211070538586369afd4a01eca0a7fe2e742e
MONAI __file__: /home/<username>/miniconda3/envs/medical/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: 0.25.0
scipy version: 1.15.3
Pillow version: 11.1.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.20.1
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.0
pandas version: 2.2.3
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INST

In [4]:
from generative.networks.nets import DiffusionModelUNet, ControlNet


In [5]:
DATAPATH = DATAPATH = '/home/andrea_moschetto/FlowMatching-MREConversion/data'
OUTPUT_DIR = "/home/andrea_moschetto/FlowMatching-MREConversion/outputs"
CHECKPOINTS_PATH = '/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints'

In [9]:
class UnifiedBrainDataset(Dataset):
    def __init__(self, root_dir, transform=None, split="train", seed=42):
        assert split in ["train", "val", "test"], "split must be 'train', 'val' or 'test'"
        self.root_dir = root_dir
        self.transform = transform
        self.split = split
        self.seed = seed
        self.samples = self._create_file_pairs()
        self._split_dataset()

    def _create_file_pairs(self):
        t1_dir = os.path.join(self.root_dir, "t1")
        t2_dir = os.path.join(self.root_dir, "t2")

        t1_files = set(os.listdir(t1_dir))
        t2_files = set(os.listdir(t2_dir))
        common_files = list(t1_files.intersection(t2_files))
        common_files.sort()

        pairs = [(os.path.join(t1_dir, fname), os.path.join(t2_dir, fname)) for fname in common_files]
        return pairs

    def _split_dataset(self):
        random.seed(self.seed)
        random.shuffle(self.samples)

        n_total = len(self.samples)
        n_train = int(n_total * 0.80)
        n_val = int(n_total * 0.05)
        n_test = n_total - n_train - n_val

        if self.split == "train":
            self.samples = self.samples[:n_train]
        elif self.split == "val":
            self.samples = self.samples[n_train:n_train + n_val]
        elif self.split == "test":
            self.samples = self.samples[n_train + n_val:]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        t1_path, t2_path = self.samples[idx]
        t1_image = Image.open(t1_path).convert("L")
        t2_image = Image.open(t2_path).convert("L")

        if self.transform:
            t1_image = self.transform(t1_image)
            t2_image = self.transform(t2_image)

        return {
            "t1": t1_image,
            "t2": t2_image,
            "filename": os.path.basename(t1_path)
        }

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fm_direct = DiffusionModelUNet(
    spatial_dims=2,  #  2D
    in_channels=1,  # x
    out_channels=1  # predice delta_x_t
)
fm_direct = fm_direct.to(device)
fm_noise = DiffusionModelUNet(
    spatial_dims=2,  #  2D
    in_channels=2,  #  noise + t1 
    out_channels=1  #  predice delta_x_t solo sul noise (condizionato da t1)
)
fm_noise = fm_noise.to(device)
control_diff = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1
)
control_diff = control_diff.to(device)
controlnet = ControlNet(
    spatial_dims=2,
    in_channels=1,
    conditioning_embedding_num_channels=(32, )
)
controlnet = controlnet.to(device)
dm = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=2,
    out_channels=1
)
dm = dm.to(device)
pixgen = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1
)
pixgen = pixgen.to(device)

In [23]:
# Controlnet
controldiff_path = '/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_diffusion-t2-brain300e_164_best.pth'
controlnet_path = '/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_controlnet-t1t2-brain300e_164_best.pth'
fm_noise_path = '/home/andrea_moschetto/FlowMatching-MREConversion/checkpoints/checkpoint_unetflow-noiset1t2-s300e_46_best.pth'
fm_direct_path = '/home/andrea_moschetto/FlowMatching-MREConversion/checkpoints/checkpoint_unetflow-t1t2-s300e_122_best.pth'
pixgen_path = '/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_unet2pix-t1t2-brain300e_142__generator_best.pth'
diffusion_path = '/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_diffusion-t1t2-brains300e_200_best.pth'

In [24]:
#flow matching noise
from torch import Tensor


def noise_euler_step(model: DiffusionModelUNet, x_t: Tensor, t_start: Tensor, t_end: Tensor):
    # delta_t shape (B, 1, 1, 1)
    delta_t = (t_end - t_start).view(-1, 1, 1, 1)
    
    # model si aspetta t come tensor (B,)
    v_hat = model(x_t, t_start)
    
    x_t_noise = x_t[:, 0:1,:, :] # [B, 1, H, W]
    x_t_cond = x_t[:, 1:2, :, :] # [B, 1, H, W], che è T1
    
    x_next_noise = x_t_noise + delta_t * v_hat
    
    x_next = torch.cat([x_next_noise, x_t_cond], dim=1) # [B, 2, H, W]
    return x_next

@torch.no_grad()
def noise_generate(model: nn.Module, x_cond: Tensor, n_steps: int = 20):
    model.eval()
    
    device = x_cond.device
    batch_size = x_cond.shape[0]
    
    time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=device, dtype=torch.float32)
    
    x = torch.cat([torch.randn_like(x_cond,device=device), x_cond], dim=1) # [B, 2, H, W]
    for i in range(n_steps):
        t_start = time_steps[i].expand(batch_size)
        t_end = time_steps[i + 1].expand(batch_size)
        x = noise_euler_step(model, x_t=x, t_start=t_start, t_end=t_end)
    
    return x[:, 0:1, :, :] # [B, 1, H, W]

# flow matching direct
def direct_euler_step(model: DiffusionModelUNet, x_t: Tensor, t_start: Tensor, t_end: Tensor):
    # delta_t shape (B, 1, 1, 1)
    delta_t = (t_end - t_start).view(-1, 1, 1, 1)

    # model si aspetta t come tensor (B,)
    v_hat = model(x_t, t_start)

    x_next = x_t + delta_t * v_hat

    return x_next


@torch.no_grad()
def direct_generate(model: nn.Module, x_T: Tensor, n_steps: int = 20):
    model.eval()

    device = x_T.device
    batch_size = x_T.shape[0]

    time_steps = torch.linspace(
        0.0, 1.0, n_steps + 1, device=device, dtype=torch.float32)

    x = x_T
    for i in range(n_steps):
        t_start = time_steps[i].expand(batch_size)
        t_end = time_steps[i + 1].expand(batch_size)
        x = direct_euler_step(model, x_t=x, t_start=t_start, t_end=t_end)

    return x

# controlnet
    # on inferer

#pix2pix
    # direct generation

#diffusion
    # on inferer


In [25]:
# control net

@torch.no_grad()
def control_generate_and_save_predictions(cn_model: ControlNet, df_model: DiffusionModelUNet, inferer: ControlNetDiffusionInferer, test_loader: DataLoader, device: str, output_dir: str = OUTPUT_DIR, just_one_batch: bool = False):
    start = time.time()
    df_model.eval()

    all_outputs = []
    imgsperbatch = None
    for idx, batch in enumerate(tqdm(test_loader, desc="Generating Predictions")):
        imgsperbatch = len(batch["t1"])
        t1_cond = batch["t1"].to(device)           # [B, 1, H, W]
        t2_targets = batch["t2"].to(device)           # [B, 1, H, W]
        noise = torch.randn_like(t2_targets).to(device)          # [B, 1, H, W]
        filenames = batch["filename"]         # list of strings (length B)
        
        gen_images = inferer.sample(input_noise=noise, diffusion_model=df_model, controlnet=cn_model, scheduler=inferer.scheduler, cn_cond=t1_cond, verbose=False)

            
        if just_one_batch:
            break

    end = time.time()
    # Calculate average time per img on a single batch
    avg = (end - start) / imgsperbatch
    return avg

# pix2pix
@torch.no_grad()
def pix_generate_and_save_predictions(generator: DiffusionModelUNet, test_loader: DataLoader, device: str, output_dir: str = OUTPUT_DIR, just_one_batch: bool = False):
    start = time.time()
    generator.eval()

    all_outputs = []
    imgsperbatch = None
    for idx, batch in enumerate(tqdm(test_loader, desc="Generating Predictions")):
        imgsperbatch = len(batch["t1"])
        with torch.no_grad():
            real_A = batch["t1"].to(device)  # [B, 1, H, W]
            real_B = batch["t2"].to(device)  # [B, 1, H, W]
            filenames = batch["filename"]  # list of strings (length B)
            # Generate fake T2 images

            gen_images = generator(x=real_A, timesteps=torch.zeros(real_A.shape[0], device=device))
            
        if just_one_batch:
            break
    
    end = time.time()
    
    avg = (end - start) / imgsperbatch
    return avg

# diffusion model
@torch.no_grad()
def diff_generate_and_save_predictions(model: nn.Module, inferer: DiffusionInferer, test_loader: DataLoader, device: str, output_dir: str = OUTPUT_DIR, just_one_batch: bool = False):
    start = time.time()
    imgsperbatch = None
    model.eval()

    all_outputs = []

    for idx, batch in enumerate(tqdm(test_loader, desc="Generating Predictions")):
        imgsperbatch = len(batch["t1"])
        t1_cond = batch["t1"].to(device)           # [B, 1, H, W]
        t2_target = batch["t2"].to(device)           # [B, 1, H, W]
        noise = torch.randn_like(t2_target).to(device)          # [B, 1, H, W]
        filenames = batch["filename"]         # list of strings (length B)

        gen_image = inferer.sample(input_noise=noise, diffusion_model=model,
                                   scheduler=inferer.scheduler, mode='concat', conditioning=t1_cond)

        if just_one_batch:
            break

    end = time.time()
    avg = (end - start) / imgsperbatch
    return avg

# flow matching direct
@torch.no_grad()
def direct_generate_and_save_predictions(model, test_loader, device, output_dir=OUTPUT_DIR, just_one_batch=False):
    start = time.time()
    imgsperbatch = None
    model.eval()

    all_outputs = []

    for batch in tqdm(test_loader, desc="Generating Predictions"):
        imgsperbatch = len(batch["t1"])
        t1 = batch["t1"].to(device)           # [B, 1, H, W]
        t2 = batch["t2"].to(device)           # [B, 1, H, W]
        filenames = batch["filename"]         # list of strings (length B)

        x_gen = direct_generate(model, x_T=t1, n_steps=300)

        if just_one_batch:
            break

    end = time.time()
    avg = (end - start) / imgsperbatch
    return avg

#flow matching noise
@torch.no_grad()
def noise_generate_and_save_predictions(model, test_loader, device, output_dir=OUTPUT_DIR, just_one_batch=False):
    start = time.time()
    imgsperbatch = None
    model.eval()
    
    all_outputs = []

    for idx, batch in enumerate(tqdm(test_loader, desc="Generating Predictions")):
        imgsperbatch = len(batch["t1"])
        t1 = batch["t1"].to(device)           # [B, 1, H, W]
        t2 = batch["t2"].to(device)           # [B, 1, H, W]
        filenames = batch["filename"]         # list of strings (length B)

        x_gen = noise_generate(model, x_cond=t1, n_steps=300)
        # print(t2.shape, x_gen.shape)

        if just_one_batch:
            break


    end = time.time()
    avg = (end - start) / imgsperbatch
    return avg


In [29]:
checkpoint = torch.load(fm_direct_path, map_location=device)
fm_direct.load_state_dict(checkpoint['model_state_dict'])

checkpoint = torch.load(fm_noise_path, map_location=device)
fm_noise.load_state_dict(checkpoint['model_state_dict'])
# --------------------------------------
num_train_timesteps = 1000
checkpoint = torch.load(controldiff_path, map_location=device)
control_diff.load_state_dict(checkpoint['model_state_dict'])
control_scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps)
control_diff_inferer = DiffusionInferer(control_scheduler)

checkpoint = torch.load(controlnet_path, map_location=device)
controlnet.load_state_dict(checkpoint['model_state_dict'])
controlnet_inferer = ControlNetDiffusionInferer(control_scheduler)
# --------------------------------------

checkpoint = torch.load(pixgen_path, map_location=device)
pixgen.load_state_dict(checkpoint['model_state_dict'])

checkpoint = torch.load(diffusion_path, map_location=device)
dm.load_state_dict(checkpoint['model_state_dict'])
dm_scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps)
dm_inferer = DiffusionInferer(dm_scheduler)

  checkpoint = torch.load(fm_direct_path, map_location=device)
  checkpoint = torch.load(fm_noise_path, map_location=device)
  checkpoint = torch.load(controldiff_path, map_location=device)
  checkpoint = torch.load(controlnet_path, map_location=device)
  checkpoint = torch.load(pixgen_path, map_location=device)
  checkpoint = torch.load(diffusion_path, map_location=device)


In [30]:
transform = transforms.Compose([
    transforms.Pad(padding=(5, 3, 5, 3), fill=0),
    transforms.ToTensor(),  # Normalize to [0, 1]
])
test_dataset = UnifiedBrainDataset(root_dir=DATAPATH, transform=transform, split="test")
test_loader = DataLoader(test_dataset, batch_size=6, num_workers=2, shuffle=False)

In [31]:
#direct

direct_avg = direct_generate_and_save_predictions(
    model=fm_direct,
    test_loader=test_loader,
    device=device,
    output_dir=None,
    just_one_batch=True
)

noise_avg = noise_generate_and_save_predictions(
    model=fm_noise,
    test_loader=test_loader,
    device=device,
    output_dir=None,
    just_one_batch=True
)

diff_avg = diff_generate_and_save_predictions(
    model=dm,
    inferer=dm_inferer,
    test_loader=test_loader,
    device=device,
    output_dir=None,
    just_one_batch=True
)

pix_avg = pix_generate_and_save_predictions(
    generator=pixgen,
    test_loader=test_loader,
    device=device,
    output_dir=None,
    just_one_batch=True
)

control_avg = control_generate_and_save_predictions(
    cn_model=controlnet,
    df_model=control_diff,
    inferer=controlnet_inferer,
    test_loader=test_loader,
    device=device,
    output_dir=None,
    just_one_batch=True
)

Generating Predictions:   0%|                                                     | 0/55 [01:09<?, ?it/s]
Generating Predictions:   0%|                                                     | 0/55 [01:11<?, ?it/s]
100%|████████████████████████████████████████████████████████████████| 1000/1000 [04:07<00:00,  4.04it/s]
Generating Predictions:   0%|                                                     | 0/55 [04:07<?, ?it/s]
Generating Predictions:   0%|                                                     | 0/55 [00:00<?, ?it/s]
Generating Predictions:   0%|                                                     | 0/55 [05:47<?, ?it/s]


In [33]:
# Raccogliamo i risultati
results = {
    "FM_Direct": direct_avg,
    "FM_Noise": noise_avg,
    "Diffusion": diff_avg,
    "Pix2Pix": pix_avg,
    "ControlNet": control_avg
}

# Stampa tabellare
print(f"{'Model':<12} | {'Avg Time (s/img)':>17}")
print('-' * 32)
for name, avg_time in results.items():
    print(f"{name:<12} | {avg_time:>17.4f}")


Model        |  Avg Time (s/img)
--------------------------------
FM_Direct    |           11.6146
FM_Noise     |           11.9371
Diffusion    |           41.3273
Pix2Pix      |            0.0539
ControlNet   |           57.8672


In [34]:
import torch

def get_model_state_dict_size(path, key='model_state_dict'):
    """
    Calcola la memoria occupata da un state_dict PyTorch salvato su disco.

    Args:
        path (str): Percorso del file contenente il checkpoint.
        key (str): Chiave del dizionario che contiene il state_dict. Default: 'model_state_dict'.

    Returns:
        tuple: (numero totale di parametri, dimensione in MB)
    """
    checkpoint = torch.load(path, map_location='cpu')

    # Se è un dizionario e contiene il key
    if isinstance(checkpoint, dict) and key in checkpoint:
        state_dict = checkpoint[key]
    else:
        # Altrimenti assume che sia direttamente lo state_dict
        state_dict = checkpoint

    total_params = 0
    total_size_bytes = 0

    for tensor in state_dict.values():
        total_params += tensor.numel()
        total_size_bytes += tensor.numel() * tensor.element_size()

    total_size_mb = total_size_bytes / (1024 ** 2)

    return total_params, total_size_mb


# Paths dei modelli
paths = {
    "ControlDiffusion": "/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_diffusion-t2-brain300e_164_best.pth",
    "ControlNet": "/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_controlnet-t1t2-brain300e_164_best.pth",
    "FM_Noise": "/home/andrea_moschetto/FlowMatching-MREConversion/checkpoints/checkpoint_unetflow-noiset1t2-s300e_46_best.pth",
    "FM_Direct": "/home/andrea_moschetto/FlowMatching-MREConversion/checkpoints/checkpoint_unetflow-t1t2-s300e_122_best.pth",
    "Pix2Pix": "/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_unet2pix-t1t2-brain300e_142__generator_best.pth",
    "Diffusion": "/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_diffusion-t1t2-brains300e_200_best.pth"
}

# Stampa dei risultati
print("Model Summary (parameters and size in MB):\n")
for name, path in paths.items():
    try:
        params, size_mb = get_model_state_dict_size(path)
        print(f"{name:<12}: {params:,} params, {size_mb:.2f} MB")
    except Exception as e:
        print(f"{name:<12}: Error reading checkpoint → {e}")


Model Summary (parameters and size in MB):

ControlDiffusion: 2,328,449 params, 8.88 MB
ControlNet  : 927,296 params, 3.54 MB
FM_Noise    : 2,328,737 params, 8.88 MB
FM_Direct   : 2,328,449 params, 8.88 MB


  checkpoint = torch.load(path, map_location='cpu')


Pix2Pix     : 2,328,449 params, 8.88 MB
Diffusion   : 2,328,737 params, 8.88 MB
