
# DDPM-based Segmentation (2D slice) — Notebook

**But**: segmentation de tumeur avec un modèle de diffusion conditionné sur images multi-modalité (T1c, T2f, T2w).

**Ce notebook fournit** :
- un loader 2D (slices) depuis `RESULTS_DIR` (prétraité)
- un UNet simple conditionné par l'image + bruit de masque
- utilitaires diffusion (q_sample) et une fonction d'échantillonnage simple
- boucle d'entraînement **squelette** (adapter batch_size / epochs)
- fonctions de visualisation (overlay mask)

> **Important** : ce notebook est un point de départ pédagogique. Pour usage réel, adapte la gestion GPU, la formule exacte de reverse DDPM, les normalisations et augmente le dataset.


In [2]:
# Installer dépendances si nécessaire (décommente si besoin)
!pip install nibabel torch torchvision tqdm matplotlib numpy


Collecting torch
  Using cached torch-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl (888.0 MB)
Collecting torchvision
  Using cached torchvision-0.23.0-cp310-cp310-manylinux_2_28_x86_64.whl (8.6 MB)
Collecting tqdm
  Using cached tqdm-4.67.1-py3-none-any.whl (78 kB)
Collecting nvidia-nvjitlink-cu12==12.8.93
  Downloading nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.3/39.3 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting nvidia-nvtx-cu12==12.8.90
  Using cached nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89 kB)
Collecting nvidia-cublas-cu12==12.8.4.1
  Downloading nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl (594.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m594.3/594.3 MB[0m [31m881.3 kB/s[0m eta [36m0:00:00[0m00:01[0m00:05[0m
[?25hCollecting nvidia-cuspars

In [None]:
# Paramètres principaux - MODIFIE LES CHEMINS AVANT D'EXÉCUTER
RESULTS_DIR = "/home/perfect/Documents/GitHub/projet-AI/data_filter"  # dossier contenant PatientID/Timepoint/*.nii.gz
OUTPUT_DIR = "/home/perfect/Documents/GitHub/projet-AI/data_segmentation"
os.makedirs(OUTPUT_DIR, exist_ok=True)

SEQUENCES = ['t1c', 't2f', 't2w']
MASK_NAME = 'tumorMask'

# Entraînement - valeurs par défaut pour test rapide
DEVICE = 'cuda' if __import__('torch').cuda.is_available() else 'cpu'
EPOCHS = 2
BATCH_SIZE = 4
LR = 2e-4
TIMESTEPS = 200  # réduit pour tests


PermissionError: [Errno 13] Permission denied: '/chemin'

In [1]:
import os, math, random
from pathlib import Path
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

print('PyTorch device:', DEVICE)


NameError: name 'DEVICE' is not defined

In [None]:
# --- Diffusion utilities (basic) ---
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=2e-2):
    return torch.linspace(beta_start, beta_end, timesteps)

class Diffusion:
    def __init__(self, timesteps=1000, device='cpu'):
        self.timesteps = timesteps
        self.device = device
        betas = linear_beta_schedule(timesteps).to(device)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = torch.cat([torch.tensor([1.], device=device), alphas_cumprod[:-1]])
        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = alphas_cumprod_prev

        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

    def q_sample(self, x_start, t, noise=None):
        # x_start: (B,1,H,W)
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_acp = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
        sqrt_omacp = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
        return sqrt_acp * x_start + sqrt_omacp * noise, noise

# small helper to get scalar tensors
def t_to_device(t, device):
    return t.to(device).long()


In [None]:
class SlicesDataset(Dataset):
    def __init__(self, results_dir, sequences=['t1c','t2f','t2w'], mask_name='tumorMask', transform=None):
        self.results_dir = Path(results_dir)
        self.sequences = sequences
        self.mask_name = mask_name
        self.transform = transform
        self.index = []  # list of tuples (tp_path, z)

        # build index
        for patient in sorted(os.listdir(results_dir)):
            pth = Path(results_dir) / patient
            if not pth.is_dir(): continue
            for tp in sorted(os.listdir(pth)):
                tp_path = pth / tp
                if not tp_path.is_dir(): continue
                # check files exist
                seq_files = {s: list(tp_path.glob(f'*{s}*.nii*')) for s in sequences}
                mask_files = list(tp_path.glob(f'*{mask_name}*.nii*'))
                if any(len(seq_files[s])==0 for s in sequences) or len(mask_files)==0:
                    continue
                # load one modality to know shape
                sample_img = nib.load(str(seq_files[sequences[0]][0])).get_fdata()
                H,W,D = sample_img.shape
                for z in range(D):
                    self.index.append((str(tp_path), int(z)))

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        tp_path, z = self.index[idx]
        tp_path = Path(tp_path)
        seq_imgs = {}
        for s in self.sequences:
            f = list(tp_path.glob(f'*{s}*.nii*'))[0]
            arr = nib.load(str(f)).get_fdata().astype(np.float32)
            # per-slice normalization (zero mean, unit std)
            arr = (arr - arr.mean()) / (arr.std() + 1e-8)
            seq_imgs[s] = arr
        mask_f = list(tp_path.glob(f'*{self.mask_name}*.nii*'))[0]
        mask = nib.load(str(mask_f)).get_fdata().astype(np.float32)
        in_slice = np.stack([seq_imgs[s][:,:,z] for s in self.sequences], axis=0)  # (C,H,W)
        mask_slice = (mask[:,:,z] > 0).astype(np.float32)[None,...]  # (1,H,W)
        return torch.from_numpy(in_slice).float(), torch.from_numpy(mask_slice).float()

# quick sanity
# ds = SlicesDataset(RESULTS_DIR)
# print('dataset size', len(ds))


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU()
        )
    def forward(self,x): return self.net(x)

class SimpleUNetCond(nn.Module):
    def __init__(self, in_ch=4, base=32):
        super().__init__()
        self.inc = DoubleConv(in_ch, base)
        self.down1 = DoubleConv(base, base*2)
        self.down2 = DoubleConv(base*2, base*4)
        self.up1 = DoubleConv(base*4, base*2)
        self.up2 = DoubleConv(base*2, base)
        self.outc = nn.Conv2d(base, 1, 1)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(self.pool(x1))
        x3 = self.down2(self.pool(x2))
        u1 = F.interpolate(x3, scale_factor=2, mode='nearest')
        u1 = self.up1(u1)
        u2 = F.interpolate(u1, scale_factor=2, mode='nearest')
        u2 = self.up2(u2)
        out = self.outc(u2 + x1)
        return out

# model = SimpleUNetCond(in_ch=4).to(DEVICE)
# print(model)


In [None]:
def train(results_dir, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR, timesteps=TIMESTEPS, device=DEVICE):
    device = torch.device(device)
    ds = SlicesDataset(results_dir, sequences=SEQUENCES, mask_name=MASK_NAME)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    model = SimpleUNetCond(in_ch=4).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    diff = Diffusion(timesteps=timesteps, device=device)

    for epoch in range(epochs):
        pbar = tqdm(dl, desc=f'Epoch {epoch+1}/{epochs}')
        running_loss = 0.0
        for imgs, masks in pbar:
            imgs = imgs.to(device)       # (B,3,H,W)
            masks = masks.to(device)     # (B,1,H,W)
            b = imgs.shape[0]
            t = torch.randint(0, diff.timesteps, (b,), device=device)
            x_t, noise = diff.q_sample(masks, t)
            # prepare model input = concat(condition images + x_t noisy mask)
            inp = torch.cat([imgs, x_t], dim=1)  # (B,4,H,W)
            pred = model(inp)  # predict noise or x0; here we predict noise
            loss = F.mse_loss(pred, noise)
            optim.zero_grad()
            loss.backward()
            optim.step()
            running_loss += loss.item()
            pbar.set_postfix(loss=running_loss / (pbar.n+1e-8))
        # save checkpoint
        ckpt_path = os.path.join(OUTPUT_DIR, f'model_epoch_{epoch+1}.pth')
        torch.save(model.state_dict(), ckpt_path)
        print('Saved', ckpt_path)

    return model, diff

# WARNING: launching training may be long. Uncomment to run:
# model, diff = train(RESULTS_DIR)


In [None]:
@torch.no_grad()
def sample(model, diff, cond_imgs, steps=None, device=DEVICE):
    device = torch.device(device)
    model.eval()
    b, c, h, w = cond_imgs.shape
    steps = steps or diff.timesteps
    x = torch.randn((b,1,h,w), device=device)
    for i in reversed(range(diff.timesteps)):
        t = torch.full((b,), i, device=device, dtype=torch.long)
        inp = torch.cat([cond_imgs, x], dim=1)  # (B,4,H,W)
        pred_noise = model(inp)
        beta = diff.betas[i]
        alpha = diff.alphas[i]
        alpha_cum = diff.alphas_cumprod[i]
        # simplified posterior mean update (not exact ddpm)
        x = (1.0 / torch.sqrt(alpha)) * (x - (beta / torch.sqrt(1 - alpha_cum)) * pred_noise)
        if i > 0:
            x = x + torch.sqrt(beta) * torch.randn_like(x)
    # x is continuous mask; threshold
    return (x.clamp(-1,1) > 0).float()

def show_overlay(cond_img_np, mask_np, title=''):
    # cond_img_np: (3,H,W) numpy, mask_np: (1,H,W)
    t1c = cond_img_np[0]
    mask = mask_np[0]
    mid = mask.shape[1]//2
    plt.figure(figsize=(6,6))
    plt.imshow(t1c, cmap='gray')
    plt.imshow(np.ma.masked_where(mask==0, mask), cmap='autumn', alpha=0.5)
    plt.title(title)
    plt.axis('off')
    plt.show()


In [None]:
# === Exemple d'utilisation ===
# 1) Lancer l'entraînement (décommente si tu veux exécuter)
# model, diff = train(RESULTS_DIR, epochs=2)

# 2) Après entraînement charger un checkpoint et échantillonner sur un batch de validation
# device = torch.device(DEVICE)
# model = SimpleUNetCond(in_ch=4).to(device)
# model.load_state_dict(torch.load('/chemin/vers/output_ddpm/model_epoch_2.pth', map_location=device))
# ds = SlicesDataset(RESULTS_DIR)
# dl = DataLoader(ds, batch_size=2, shuffle=True)
# imgs, masks = next(iter(dl))
# imgs = imgs.to(device)
# pred_mask = sample(model, diff, imgs)
# show_overlay(imgs[0].cpu().numpy(), pred_mask[0].cpu().numpy(), title='Prediction overlay')



---

## Next steps / améliorations recommandées
- Remplacer l'update d'échantillonnage par la formule complète p(x_{t-1}|x_t) (voir DDPM paper).
- Ajouter time embedding injectée dans chaque bloc UNet.
- Utiliser augmentation (torchio / MONAI), early stopping, validation set.
- Pour 3D : entraînement patch-based 3D ou 2.5D avec stacks de slices.
- Post-traitement : seuillage, morpho, largest connected component.

Good luck — adapte les hyperparamètres et bon entraînement !
