In [None]:
import os
import random
import torch
import time
from torch import nn
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch import Tensor
from PIL import Image

from generative.networks.nets import DiffusionModelUNet
import wandb
from matplotlib import pyplot as plt
from tqdm import trange, tqdm
import numpy as np
import pandas as pd

from skimage.metrics import structural_similarity as ssim_fn

In [None]:
DATAPATH = DATAPATH = '/home/andrea_moschetto/flow_matching_t1t2/data'
OUTPUT_DIR = "/home/andrea_moschetto/flow_matching_t1t2/outputs"

In [11]:
class UnifiedBrainDataset(Dataset):
    def __init__(self, root_dir, transform=None, split="train", seed=42):
        assert split in ["train", "val", "test"], "split must be 'train', 'val' or 'test'"
        self.root_dir = root_dir
        self.transform = transform
        self.split = split
        self.seed = seed
        self.samples = self._create_file_pairs()
        self._split_dataset()

    def _create_file_pairs(self):
        t1_dir = os.path.join(self.root_dir, "t1")
        t2_dir = os.path.join(self.root_dir, "t2")

        t1_files = set(os.listdir(t1_dir))
        t2_files = set(os.listdir(t2_dir))
        common_files = list(t1_files.intersection(t2_files))
        common_files.sort()

        pairs = [(os.path.join(t1_dir, fname), os.path.join(t2_dir, fname)) for fname in common_files]
        return pairs

    def _split_dataset(self):
        random.seed(self.seed)
        random.shuffle(self.samples)

        n_total = len(self.samples)
        n_train = int(n_total * 0.80)
        n_val = int(n_total * 0.05)
        n_test = n_total - n_train - n_val

        if self.split == "train":
            self.samples = self.samples[:n_train]
        elif self.split == "val":
            self.samples = self.samples[n_train:n_train + n_val]
        elif self.split == "test":
            self.samples = self.samples[n_train + n_val:]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        t1_path, t2_path = self.samples[idx]
        t1_image = Image.open(t1_path).convert("L")
        t2_image = Image.open(t2_path).convert("L")

        if self.transform:
            t1_image = self.transform(t1_image)
            t2_image = self.transform(t2_image)

        return {
            "t1": t1_image,
            "t2": t2_image,
            "filename": os.path.basename(t1_path)
        }


In [12]:
def euler_step(model: DiffusionModelUNet, x_t: Tensor, t_start: Tensor, t_end: Tensor):
    # delta_t shape (B, 1, 1, 1)
    delta_t = (t_end - t_start).view(-1, 1, 1, 1)
    
    # model si aspetta t come tensor (B,)
    v_hat = model(x_t, t_start)
    
    x_next = x_t + delta_t * v_hat
    
    return x_next

@torch.no_grad()
def generate(model: nn.Module, x_T: Tensor, n_steps: int = 20):
    model.eval()
    
    device = x_T.device
    batch_size = x_T.shape[0]
    
    time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=device, dtype=torch.float32)
    
    x = x_T
    for i in range(n_steps):
        t_start = time_steps[i].expand(batch_size)
        t_end = time_steps[i + 1].expand(batch_size)
        x = euler_step(model, x_t=x, t_start=t_start, t_end=t_end)
    
    return x


In [14]:
def log_generation(model, epoch, device, n_steps: int, reference_image=None, use_wandb=True):
    with torch.no_grad():
        if reference_image is not None:
            reference_t1, reference_t2 = reference_image  # t1 = input, t2 = ground truth
            # input al modello = t1
            x_T = reference_t1.to(device)
            x_gen = generate(model=model, x_T=x_T, n_steps=n_steps)

            # Visualizza t1 (input), t2 (vero), x_gen (generato)
            images = torch.cat([reference_t1, reference_t2,
                               x_gen], dim=0)  # [3, 1, H, W]
        else:
            raise ValueError(
                "reference_image must be provided when generating from t1.")

        grid = torchvision.utils.make_grid(images, nrow=3, normalize=True)

        if use_wandb:
            wandb.log({
                "generation": [wandb.Image(grid, caption=f"Epoch {epoch+1}")]
            })

        return grid, x_gen


def show_grid(grid):
    # grid è un tensore [C, H, W], lo trasformiamo in un'immagine visualizzabile
    np_grid = grid.permute(1, 2, 0).cpu().numpy()  # da [C,H,W] a [H,W,C]
    plt.figure(figsize=(8, 8))
    plt.imshow(np_grid)
    plt.axis('off')
    plt.show()

In [15]:
CHECKPOINTS_PATH = '/home/andrea_moschetto/flow_matching_t1t2/checkpoints'
if not os.path.exists(CHECKPOINTS_PATH):
    os.makedirs(CHECKPOINTS_PATH)


def train_flow(model: DiffusionModelUNet, train_loader: DataLoader, val_loader: DataLoader, project: str, exp_name: str, notes: str, n_epochs: int = 10, lr : float = 1e-3, generation_steps: int = 100):
    with wandb.init(
        project=project,
        name=exp_name,
        notes=notes,
        tags=["flow", "brain", "diffusion"],
        config={
            'model': model.__class__.__name__,
            'epochs': n_epochs,
            'batch_size': train_loader.batch_size,
            'num_workers': train_loader.num_workers,
            'optimizer': 'Adam',
            'learning_rate': lr,
            'loss_function': 'MSELoss',
            'generation_steps': generation_steps,
            'device': str(torch.cuda.get_device_name(0)
                          if torch.cuda.is_available() else "CPU"),
        }
    ) as run:

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("Using", torch.cuda.get_device_name(0)
                if torch.cuda.is_available() else "CPU")

        model.to(device)
        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.MSELoss()

        best_val_loss = float("inf")
        best_model_path = None
        start_time = time.time()
        for e in trange(n_epochs, desc="Epochs"):
            start_e_time = time.time()
            # Training
            model.train()
            train_losses = []
            for batch in tqdm(train_loader, desc=f"Training epoch {e}"):
                x_1 = batch["t2"].to(device)  # [B, 1, H, W]
                x_0 = batch["t1"].to(device)  # [B, 1, H, W]  # torch.randn_like(x_1).to(device)  # [B, 1, H, W]
                # add the corresponding t1 to the second channel of x_0
                
                B = x_0.shape[0]
                t = torch.rand(B, device=device)  # B 
                
                t_img = t.view(B, 1, 1, 1)  # [B, 1, 1, 1] for broadcasting

                x_t = (1 - t_img) * x_0 + t_img * x_1         # [B, 1, H, W]
                dx_t = x_1 - x_0                              # [B, 1, H, W]

                optimizer.zero_grad()
                pred = model(x_t, t)  # [B, 1, H, W]
                loss = criterion(pred, dx_t)
                train_losses.append(loss.item())
                loss.backward()
                optimizer.step()
            wandb.log({"train_loss": sum(train_losses) / len(train_losses)})
            
            # Validation
            model.eval()
            val_losses = []
            with torch.no_grad():
                for batch in val_loader:
                    x_1 = batch["t2"].to(device)
                    x_0 = batch["t1"].to(device)
                    B = x_0.shape[0]
                    t = torch.rand(B, device=device)
                    t_img = t.view(B, 1, 1, 1)
                    x_t = (1 - t_img) * x_0 + t_img * x_1
                    dx_t = x_1 - x_0

                    pred = model(x_t, t)
                    val_loss = criterion(pred, dx_t)
                    val_losses.append(val_loss.item())
            batch_val_loss = sum(val_losses) / len(val_losses)
            wandb.log({"val_loss": batch_val_loss})
            e_time = time.time() - start_e_time
            wandb.log({"epoch_time_minutes": e_time // 60})


            # Checkpoint
            if batch_val_loss < best_val_loss:
                sample_batch = next(iter(val_loader))  # just one batch
                reference_t2 = sample_batch["t2"][0].unsqueeze(0).to(device)  # [1, 1, H, W]
                reference_t1 = sample_batch["t1"][0].unsqueeze(0).to(device)
                reference = (reference_t1, reference_t2)
                log_generation(model, epoch=e, device=device, n_steps=generation_steps, reference_image=reference)
                
                path = f'{CHECKPOINTS_PATH}/checkpoint_{exp_name}_{e+1}.pth'
                torch.save({
                    'epoch': e+1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, path)
                if best_model_path is not None and os.path.exists(best_model_path):
                    os.remove(best_model_path)
                best_model_path = path
                best_val_loss = batch_val_loss
        end_time = time.time()
        elapsed_time = end_time - start_time
        wandb.log({"total_running_hours": elapsed_time // 3600})
        print(f"Training completed in {elapsed_time // 60:.0f}m {elapsed_time % 60:.0f}s")
    print("Training complete.")

In [122]:
@torch.no_grad()
def generate_and_save_predictions(model, test_loader, device, output_dir=OUTPUT_DIR):
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    
    all_outputs = []

    for batch in tqdm(test_loader, desc="Generating Predictions"):
        t1 = batch["t1"].to(device)           # [B, 1, H, W]
        t2 = batch["t2"].to(device)           # [B, 1, H, W]
        filenames = batch["filename"]         # list of strings (length B)

        x_gen = generate(model, x_T=t1, n_steps=300)
        # print(t2.shape, x_gen.shape)

        for i in range(t1.size(0)):
            sample = {
                "filename": filenames[i],
                "input": t1[i].cpu(),         # torch.Tensor [1, H, W]
                "target": t2[i].cpu(),
                "prediction": x_gen[i].cpu()
            }

            torch.save(sample, os.path.join(output_dir, f"{filenames[i]}.pt"))
            all_outputs.append(sample)

    return all_outputs


In [126]:
class PredictionDataset(Dataset):
    def __init__(self, directory):
        super().__init__()
        self.directory = directory
        self.files = sorted([
            f for f in os.listdir(directory) if f.endswith('.pt')
        ])
        if not self.files:
            raise ValueError(f"No .pt files found in directory: {directory}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = os.path.join(self.directory, self.files[idx])
        data = torch.load(file_path)
        pred = data["prediction"]       # expected shape: [1, H, W] or [C, H, W]
        gt = data["target"]
        return pred, gt


In [130]:
def percnorm(arr, lperc=5, uperc=99.5):
    """
    Remove outlier intensities from a brain component,
    similar to Tukey's fences method.
    """
    upperbound = np.percentile(arr, uperc)
    lowerbound = np.percentile(arr, lperc)
    arr[arr > upperbound] = upperbound
    arr[arr < lowerbound] = lowerbound
    return arr

def normalize(img):
    # img: [C, H, W]
    img = (img - img.min())/ (img.max() - img.min() + 1e-8)
    return img


In [155]:
def compute_ssim_from_dataset(dataset):
    ssim_scores = []
    mse_scores = []

    example_pred = None
    example_gt = None

    crop = transforms.CenterCrop((182, 150))

    for i in range(len(dataset)):
        pred, gt = dataset[i]  # tensors [1, H, W]

        # Convert to numpy and squeeze channel
        pred_np = pred.squeeze().cpu().numpy()
        gt_np = gt.squeeze().cpu().numpy()

        pred_np = normalize(percnorm(pred_np))
        gt_np = normalize(percnorm(gt_np))

        # Compute SSIM
        ssim_val = ssim_fn(pred_np, gt_np, data_range=1.0)
        ssim_scores.append(ssim_val)

        # Compute MSE
        mse_val = F.mse_loss(pred, gt).item()
        mse_scores.append(mse_val)

        # Store one example for visualization
        if i == 4 and example_pred is None:
            example_pred = pred_np
            example_gt = gt_np

    ssim_scores = np.array(ssim_scores)
    mse_scores = np.array(mse_scores)
    
    summary = pd.DataFrame({
        "Metric": ["SSIM", "MSE"],
        "Mean": [ssim_scores.mean(), mse_scores.mean()],
        "Variance": [ssim_scores.var(), mse_scores.var()]
    })

    # Visualize example
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    axs[0].imshow(example_gt, cmap='gray')
    axs[0].set_title("Ground Truth")
    axs[0].axis("off")

    axs[1].imshow(example_pred, cmap='gray')
    axs[1].set_title("Prediction")
    axs[1].axis("off")

    plt.suptitle("Example Comparison")
    plt.tight_layout()
    plt.show()

    return summary

In [16]:
transform = transforms.Compose([
    transforms.Pad(padding=(5, 3, 5, 3), fill=0),
    transforms.ToTensor(),  # Normalize to [0, 1]
])

train_dataset = UnifiedBrainDataset(root_dir=DATAPATH, transform=transform, split="train")
train_loader = DataLoader(train_dataset, batch_size=6, num_workers=2, shuffle=True)
val_dataset = UnifiedBrainDataset(root_dir=DATAPATH, transform=transform, split="val")
val_loader = DataLoader(val_dataset, batch_size=6, num_workers=2, shuffle=False)
test_dataset = UnifiedBrainDataset(root_dir=DATAPATH, transform=transform, split="test")
test_loader = DataLoader(test_dataset, batch_size=6, num_workers=2, shuffle=False)

model = DiffusionModelUNet(
    spatial_dims=2,  #  2D
    in_channels=1,  #  x 
    out_channels=1  #  predice delta_x_t
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
exp_name = "unetflow-t1t2-150e"
modelpath = train_flow(
    model=model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    project='flowmatching-t1-to-t2', 
    exp_name=exp_name,
    notes="UNet flow model for directional diffusion from T1 to T2. 50 epochs.",
    n_epochs=150, 
    lr=1e-4,
    generation_steps=300)

generate_and_save_predictions(model, test_loader, device, output_dir=f'{OUTPUT_DIR}/{exp_name}')
out_dataset = PredictionDataset(directory=f'{OUTPUT_DIR}/{exp_name}')
with wandb.init(
    project = 'flowmatching-t1-to-t2',
    name=f'evaluation-{exp_name}',
    notes="Evaluation of the flow model on the test set.",
):
    summary = compute_ssim_from_dataset(out_dataset)
    wandb.log({"eval/metrics": wandb.Table(dataframe=summary)})
    wandb.log({"eval/ssim_mean": summary["Mean"][0]})
    wandb.log({"eval/mse_mean": summary["Mean"][1]})
    wandb.log({"eval/ssim_var": summary["Variance"][0]})
    wandb.log({"eval/mse_var": summary["Variance"][1]})
summary

In [18]:
# # Load the best checkpoint
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# checkpoint_path = f'{CHECKPOINTS_PATH}/checkpoint_unetflow-t1t2-150e_149.pth'
# model.load_state_dict(torch.load(checkpoint_path, map_location=device)['model_state_dict'])
# model.eval()