# Partie 2 — Denoising AutoEncoder + ConvLSTM (t2f-only)
Notebook **prêt à exécuter** :
- t2f-only (1 canal) avec **prétraitements** (z-score + padding x8)
- **AutoEncoder 3D** pour débruitage (Stage A)
- **ConvLSTM 3D** sur latents (TP1→TP2) pour prédire **le masque TP3** (Stage B)
- Sauvegarde des masques prédits dans `OUTPUT_DIR/output_masks/Patient/Timepoint/pred_mask.nii.gz`


In [None]:
# Imports
import os, math, random, json
from typing import List, Optional, Tuple

import numpy as np
import nibabel as nib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import torch.optim as optim
torch.backends.cudnn.benchmark = True


In [None]:
# === Utils NIfTI & padding ===
def pad_to_multiple(volume, multiple=8, mode='constant', value=0):
    if volume.ndim == 3:
        shape = volume.shape
        pad_widths = []
        for dim in shape:
            r = dim % multiple
            pad_widths.append((0, 0 if r == 0 else multiple - r))
        return np.pad(volume, pad_widths, mode=mode, constant_values=value), pad_widths
    elif volume.ndim == 4:
        shape = volume.shape[1:]
        pad_widths = [(0,0)]
        for dim in shape:
            r = dim % multiple
            pad_widths.append((0, 0 if r == 0 else multiple - r))
        return np.pad(volume, pad_widths, mode=mode, constant_values=value), pad_widths
    else:
        raise ValueError("Volume must be 3D or 4D")

def load_nifti(path):
    img = nib.load(path)
    data = img.get_fdata()
    return data, img.affine

SEQ_KEYWORDS  = ["t2f_processed", "t2f", "flair", "t2_flair", "t2fla"]
MASK_KEYWORDS = ["tumorMask", "mask", "seg", "label"]

def is_nifti(fname: str) -> bool:
    low = fname.lower()
    return low.endswith(".nii") or low.endswith(".nii.gz")

def find_first_matching_file(folder: str, keywords: List[str]):
    keys = [k.lower() for k in keywords]
    for fname in sorted(os.listdir(folder)):
        if not is_nifti(fname):
            continue
        low = fname.lower()
        if any(k in low for k in keys):
            return os.path.join(folder, fname)
    return None


In [None]:
# === Data helpers ===
def load_volume_and_mask(dataset_dir: str, patient: str, tp: str):
    tp_path = os.path.join(dataset_dir, patient, tp)
    t2f_path = find_first_matching_file(tp_path, SEQ_KEYWORDS)
    if t2f_path is None: raise FileNotFoundError(f"[t2f introuvable] {patient}/{tp}")
    vol, affine = load_nifti(t2f_path)

    mask_path = find_first_matching_file(tp_path, MASK_KEYWORDS)
    if mask_path is None: raise FileNotFoundError(f"[mask introuvable] {patient}/{tp}")
    mask, _ = load_nifti(mask_path)

    vol = vol.astype(np.float32, copy=False)
    vol = (vol - vol.mean()) / (vol.std() + 1e-8)
    vol = vol[None, ...]
    mask = (mask > 0).astype(np.float32)[None, ...]

    vol, _  = pad_to_multiple(vol, multiple=8)
    mask, _ = pad_to_multiple(mask, multiple=8)
    affine = np.asarray(affine, dtype=np.float64)
    return vol, mask, affine

class AEDataset(Dataset):
    def __init__(self, dataset_dir: str, pairs_for_ae: List[Tuple[str,str]], noise_std=(0.05, 0.15)):
        self.dataset_dir = dataset_dir
        self.pairs = list(pairs_for_ae)
        self.noise_std = noise_std
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        patient, tp = self.pairs[idx]
        vol, _, _ = load_volume_and_mask(self.dataset_dir, patient, tp)
        clean = vol
        sigma = np.random.uniform(self.noise_std[0], self.noise_std[1])
        noisy = clean + sigma * np.random.randn(*clean.shape).astype(np.float32)
        return torch.from_numpy(noisy), torch.from_numpy(clean)

class SequenceDataset(Dataset):
    def __init__(self, dataset_dir: str, triples: List[Tuple[str,str,str]]):
        self.dataset_dir = dataset_dir
        self.triples = list(triples)  # (patient, tp1, tp2, tp3)
    def __len__(self): return len(self.triples)
    def __getitem__(self, idx):
        patient, tp1, tp2, tp3 = self.triples[idx]
        vol1, _, _ = load_volume_and_mask(self.dataset_dir, patient, tp1)
        vol2, _, _ = load_volume_and_mask(self.dataset_dir, patient, tp2)
        _, mask3, affine3 = load_volume_and_mask(self.dataset_dir, patient, tp3)
        x1 = torch.from_numpy(vol1)
        x2 = torch.from_numpy(vol2)
        y3 = torch.from_numpy(mask3)
        return (x1, x2, y3, affine3, (patient, tp3))


In [None]:
# === Models ===
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.InstanceNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.InstanceNorm3d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.block(x)

class Encoder3D(nn.Module):
    def __init__(self, in_ch=1, base=16):
        super().__init__()
        self.enc1 = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool3d(2)
        self.enc2 = DoubleConv(base, base*2)
        self.pool2 = nn.MaxPool3d(2)
        self.bottleneck = DoubleConv(base*2, base*4)
        self.out_channels = base*4
    def forward(self, x):
        x = self.enc1(x); x = self.pool1(x)
        x = self.enc2(x); x = self.pool2(x)
        z = self.bottleneck(x)
        return z

class Decoder3D(nn.Module):
    def __init__(self, out_ch=1, base=16):
        super().__init__()
        self.up2 = nn.ConvTranspose3d(base*4, base*2, 2, stride=2)
        self.dec2 = DoubleConv(base*2, base*2)
        self.up1 = nn.ConvTranspose3d(base*2, base, 2, stride=2)
        self.dec1 = DoubleConv(base, base)
        self.out_conv = nn.Conv3d(base, out_ch, 1)
    def forward(self, z):
        x = self.up2(z); x = self.dec2(x)
        x = self.up1(x); x = self.dec1(x)
        y = self.out_conv(x)
        return y

class DAE3D(nn.Module):
    def __init__(self, in_ch=1, base=16):
        super().__init__()
        self.encoder = Encoder3D(in_ch=in_ch, base=base)
        self.decoder = Decoder3D(out_ch=in_ch, base=base)
    def forward(self, x):
        z = self.encoder(x); rec = self.decoder(z); return rec, z

class ConvLSTMCell3D(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv3d(input_dim + hidden_dim, 4*hidden_dim, kernel_size, padding=padding)
        self.hidden_dim = hidden_dim
    def forward(self, x, h_prev, c_prev):
        gates = self.conv(torch.cat([x, h_prev], dim=1))
        i, f, g, o = torch.chunk(gates, 4, dim=1)
        i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
        g = torch.tanh(g)
        c = f * c_prev + i * g
        h = o * torch.tanh(c)
        return h, c

class TemporalSegNet(nn.Module):
    def __init__(self, encoder: Encoder3D, hidden_dim=None, base=16):
        super().__init__()
        self.encoder = encoder
        enc_out = encoder.out_channels
        hidden_dim = hidden_dim or enc_out
        self.lstm = ConvLSTMCell3D(input_dim=enc_out, hidden_dim=hidden_dim, kernel_size=3)
        self.decoder = Decoder3D(out_ch=1, base=base)
    def forward(self, x1, x2):
        z1 = self.encoder(x1); z2 = self.encoder(x2)
        h = torch.zeros_like(z1); c = torch.zeros_like(z1)
        h, c = self.lstm(z1, h, c)
        h, c = self.lstm(z2, h, c)
        logits = self.decoder(h)
        return logits


In [None]:
# === Loss & metric ===
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6): super().__init__(); self.eps = eps
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        inter = (probs * targets).sum(dim=(2,3,4))
        union = probs.sum(dim=(2,3,4)) + targets.sum(dim=(2,3,4))
        dice = (2*inter + self.eps) / (union + self.eps)
        return 1 - dice.mean()

def dice_coefficient_from_logits(logits, targets, thr=0.5, eps=1e-6):
    probs = torch.sigmoid(logits); preds = (probs > thr).float()
    inter = (preds * targets).sum(dim=(2,3,4))
    union = preds.sum(dim=(2,3,4)) + targets.sum(dim=(2,3,4))
    return ((2*inter + eps) / (union + eps)).mean().item()


In [None]:
# === Training loops ===
def train_ae(model, loader, device, optimizer, epochs=5):
    model.train(); l1 = nn.L1Loss(); best_loss = float('inf')
    for ep in range(1, epochs+1):
        running = 0.0
        for noisy, clean in tqdm(loader, desc=f"AE Train {ep}/{epochs}"):
            noisy, clean = noisy.to(device, non_blocking=True), clean.to(device, non_blocking=True)
            optimizer.zero_grad(); rec, _ = model(noisy); loss = l1(rec, clean)
            loss.backward(); optimizer.step(); running += loss.item() * noisy.size(0)
        avg = running / max(1, len(loader.dataset))
        print(f"[AE] Epoch {ep}/{epochs} - L1: {avg:.4f}")
        if avg < best_loss: best_loss = avg
    return best_loss

def train_temporal(model, loader, device, optimizer, epochs=10, lambda_bce=0.5):
    model.train(); dice_loss = DiceLoss(); bce = nn.BCEWithLogitsLoss()
    for ep in range(1, epochs+1):
        running = 0.0
        for x1, x2, y3, _, _ in tqdm(loader, desc=f"Seg Train {ep}/{epochs}"):
            x1, x2, y3 = x1.to(device, non_blocking=True), x2.to(device, non_blocking=True), y3.to(device, non_blocking=True)
            optimizer.zero_grad(); logits = model(x1, x2)
            loss = dice_loss(logits, y3) + lambda_bce * bce(logits, y3)
            loss.backward(); optimizer.step(); running += loss.item() * x1.size(0)
        avg = running / max(1, len(loader.dataset))
        print(f"[SEG] Epoch {ep}/{epochs} - Loss: {avg:.4f}")


In [None]:

# === Inférence robuste & sauvegarde des masques (mise à jour) ===
def inference_and_save_temporal(model, loader, device, out_dir):
    import os, numpy as np, torch, nibabel as nib
    from tqdm import tqdm

    os.makedirs(out_dir, exist_ok=True)
    model.eval()

    def get_meta_pair(metas, i, B):
        if isinstance(metas, (list, tuple)) and len(metas) == 2 and \
           all(isinstance(x, (list, tuple, np.ndarray)) for x in metas):
            patients, tps = metas
            if len(patients) != len(tps):
                raise ValueError(f"metas incohérent: len(patients)={len(patients)} != len(tps)={len(tps)}")
            if i >= len(patients):
                raise IndexError(f"i={i} hors plage pour metas pair-of-lists (taille {len(patients)}), batch B={B}")
            return patients[i], tps[i]
        if isinstance(metas, (list, tuple)) and len(metas) > 0 and isinstance(metas[0], (list, tuple)) and len(metas[0]) == 2:
            if i >= len(metas):
                raise IndexError(f"i={i} hors plage pour metas list-of-pairs (taille {len(metas)}), batch B={B}")
            return metas[i][0], metas[i][1]
        if isinstance(metas, (list, tuple)) and len(metas) == 2 and not any(isinstance(x, (list, tuple)) for x in metas):
            return metas[0], metas[1]
        if isinstance(metas, (list, tuple)) and len(metas) > 0 and isinstance(metas[0], str):
            who = metas[i if i < len(metas) else -1]; s = who.replace('\\', '/').split('/')
            if len(s) >= 2: return s[-2], s[-1]
        if isinstance(metas, str):
            s = metas.replace('\\', '/').split('/')
            if len(s) >= 2: return s[-2], s[-1]
        raise ValueError(f"Format metas non supporté: {type(metas)} | exemple={repr(metas)[:200]}")

    def get_affine_i(affines, i):
        if isinstance(affines, torch.Tensor):
            a = affines[i] if affines.ndim == 3 else affines
            a = a.detach().cpu().numpy()
        elif isinstance(affines, np.ndarray):
            a = affines[i] if affines.ndim == 3 else affines
        elif isinstance(affines, (list, tuple)):
            a = affines[i] if len(affines) > 1 else affines[0]; a = np.asarray(a)
        else:
            a = np.asarray(affines)
        a = np.asarray(a, dtype=np.float64)
        if a.shape != (4, 4):
            raise ValueError(f"Affine devrait être (4,4), reçu {a.shape}")
        return a

    with torch.no_grad():
        for batch in tqdm(loader, desc='Inference'):
            x1, x2, y3, affines, metas = batch
            x1 = x1.to(device, non_blocking=True)
            x2 = x2.to(device, non_blocking=True)
            logits = model(x1, x2)
            probs  = torch.sigmoid(logits).cpu().numpy()

            B = probs.shape[0]
            for i in range(B):
                pred = (probs[i, 0] > 0.5).astype(np.uint8)
                patient, tp3 = get_meta_pair(metas, i, B)
                affine = get_affine_i(affines, i)
                out_path = os.path.join(out_dir, patient, tp3, 'pred_mask.nii.gz')
                os.makedirs(os.path.dirname(out_path), exist_ok=True)
                nib.save(nib.Nifti1Image(pred.astype(np.uint8), affine), out_path)


In [None]:
# === Paramètres & préparation des DataLoaders ===
DATASET_DIR = r'/home/perfect/Documents/GitHub/projet-AI/data_t2f'
OUTPUT_DIR  = r'/home/perfect/Documents/GitHub/projet-AI/partie2_outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)

BATCH_SIZE_AE  = 1
BATCH_SIZE_SEG = 1
EPOCHS_AE = 3
EPOCHS_SEG = 5
LR_AE = 1e-4
LR_SEG = 1e-4
BASE = 16

patients_tps = {}
for p in sorted(os.listdir(DATASET_DIR)):
    ppath = os.path.join(DATASET_DIR, p)
    if not os.path.isdir(ppath): continue
    tps = sorted([d for d in os.listdir(ppath) if os.path.isdir(os.path.join(ppath, d))])
    if len(tps) >= 3: patients_tps[p] = tps

triples = []
for p, tps in patients_tps.items():
    if len(tps) >= 3: triples.append((p, tps[0], tps[1], tps[2]))

patients = sorted(list(patients_tps.keys()))
n_train = max(1, int(0.8 * len(patients)))
train_patients = set(patients[:n_train])
test_patients  = set(patients[n_train:])

train_triples = [t for t in triples if t[0] in train_patients]
test_triples  = [t for t in triples if t[0] in test_patients]

ae_pairs = []
for p in train_patients:
    tps = patients_tps[p]
    if len(tps) >= 2:
        ae_pairs.append((p, tps[0])); ae_pairs.append((p, tps[1]))

ae_loader = DataLoader(AEDataset(DATASET_DIR, ae_pairs, noise_std=(0.05, 0.15)),
                       batch_size=BATCH_SIZE_AE, shuffle=True,
                       num_workers=2, pin_memory=torch.cuda.is_available(), persistent_workers=True)

train_seq_loader = DataLoader(SequenceDataset(DATASET_DIR, train_triples),
                              batch_size=BATCH_SIZE_SEG, shuffle=True,
                              num_workers=2, pin_memory=torch.cuda.is_available(), persistent_workers=True)
test_seq_loader  = DataLoader(SequenceDataset(DATASET_DIR, test_triples),
                              batch_size=1, shuffle=False,
                              num_workers=2, pin_memory=torch.cuda.is_available(), persistent_workers=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
print(f"Train triples: {len(train_triples)} | Test triples: {len(test_triples)}")


In [None]:
# === Stage A : entraînement AutoEncoder 3D ===
dae = DAE3D(in_ch=1, base=BASE).to(device)
opt_ae = optim.Adam(dae.parameters(), lr=LR_AE)
best_l1 = train_ae(dae, ae_loader, device, opt_ae, epochs=EPOCHS_AE)
print("[AE] Meilleure L1:", best_l1)
AE_CKPT = os.path.join(OUTPUT_DIR, "dae_best.pt")
torch.save(dae.state_dict(), AE_CKPT)
print("[AE] Checkpoint sauvegardé:", AE_CKPT)


In [None]:
# === Stage B : ConvLSTM pour la prédiction TP3 ===
encoder = Encoder3D(in_ch=1, base=BASE).to(device)
encoder.load_state_dict(dae.encoder.state_dict())
for p in encoder.parameters(): p.requires_grad = False

model = TemporalSegNet(encoder=encoder, hidden_dim=encoder.out_channels, base=BASE).to(device)
opt_seg = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR_SEG)
train_temporal(model, train_seq_loader, device, opt_seg, epochs=EPOCHS_SEG, lambda_bce=0.5)

SEG_CKPT = os.path.join(OUTPUT_DIR, "temporal_seg_best.pt")
torch.save(model.state_dict(), SEG_CKPT)
print("[SEG] Checkpoint sauvegardé:", SEG_CKPT)


In [None]:
# === Évaluation Dice (test) ===
model.eval(); dice_list = []
with torch.no_grad():
    for x1, x2, y3, _, _ in tqdm(test_seq_loader, desc="Test Eval"):
        x1, x2, y3 = x1.to(device, non_blocking=True), x2.to(device, non_blocking=True), y3.to(device, non_blocking=True)
        logits = model(x1, x2)
        dice_list.append(dice_coefficient_from_logits(logits, y3, thr=0.5))
mean_dice = float(np.mean(dice_list)) if dice_list else float('nan')
print(f"[TEST] Mean Dice (TP3): {mean_dice:.4f} over {len(dice_list)} samples")


In [None]:
# === Inférence & sauvegarde dans output_masks/ ===
OUT_MASKS_DIR = os.path.join(OUTPUT_DIR, "output_masks")
os.makedirs(OUT_MASKS_DIR, exist_ok=True)
# model.load_state_dict(torch.load(SEG_CKPT, map_location=device))  # si besoin
inference_and_save_temporal(model, test_seq_loader, device, OUT_MASKS_DIR)
print("[OK] Masques prédits sauvés dans:", OUT_MASKS_DIR)
