# Segmentation 3D — Notebook

Ce notebook contient un pipeline complet et commenté pour segmenter des tumeurs 3D à partir de volumes `.nii.gz`.

**Organisation**:
- Entraînement: Timepoint_1 + Timepoint_2
- Test: Timepoint_3

> Adapté pour usage avec MedSegDiff-V2 (si disponible) ou fallback UNet3D.

In [1]:

# %%
# Setup - imports et paramètres
import os
import numpy as np
import nibabel as nib
from tqdm import tqdm
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# Séquences et nom de masque
SEQUENCES = ["t1c", "t2f", "t2w"]
MASK_NAME = "tumorMask"

print('PyTorch version:', torch.__version__)


PyTorch version: 2.8.0+cu128


## Fonctions utilitaires

Chargement/sauvegarde NIfTI et fonctions de visualisation.

In [2]:

# %%
def load_nifti(path):
    """Charge un fichier NIfTI et retourne (data_float32, affine)."""
    img = nib.load(path)
    return img.get_fdata(dtype=np.float32), img.affine

def save_nifti(data, affine, out_path):
    """Sauvegarde data (numpy) au format NIfTI en gardant affine."""
    img = nib.Nifti1Image(data.astype(np.float32), affine)
    nib.save(img, out_path)

def find_sequence_file(tp_dir, seq_key):
    """Retourne le chemin d'un fichier dans tp_dir contenant seq_key dans son nom (ou None)."""
    for f in os.listdir(tp_dir):
        if seq_key in f and (f.endswith('.nii') or f.endswith('.nii.gz')):
            return os.path.join(tp_dir, f)
    return None

def show_slices(imgs, masks=None, slice_idx=None, figsize=(12,6)):
    """Affiche des slices (axial) côte-à-côte pour chaque canal et le masque si fourni."""
    # imgs : numpy array (C, D, H, W) or (D,H,W)
    if imgs.ndim == 4:
        C, D, H, W = imgs.shape
    elif imgs.ndim == 3:
        C = 1
        D, H, W = imgs.shape
        imgs = imgs[None, ...]
    if slice_idx is None:
        slice_idx = D // 2
    ncols = C + (1 if masks is not None else 0)
    fig, axes = plt.subplots(1, ncols, figsize=figsize)
    for c in range(C):
        ax = axes[c] if ncols>1 else axes
        im = imgs[c, slice_idx, :, :]
        ax.imshow(im.T, cmap='gray', origin='lower')
        ax.set_title(f'Canal {c} - slice {slice_idx}')
        ax.axis('off')
    if masks is not None:
        m = masks if masks.ndim==3 else masks[0]
        ax = axes[-1]
        ax.imshow(m[slice_idx].T, cmap='gray', origin='lower')
        ax.set_title('Masque')
        ax.axis('off')
    plt.show()


## Dataset PyTorch

Classe Dataset qui charge les échantillons en respectant la règle : utiliser Timepoint_1 et Timepoint_2 pour l'entraînement et Timepoint_3 pour le test.

In [3]:

# %%
class MRIDataset3D(Dataset):
    """Dataset PyTorch pour volumes multi-séquences.
    Retourne: (image_tensor, mask_tensor, affine, (patient, timepoint))
    """
    def __init__(self, dataset_dir, patients, timepoints, sequences=SEQUENCES):
        self.dataset_dir = dataset_dir
        self.patients = patients
        self.timepoints = timepoints
        self.sequences = sequences
        self.samples = []
        for patient in patients:
            for tp in timepoints:
                tp_path = os.path.join(dataset_dir, patient, tp)
                if os.path.isdir(tp_path):
                    # Vérifier que toutes les séquences et le masque existent
                    ok = True
                    for s in sequences + [MASK_NAME]:
                        if find_sequence_file(tp_path, s) is None:
                            ok = False
                            break
                    if ok:
                        self.samples.append((patient, tp))

    def __len__(self):
        return len(self.samples)

    def pad_to_multiple(vol, multiple=8):
        """Pad numpy array 3D or 4D (C,D,H,W) so that D,H,W are multiples of 'multiple'."""
        if vol.ndim == 4:
            C, D, H, W = vol.shape
            pad_D = (multiple - D % multiple) if D % multiple != 0 else 0
            pad_H = (multiple - H % multiple) if H % multiple != 0 else 0
            pad_W = (multiple - W % multiple) if W % multiple != 0 else 0
            return np.pad(vol, ((0,0),(0,pad_D),(0,pad_H),(0,pad_W)), mode='constant'), (pad_D,pad_H,pad_W)
        elif vol.ndim == 3:
            D, H, W = vol.shape
            pad_D = (multiple - D % multiple) if D % multiple != 0 else 0
            pad_H = (multiple - H % multiple) if H % multiple != 0 else 0
            pad_W = (multiple - W % multiple) if W % multiple != 0 else 0
            return np.pad(vol, ((0,pad_D),(0,pad_H),(0,pad_W)), mode='constant'), (pad_D,pad_H,pad_W)
        else:
            return vol, (0,0,0)

    def __getitem__(self, idx):
        patient, tp = self.samples[idx]
        tp_path = os.path.join(self.dataset_dir, patient, tp)
        channels = []
        affine = None
        for seq in self.sequences:
            p = find_sequence_file(tp_path, seq)
            data, a = load_nifti(p)
            if affine is None:
                affine = a
            # z-score normalisation
            data = (data - np.mean(data)) / (np.std(data) + 1e-8)
            channels.append(data)
        mask_p = find_sequence_file(tp_path, MASK_NAME)
        mask, _ = load_nifti(mask_p)
        mask = (mask > 0).astype(np.float32)
        vol = np.stack(channels, axis=0)  # (C, D, H, W)
        mask = mask[None, ...]            # (1, D, H, W)
        vol, _ = pad_to_multiple(vol, multiple=8)
        mask, _ = pad_to_multiple(mask, multiple=8)
        return torch.from_numpy(vol.astype(np.float32)), torch.from_numpy(mask.astype(np.float32)), affine, (patient, tp)


## Modèle UNet3D (fallback)

Architecture UNet3D simple commentée.

In [4]:

# %%
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1), nn.BatchNorm3d(out_ch), nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1), nn.BatchNorm3d(out_ch), nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNet3D(nn.Module):
    def __init__(self, in_ch=3, out_ch=1, base_filters=16):
        super().__init__()
        f = base_filters
        self.enc1 = DoubleConv(in_ch, f)
        self.pool = nn.MaxPool3d(2)
        self.enc2 = DoubleConv(f, f*2)
        self.enc3 = DoubleConv(f*2, f*4)
        self.bottleneck = DoubleConv(f*4, f*8)
        self.up3 = nn.ConvTranspose3d(f*8, f*4, 2, stride=2)
        self.dec3 = DoubleConv(f*8, f*4)
        self.up2 = nn.ConvTranspose3d(f*4, f*2, 2, stride=2)
        self.dec2 = DoubleConv(f*4, f*2)
        self.up1 = nn.ConvTranspose3d(f*2, f, 2, stride=2)
        self.dec1 = DoubleConv(f*2, f)
        self.outc = nn.Conv3d(f, out_ch, 1)
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = self.up3(b)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))
        return self.outc(d1)

# Dice loss
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
    def forward(self, logits, target):
        probs = torch.sigmoid(logits)
        inter = (probs * target).sum(dim=[1,2,3,4])
        denom = probs.sum(dim=[1,2,3,4]) + target.sum(dim=[1,2,3,4])
        return 1 - ((2*inter + self.eps) / (denom + self.eps)).mean()

def dice_coef(logits, target, eps=1e-6):
    pred = (torch.sigmoid(logits) > 0.5).float()
    inter = (pred * target).sum(dim=[1,2,3,4])
    denom = pred.sum(dim=[1,2,3,4]) + target.sum(dim=[1,2,3,4])
    return ((2*inter + eps) / (denom + eps)).mean().item()


## Entraînement

Boucles d'entraînement et d'évaluation. Entraîne sur Timepoint_1+Timepoint_2 et évalue sur Timepoint_3.

In [5]:

# %%
def train_epoch(model, loader, device, optimizer, criterion):
    model.train()
    total_loss = 0.0
    for imgs, masks, _, _ in tqdm(loader, desc='Train'):
        imgs = imgs.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def eval_epoch(model, loader, device, criterion):
    model.eval()
    total_loss = 0.0
    total_dice = 0.0
    with torch.no_grad():
        for imgs, masks, _, _ in tqdm(loader, desc='Eval'):
            imgs = imgs.to(device)
            masks = masks.to(device)
            logits = model(imgs)
            loss = criterion(logits, masks)
            total_loss += loss.item()
            total_dice += dice_coef(logits, masks)
    return total_loss / len(loader), total_dice / len(loader)

def inference_and_save(model, loader, device, out_dir):
    model.eval()
    os.makedirs(out_dir, exist_ok=True)
    with torch.no_grad():
        for imgs, masks, affines, metas in tqdm(loader, desc='Infer'):
            imgs = imgs.to(device)
            logits = model(imgs)
            probs = torch.sigmoid(logits).cpu().numpy()
            for i in range(probs.shape[0]):
                pred = (probs[i,0] > 0.5).astype(np.uint8)
                patient, tp = metas[i]
                out_path = os.path.join(out_dir, patient, tp, 'pred_mask.nii.gz')
                os.makedirs(os.path.dirname(out_path), exist_ok=True)
                save_nifti(pred, affines[i], out_path)


## Exemple d'exécution

Modifie `DATASET_DIR` et exécute cette cellule pour lancer l'entraînement (GPU recommandé).

In [6]:

# %%
# Paramètres à éditer par l'utilisateur
DATASET_DIR = '/home/perfect/Documents/GitHub/projet-AI/data_filter'   # <-- change
OUTPUT_DIR = '/home/perfect/Documents/GitHub/projet-AI/data_segmentation'      # <-- change
BATCH_SIZE = 1
EPOCHS = 20
LR = 1e-4

# Lister patients ayant >=3 timepoints
patients = []
for p in sorted(os.listdir(DATASET_DIR)):
    ppath = os.path.join(DATASET_DIR, p)
    if not os.path.isdir(ppath): continue
    tps = sorted([d for d in os.listdir(ppath) if os.path.isdir(os.path.join(ppath, d))])
    if len(tps) >= 3:
        patients.append(p)
print(f'Patients retenus: {len(patients)}')

# Datasets: train = TP1+TP2, test = TP3
train_ds = MRIDataset3D(DATASET_DIR, patients, ['Timepoint_1', 'Timepoint_2'])
test_ds  = MRIDataset3D(DATASET_DIR, patients, ['Timepoint_3'])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_ds,  batch_size=1, shuffle=False, num_workers=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet3D(in_ch=len(SEQUENCES), out_ch=1).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = DiceLoss()

best_dice = 0.0
for epoch in range(1, EPOCHS+1):
    train_loss = train_epoch(model, train_loader, device, optimizer, criterion)
    val_loss, val_dice = eval_epoch(model, test_loader, device, criterion)
    print(f'Epoch {epoch}/{EPOCHS} - Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}')
    # Save best model
    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'best_model.pt'))
        print('Best model saved.')

# Inference and save
inference_and_save(model, test_loader, device, os.path.join(OUTPUT_DIR, 'predictions'))
print('Done. Predictions saved to', os.path.join(OUTPUT_DIR, 'predictions'))


Patients retenus: 110


Train:   0%|          | 0/195 [00:04<?, ?it/s]


NameError: Caught NameError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/perfect/Documents/GitHub/projet-AI/env/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/perfect/Documents/GitHub/projet-AI/env/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/perfect/Documents/GitHub/projet-AI/env/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_3686504/4224057586.py", line 63, in __getitem__
    vol, _ = pad_to_multiple(vol, multiple=8)
NameError: name 'pad_to_multiple' is not defined


## Visualisation des résultats

Affiche des slices (axial) des volumes et des masques (vérité terrain et prédiction).

In [None]:

# %%
# Exemple pour visualiser un patient/timepoint précis (après inférence)
def visualize_patient_timepoint(data_root, patient, timepoint, pred_root=None, slice_idx=None):
    tp_path = os.path.join(data_root, patient, timepoint)
    channels = []
    for seq in SEQUENCES:
        f = find_sequence_file(tp_path, seq)
        d, _ = load_nifti(f)
        channels.append(d)
    vol = np.stack(channels, axis=0)
    mask_f = find_sequence_file(tp_path, MASK_NAME)
    mask, _ = load_nifti(mask_f)
    pred_mask = None
    if pred_root:
        pred_p = os.path.join(pred_root, patient, timepoint, 'pred_mask.nii.gz')
        if os.path.exists(pred_p):
            pred_mask, _ = load_nifti(pred_p)
    show_slices(vol, masks=pred_mask if pred_mask is not None else mask, slice_idx=slice_idx)

# Usage example (modifier les chemins et ids)
# visualize_patient_timepoint(DATASET_DIR, 'PatientID_0162', 'Timepoint_3', pred_root=os.path.join(OUTPUT_DIR,'predictions'))
