In [2]:
import os
DATAPATH = DATAPATH = '/home/andrea_moschetto/flow_matching_t1t2/data'
print(DATAPATH)

In [3]:
import os
import random
from torch.utils.data import Dataset
from PIL import Image

class UnifiedBrainDataset(Dataset):
    def __init__(self, root_dir, transform=None, is_train=True, split_ratio=0.8, seed=42):
        self.root_dir = root_dir
        self.transform = transform
        self.is_train = is_train
        self.split_ratio = split_ratio
        self.seed = seed
        self.samples = self._create_file_pairs()
        self._split_dataset()
        
    def _create_file_pairs(self):
        t1_dir = os.path.join(self.root_dir, "t1")
        t2_dir = os.path.join(self.root_dir, "t2")
        
        t1_files = set(os.listdir(t1_dir))
        t2_files = set(os.listdir(t2_dir))
        common_files = list(t1_files.intersection(t2_files))
        common_files.sort()  # garantisce ordine ripetibile

        pairs = [(os.path.join(t1_dir, fname), os.path.join(t2_dir, fname)) for fname in common_files]
        return pairs

    def _split_dataset(self):
        random.seed(self.seed)
        random.shuffle(self.samples)
        
        split_idx = int(len(self.samples) * self.split_ratio)
        if self.is_train:
            self.samples = self.samples[:split_idx]
        else:
            self.samples = self.samples[split_idx:]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        t1_path, t2_path = self.samples[idx]
        t1_image = Image.open(t1_path).convert("L")
        t2_image = Image.open(t2_path).convert("L")
        
        if self.transform:
            t1_image = self.transform(t1_image)
            t2_image = self.transform(t2_image)

        return {
            "t1": t1_image,
            "t2": t2_image,
            "filename": os.path.basename(t1_path)
        }


In [34]:
import torch
from torch import nn
from torch import Tensor
from generative.networks.nets import DiffusionModelUNet


def euler_step(model: DiffusionModelUNet, x_t: Tensor, t_start: Tensor, t_end: Tensor):
    # delta_t shape (B, 1, 1, 1)
    delta_t = (t_end - t_start).view(-1, 1, 1, 1)
    
    # model si aspetta t come tensor (B,)
    v_hat = model(x_t, t_start)
    
    x_next = x_t + delta_t * v_hat
    
    return x_next

@torch.no_grad()
def generate(model: nn.Module, x_T: Tensor, n_steps: int = 20):
    model.eval()
    
    device = x_T.device
    batch_size = x_T.shape[0]
    
    time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=device, dtype=torch.float32)
    
    x = x_T
    for i in range(n_steps):
        t_start = time_steps[i].expand(batch_size)
        t_end = time_steps[i + 1].expand(batch_size)
        x = euler_step(model, x_t=x, t_start=t_start, t_end=t_end)
    
    return x


In [5]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

In [None]:

import torchvision
import wandb
from matplotlib import pyplot as plt


def log_generation(model, epoch, device, reference_image=None, use_wandb=True):
    with torch.no_grad():
        noise = torch.randn(1, 1, 224, 192).to(device)  #  1, 1, 224, 192
        x_gen = generate(model=model, x_T=noise, n_steps=100)
        if reference_image is not None:
            # stack [real, generated]
            images = torch.cat([reference_image, x_gen], dim=0)
        else:
            images = x_gen
        
        grid = torchvision.utils.make_grid(images, nrow=2, normalize = True)
        if use_wandb:
            wandb.log(
                {"generation": [wandb.Image(grid, caption=f"Epoch {epoch}")]})
        return grid


def show_grid(grid):
    # grid è un tensore [C, H, W], lo trasformiamo in un'immagine visualizzabile
    np_grid = grid.permute(1, 2, 0).cpu().numpy()  # da [C,H,W] a [H,W,C]
    plt.figure(figsize=(8, 8))
    plt.imshow(np_grid)
    plt.axis('off')
    plt.show()

In [None]:
import time
from torch.utils.data import DataLoader
from tqdm import trange, tqdm
import wandb
import os

CHECKPOINTS_PATH = '/home/andrea_moschetto/flow_matching_t1t2/checkpoints'
if not os.path.exists(CHECKPOINTS_PATH):
    os.makedirs(CHECKPOINTS_PATH)


def train_flow(model: DiffusionModelUNet, train_loader: DataLoader, val_loader: DataLoader, project: str, exp_name: str, notes: str, n_epochs: int = 10, lr : float = 1e-3):
    with wandb.init(
        project=project,
        name=exp_name,
        notes=notes,
        tags=["flow", "brain", "diffusion"],
        config={
            'model': model.__class__.__name__,
            'epochs': n_epochs,
            'batch_size': train_loader.batch_size,
            'num_workers': train_loader.num_workers,
            'optimizer': 'Adam',
            'learning_rate': lr,
            'loss_function': 'MSELoss',
            'device': str(torch.cuda.get_device_name(0)
                          if torch.cuda.is_available() else "CPU"),
        }
    ) as run:

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("Using", torch.cuda.get_device_name(0)
                if torch.cuda.is_available() else "CPU")

        model.to(device)
        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.MSELoss()
        old_epoch = 0
        
        start_time = time.time()
        for e in trange(n_epochs, desc="Epochs"):
            start_e_time = time.time()
            # Training
            model.train()
            train_losses = []
            for batch in tqdm(train_loader, desc=f"Training epoch {e}"):
                x_1 = batch["t2"].to(device)  # [B, 1, H, W]
                x_0 = torch.randn_like(x_1).to(device)  # [B, 1, H, W]

                B = x_0.shape[0]
                t = torch.rand(B, device=device)  # B 
                
                t_img = t.view(B, 1, 1, 1)  # [B, 1, 1, 1] for broadcasting

                x_t = (1 - t_img) * x_0 + t_img * x_1         # [B, 1, H, W]
                dx_t = x_1 - x_0                              # [B, 1, H, W]

                optimizer.zero_grad()
                pred = model(x_t, t)  # [B, 1, H, W]
                loss = criterion(pred, dx_t)
                train_losses.append(loss.item())
                loss.backward()
                optimizer.step()
            wandb.log({"train_loss": sum(train_losses) / len(train_losses)})
            
            # Validation
            model.eval()
            val_losses = []
            with torch.no_grad():
                for batch in val_loader:
                    x_1 = batch["t2"].to(device)
                    x_0 = torch.randn_like(x_1, device=device)
                    B = x_0.shape[0]
                    t = torch.rand(B, device=device)
                    t_img = t.view(B, 1, 1, 1)
                    x_t = (1 - t_img) * x_0 + t_img * x_1
                    dx_t = x_1 - x_0

                    pred = model(x_t, t)
                    val_loss = criterion(pred, dx_t)
                    val_losses.append(val_loss.item())
            wandb.log({"val_loss": sum(val_losses) / len(val_losses)})
            e_time = time.time() - start_e_time
            wandb.log({"epoch_time_minutes": e_time // 60})


            # Checkpoint
            if e % 5 == 0 or e == n_epochs - 1:
                sample_batch = next(iter(train_loader))  # just one batch
                reference = sample_batch["t2"][0].unsqueeze(0).to(device)  # [1, 1, H, W]
                log_generation(model, epoch=e, device=device, reference_image=reference)

                torch.save({
                    'epoch': e,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, f'{CHECKPOINTS_PATH}/checkpoint_{exp_name}_{e}.pth')
                if os.path.exists(f'{CHECKPOINTS_PATH}/checkpoint_{exp_name}_{old_epoch}.pth'):
                    os.remove(
                        f'{CHECKPOINTS_PATH}/checkpoint_{exp_name}_{old_epoch}.pth')
                old_epoch = e
        end_time = time.time()
        elapsed_time = end_time - start_time
        wandb.log({"total_running_hours": elapsed_time // 3600})
        print(f"Training completed in {elapsed_time // 60:.0f}m {elapsed_time % 60:.0f}s")
    print("Training complete.")

In [8]:
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Pad(padding=(5, 3, 5, 3), fill=0),
    transforms.ToTensor(),  # Normalize to [0, 1]
])

train_dataset = UnifiedBrainDataset(root_dir=DATAPATH, transform=transform, is_train=True)
train_loader = DataLoader(train_dataset, batch_size=4, num_workers=2, shuffle=True)
val_dataset = UnifiedBrainDataset(root_dir=DATAPATH, transform=transform, is_train=False)
val_loader = DataLoader(val_dataset, batch_size=4, num_workers=2, shuffle=False)

model = DiffusionModelUNet(
    spatial_dims=2,  #  2D
    in_channels=1,  #  x 
    out_channels=1  #  predice delta_x_t
)

In [None]:
train_flow(
    model=model, 
    train_loader=train_loader, 
    val_loader=val_loader,
    project='flowmatching-t2gen', 
    exp_name="unetflow-t2gen-50e",
    notes="UNet flow model for T2 generation",
    n_epochs=50, 
    lr=1e-4)