# Baseline: Imagen Diffusion (MRI → SPECT)

Entrenamiento de un modelo de difusión directo en espacio de imagen (no latente) usando las mismas carpetas de datos (`whole_MRI`, `whole_SPECT`, `data_info`).

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

In [None]:
import os
import math
import copy
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

import nibabel as nib

from dataset import resolve_nifti, nifti_to_numpy, min_max_norm, z_score_norm, crop
from model import UNet
import config
import torch.nn.functional as F

In [None]:
def seed_all(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_all(config.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
cfg = {
    'train_txt': config.train,
    'val_txt': config.validation,
    'test_txt': config.test,
    'root_mri': config.whole_MRI,
    'root_spect': config.whole_Abeta,
    'batch_size': 4,
    'num_workers': 0,
    'epochs': 200,
    'lr': 2e-4,
    'time_dim': config.time_dim,
    'image_size': 160,  # después se reduce con convs en el UNet cond branch
    'save_dir': 'result/exp_image_diffusion/',
    'patience': 20,
}
os.makedirs(cfg['save_dir'], exist_ok=True)

## Dataset (imagen completa)
- MRI: crop a 160×192×160 y z-score.
- SPECT: crop y min-max.
- Etiquetas desde `data_info.csv` (si falta, usa 0).

In [None]:
class ImagePairDataset(Dataset):
    def __init__(self, ids_txt, root_mri, root_spect, stage='train'):
        self.ids = [i.strip() for i in open(ids_txt) if i.strip()]
        self.root_mri = root_mri
        self.root_spect = root_spect
        self.stage = stage
        self.labels = pd.read_csv('data_info/data_info.csv', encoding='ISO-8859-1')

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        bid = self.ids[idx % len(self.ids)]
        mri_path = resolve_nifti(self.root_mri, bid)
        spect_path = resolve_nifti(self.root_spect, bid)

        mri = crop(z_score_norm(nifti_to_numpy(mri_path)))
        spect = crop(min_max_norm(nifti_to_numpy(spect_path)))

        mri = torch.tensor(mri[None, ...], dtype=torch.float32)
        spect = torch.tensor(spect[None, ...], dtype=torch.float32)

        label = self.labels[self.labels['ID'].astype(str) == bid]['label'].values.astype(np.float32)
        if label.size == 0:
            label = np.array([0], dtype=np.float32)
        label = torch.tensor(label, dtype=torch.float32)
        return mri, spect, bid, label

## Difusión y modelo

In [None]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=2e-2, device=device):
        self.device = device
        self.noise_steps = noise_steps
        self.beta = torch.linspace(beta_start, beta_end, noise_steps, device=self.device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None, None]
        eps = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * eps, eps

    def sample_timesteps(self, n):
        return torch.randint(1, self.noise_steps, (n,), device=self.device)

    @torch.no_grad()
    def sample(self, model, y, labels=None):
        model.eval()
        n = y.shape[0]
        x = torch.randn((n, 1, 40, 48, 40), device=y.device)
        for i in tqdm(reversed(range(1, self.noise_steps)), total=self.noise_steps-1, leave=False):
            t = torch.full((n,), i, device=y.device, dtype=torch.long)
            pred = model(x, y, t, labels)
            alpha = self.alpha[t][:, None, None, None, None]
            alpha_hat = self.alpha_hat[t][:, None, None, None, None]
            beta = self.beta[t][:, None, None, None, None]
            noise = torch.randn_like(x) if i > 1 else torch.zeros_like(x)
            x = (1 / torch.sqrt(alpha)) * (x - (1 - alpha) / torch.sqrt(1 - alpha_hat) * pred) + torch.sqrt(beta) * noise
        return x


## Entrenamiento

In [None]:
train_ds = ImagePairDataset(cfg['train_txt'], cfg['root_mri'], cfg['root_spect'], stage='train')
val_ds = ImagePairDataset(cfg['val_txt'], cfg['root_mri'], cfg['root_spect'], stage='val')

train_loader = DataLoader(train_ds, batch_size=cfg['batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=cfg['num_workers'], pin_memory=True)

unet = UNet(in_channel=2, out_channel=1, image_size=40).to(device)
optimizer = torch.optim.AdamW(unet.parameters(), lr=cfg['lr'])
diffusion = Diffusion()
ema = copy.deepcopy(unet).eval().requires_grad_(False)

mse = nn.MSELoss()

best_ssim = -1e9
patience = 0

for epoch in range(cfg['epochs']):
    unet.train()
    loop = tqdm(train_loader, leave=False)
    epoch_loss = 0
    for mri, spect, bid, label in loop:
        mri, spect, label = mri.to(device), spect.to(device), label.to(device)
        # dentro del loop de train, antes de sample_timesteps:
        spect_small = F.interpolate(spect, size=(40, 48, 40), mode='trilinear', align_corners=False)
        t = diffusion.sample_timesteps(spect_small.shape[0]).to(device)
        x_t, noise = diffusion.noise_images(spect_small, t)
        pred_noise = unet(x_t, mri, t, label)
        loss = mse(pred_noise, noise)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for p, q in zip(ema.parameters(), unet.parameters()):
            p.data = 0.999 * p.data + 0.001 * q.data

        epoch_loss += loss.item()
        loop.set_description(f"epoch {epoch+1} loss {loss.item():.4f}")

    # Validación
    unet.eval()
    psnr_sum = 0
    ssim_sum = 0
    with torch.no_grad():
        for mri, spect, bid, label in val_loader:
            mri, spect, label = mri.to(device), spect.to(device), label.to(device)
            target = F.interpolate(spect, size=(40, 48, 40), mode='trilinear', align_corners=False)
            sampled = diffusion.sample(ema, mri, label)
            recon = sampled.cpu().numpy().squeeze().astype(np.float32)
            target = target.cpu().numpy().squeeze().astype(np.float32)
            data_range = max(target.max() - target.min(), 1e-8)
            psnr_sum += psnr(target, recon, data_range=data_range)
            min_side = min(target.shape)
            win = 7 if min_side >= 7 else (min_side if min_side % 2 == 1 else max(min_side-1,1))
            ssim_sum += ssim(target, recon, data_range=data_range, win_size=win)
    psnr_avg = psnr_sum / len(val_loader)
    ssim_avg = ssim_sum / len(val_loader)
    print(f"Epoch {epoch+1}: loss {epoch_loss/len(train_loader):.4f}, PSNR {psnr_avg:.3f}, SSIM {ssim_avg:.4f}")

    if ssim_avg > best_ssim:
        best_ssim = ssim_avg
        patience = 0
        torch.save({
            'state_dict': ema.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch+1,
            'psnr': psnr_avg,
            'ssim': ssim_avg,
        }, os.path.join(cfg['save_dir'], 'unet_image_best.pth'))
    else:
        patience += 1
        if patience >= cfg['patience']:
            print(f"Early stopping en epoch {epoch+1}")
            break