In [None]:
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

from monai.config import print_config
from monai.data import DataLoader
from torch.amp import autocast
from torch.optim.lr_scheduler import CosineAnnealingLR

from tqdm import tqdm


from generative.inferers import DiffusionInferer, ControlNetDiffusionInferer
from generative.networks.nets import DiffusionModelUNet, ControlNet
from generative.networks.schedulers import DDPMScheduler

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import random
from PIL import Image
from torchvision import transforms
import torchvision
import pandas as pd
from skimage.metrics import structural_similarity as ssim_fn
from skimage.metrics import peak_signal_noise_ratio as psnr_fn
import wandb


from torch import nn


from tqdm import trange, tqdm



print_config()

In [None]:
from generative.networks.nets import DiffusionModelUNet, ControlNet


In [None]:
DATAPATH = DATAPATH = '/home/andrea_moschetto/FlowMatching-MREConversion/data'
OUTPUT_DIR = "/home/andrea_moschetto/FlowMatching-MREConversion/outputs"
CHECKPOINTS_PATH = '/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints'

In [None]:
class UnifiedBrainDataset(Dataset):
    def __init__(self, root_dir, transform=None, split="train", seed=42):
        assert split in ["train", "val", "test"], "split must be 'train', 'val' or 'test'"
        self.root_dir = root_dir
        self.transform = transform
        self.split = split
        self.seed = seed
        self.samples = self._create_file_pairs()
        self._split_dataset()

    def _create_file_pairs(self):
        t1_dir = os.path.join(self.root_dir, "t1")
        t2_dir = os.path.join(self.root_dir, "t2")

        t1_files = set(os.listdir(t1_dir))
        t2_files = set(os.listdir(t2_dir))
        common_files = list(t1_files.intersection(t2_files))
        common_files.sort()

        pairs = [(os.path.join(t1_dir, fname), os.path.join(t2_dir, fname)) for fname in common_files]
        return pairs

    def _split_dataset(self):
        random.seed(self.seed)
        random.shuffle(self.samples)

        n_total = len(self.samples)
        n_train = int(n_total * 0.80)
        n_val = int(n_total * 0.05)
        n_test = n_total - n_train - n_val

        if self.split == "train":
            self.samples = self.samples[:n_train]
        elif self.split == "val":
            self.samples = self.samples[n_train:n_train + n_val]
        elif self.split == "test":
            self.samples = self.samples[n_train + n_val:]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        t1_path, t2_path = self.samples[idx]
        t1_image = Image.open(t1_path).convert("L")
        t2_image = Image.open(t2_path).convert("L")

        if self.transform:
            t1_image = self.transform(t1_image)
            t2_image = self.transform(t2_image)

        return {
            "t1": t1_image,
            "t2": t2_image,
            "filename": os.path.basename(t1_path)
        }

In [None]:
fm_direct = DiffusionModelUNet(
    spatial_dims=2,  #  2D
    in_channels=1,  # x
    out_channels=1  # predice delta_x_t
)
fm_noise = DiffusionModelUNet(
    spatial_dims=2,  #  2D
    in_channels=2,  #  noise + t1 
    out_channels=1  #  predice delta_x_t solo sul noise (condizionato da t1)
)
control_diff = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1
)
controlnet = ControlNet(
    spatial_dims=2,
    in_channels=1,
    conditioning_embedding_num_channels=(32, )
)
dm = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=2,
    out_channels=1
)
pixgen = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1
)

In [2]:
# Controlnet
controldiff_path = '/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_diffusion-t2-brain300e_164_best.pth'
controlnet_path = '/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_controlnet-t1t2-brain300e_164_best.pth'
fm_noise_path = '/home/andrea_moschetto/FlowMatching-MREConversion/checkpoints/checkpoint_unetflow-noiset1t2-s300e_46_best.pth'
fm_direct_path = '/home/andrea_moschetto/FlowMatching-MREConversion/checkpoints/checkpoint_unetflow-t1t2-s300e_122_best.pth'
pixgen_path = '/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_pix2pix-t1t2-brain300e_169__generator_best.pth'
diffusion_path = '/home/andrea_moschetto/FlowMatching-MREConversion/baseline_checkpoints/checkpoint_diffusion-t2-brain300e_164_best.pth'

In [None]:
#flow matching noise
from torch import Tensor


def noise_euler_step(model: DiffusionModelUNet, x_t: Tensor, t_start: Tensor, t_end: Tensor):
    # delta_t shape (B, 1, 1, 1)
    delta_t = (t_end - t_start).view(-1, 1, 1, 1)
    
    # model si aspetta t come tensor (B,)
    v_hat = model(x_t, t_start)
    
    x_t_noise = x_t[:, 0:1,:, :] # [B, 1, H, W]
    x_t_cond = x_t[:, 1:2, :, :] # [B, 1, H, W], che è T1
    
    x_next_noise = x_t_noise + delta_t * v_hat
    
    x_next = torch.cat([x_next_noise, x_t_cond], dim=1) # [B, 2, H, W]
    return x_next

@torch.no_grad()
def noise_generate(model: nn.Module, x_cond: Tensor, n_steps: int = 20):
    model.eval()
    
    device = x_cond.device
    batch_size = x_cond.shape[0]
    
    time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=device, dtype=torch.float32)
    
    x = torch.cat([torch.randn_like(x_cond,device=device), x_cond], dim=1) # [B, 2, H, W]
    for i in range(n_steps):
        t_start = time_steps[i].expand(batch_size)
        t_end = time_steps[i + 1].expand(batch_size)
        x = noise_euler_step(model, x_t=x, t_start=t_start, t_end=t_end)
    
    return x[:, 0:1, :, :] # [B, 1, H, W]

# flow matching direct
def direct_euler_step(model: DiffusionModelUNet, x_t: Tensor, t_start: Tensor, t_end: Tensor):
    # delta_t shape (B, 1, 1, 1)
    delta_t = (t_end - t_start).view(-1, 1, 1, 1)

    # model si aspetta t come tensor (B,)
    v_hat = model(x_t, t_start)

    x_next = x_t + delta_t * v_hat

    return x_next


@torch.no_grad()
def direct_generate(model: nn.Module, x_T: Tensor, n_steps: int = 20):
    model.eval()

    device = x_T.device
    batch_size = x_T.shape[0]

    time_steps = torch.linspace(
        0.0, 1.0, n_steps + 1, device=device, dtype=torch.float32)

    x = x_T
    for i in range(n_steps):
        t_start = time_steps[i].expand(batch_size)
        t_end = time_steps[i + 1].expand(batch_size)
        x = direct_euler_step(model, x_t=x, t_start=t_start, t_end=t_end)

    return x

# controlnet
    # on inferer

#pix2pix
    # direct generation

#diffusion
    # on inferer


In [None]:
# control net

@torch.no_grad()
def control_generate_and_save_predictions(cn_model: ControlNet, df_model: DiffusionModelUNet, inferer: ControlNetDiffusionInferer, test_loader: DataLoader, device: str, output_dir: str = OUTPUT_DIR, just_one_batch: bool = False):
    os.makedirs(output_dir, exist_ok=True)
    df_model.eval()

    all_outputs = []

    for idx, batch in enumerate(tqdm(test_loader, desc="Generating Predictions")):
        with autocast(device_type="cuda", enabled=True):
            t1_cond = batch["t1"].to(device)           # [B, 1, H, W]
            t2_targets = batch["t2"].to(device)           # [B, 1, H, W]
            noise = torch.randn_like(t2_targets).to(device)          # [B, 1, H, W]
            filenames = batch["filename"]         # list of strings (length B)
            
            gen_images = inferer.sample(input_noise=noise, diffusion_model=df_model, controlnet=cn_model, scheduler=inferer.scheduler, cn_cond=t1_cond, verbose=False)

        for i in range(t1_cond.size(0)):
            sample = {
                "filename": filenames[i],
                "input": t1_cond[i].cpu(),         # torch.Tensor [1, H, W]
                "target": t2_targets[i].cpu(),
                "prediction": gen_images[i].cpu()
            }

            torch.save(sample, os.path.join(output_dir, f"{filenames[i]}.pt"))
            all_outputs.append(sample)
            
        if just_one_batch:
            break
        wandb.log({"prediction_progress": idx})

    return all_outputs

# pix2pix
@torch.no_grad()
def pix_generate_and_save_predictions(generator: DiffusionModelUNet, test_loader: DataLoader, device: str, output_dir: str = OUTPUT_DIR, just_one_batch: bool = False):
    os.makedirs(output_dir, exist_ok=True)
    generator.eval()

    all_outputs = []

    for idx, batch in enumerate(tqdm(test_loader, desc="Generating Predictions")):
        with torch.no_grad():
            real_A = batch["t1"].to(device)  # [B, 1, H, W]
            real_B = batch["t2"].to(device)  # [B, 1, H, W]
            filenames = batch["filename"]  # list of strings (length B)
            # Generate fake T2 images
            if isinstance(generator, DiffusionModelUNet):
                gen_images = generator(x=real_A, timesteps=torch.zeros(real_A.shape[0], device=device))
            else:
                gen_images = generator(real_A)  # [B, 1, H, W]

            
        for i in range(real_A.size(0)):
            sample = {
                "filename": filenames[i],
                "input": real_A[i].cpu(),         # torch.Tensor [1, H, W]
                "target": real_B[i].cpu(),
                "prediction": gen_images[i].cpu()
            }

            torch.save(sample, os.path.join(output_dir, f"{filenames[i]}.pt"))
            all_outputs.append(sample)
            
        if just_one_batch:
            break
        wandb.log({"prediction_progress": idx})

    return all_outputs

# diffusion model
@torch.no_grad()
def diff_generate_and_save_predictions(model: nn.Module, inferer: DiffusionInferer, test_loader: DataLoader, device: str, output_dir: str = OUTPUT_DIR, just_one_batch: bool = False):
    os.makedirs(output_dir, exist_ok=True)
    model.eval()

    all_outputs = []

    for idx, batch in enumerate(tqdm(test_loader, desc="Generating Predictions")):
        t1_cond = batch["t1"].to(device)           # [B, 1, H, W]
        t2_target = batch["t2"].to(device)           # [B, 1, H, W]
        noise = torch.randn_like(t2_target).to(device)          # [B, 1, H, W]
        filenames = batch["filename"]         # list of strings (length B)

        gen_image = inferer.sample(input_noise=noise, diffusion_model=model,
                                   scheduler=inferer.scheduler, mode='concat', conditioning=t1_cond)

        for i in range(t1_cond.size(0)):
            sample = {
                "filename": filenames[i],
                "input": t1_cond[i].cpu(),         # torch.Tensor [1, H, W]
                "target": t2_target[i].cpu(),
                "prediction": gen_image[i].cpu()
            }

            torch.save(sample, os.path.join(output_dir, f"{filenames[i]}.pt"))
            all_outputs.append(sample)
        if just_one_batch:
            break
        wandb.log({"prediction_progress": idx})

    return all_outputs

# flow matching direct
@torch.no_grad()
def direct_generate_and_save_predictions(model, test_loader, device, output_dir=OUTPUT_DIR, just_one_batch=False):
    os.makedirs(output_dir, exist_ok=True)
    model.eval()

    all_outputs = []

    for batch in tqdm(test_loader, desc="Generating Predictions"):
        t1 = batch["t1"].to(device)           # [B, 1, H, W]
        t2 = batch["t2"].to(device)           # [B, 1, H, W]
        filenames = batch["filename"]         # list of strings (length B)

        x_gen = direct_generate(model, x_T=t1, n_steps=300)
        # print(t2.shape, x_gen.shape)

        for i in range(t1.size(0)):
            sample = {
                "filename": filenames[i],
                "input": t1[i].cpu(),         # torch.Tensor [1, H, W]
                "target": t2[i].cpu(),
                "prediction": x_gen[i].cpu()
            }

            torch.save(sample, os.path.join(output_dir, f"{filenames[i]}.pt"))
            all_outputs.append(sample)
        if just_one_batch:
            break

    return all_outputs

#flow matching noise
@torch.no_grad()
def noise_generate_and_save_predictions(model, test_loader, device, output_dir=OUTPUT_DIR, just_one_batch=False):
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    
    all_outputs = []

    for idx, batch in enumerate(tqdm(test_loader, desc="Generating Predictions")):
        t1 = batch["t1"].to(device)           # [B, 1, H, W]
        t2 = batch["t2"].to(device)           # [B, 1, H, W]
        filenames = batch["filename"]         # list of strings (length B)

        x_gen = noise_generate(model, x_cond=t1, n_steps=300)
        # print(t2.shape, x_gen.shape)

        for i in range(t1.size(0)):
            sample = {
                "filename": filenames[i],
                "input": t1[i].cpu(),         # torch.Tensor [1, H, W]
                "target": t2[i].cpu(),
                "prediction": x_gen[i].cpu()
            }

            torch.save(sample, os.path.join(output_dir, f"{filenames[i]}.pt"))
            all_outputs.append(sample)
        if just_one_batch:
            break
        wandb.log({"prediction_progress": idx})

    return all_outputs


In [None]:
checkpoint = torch.load(fm_direct_path, map_location='cpu')
fm_direct.load_state_dict(checkpoint['model_state_dict'])

checkpoint = torch.load(fm_noise_path, map_location='cpu')
fm_noise.load_state_dict(checkpoint['model_state_dict'])
# --------------------------------------
num_train_timesteps = 1000
checkpoint = torch.load(controldiff_path, map_location='cpu')
control_diff.load_state_dict(checkpoint['model_state_dict'])
control_scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps)
control_diff_inferer = DiffusionInferer(control_scheduler)

checkpoint = torch.load(controlnet_path, map_location='cpu')
controlnet.load_state_dict(checkpoint['model_state_dict'])
controlnet_inferer = ControlNetDiffusionInferer(control_scheduler)
# --------------------------------------

checkpoint = torch.load(pixgen_path, map_location='cpu')
pixgen.load_state_dict(checkpoint['generator_state_dict'])

checkpoint = torch.load(diffusion_path, map_location='cpu')
dm.load_state_dict(checkpoint['model_state_dict'])
dm_scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps)
dm_inferer = DiffusionInferer(dm_scheduler)

In [None]:
transform = transforms.Compose([
    transforms.Pad(padding=(5, 3, 5, 3), fill=0),
    transforms.ToTensor(),  # Normalize to [0, 1]
])
test_dataset = UnifiedBrainDataset(root_dir=DATAPATH, transform=transform, split="test")
test_loader = DataLoader(test_dataset, batch_size=6, num_workers=2, shuffle=False)

In [None]:
#direct
direct_generate_and_save_predictions(
    model=fm_direct, 
    
)