# 1) Imports

#### PyTorch

In [4]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, OneCycleLR

In [5]:
print('PyTorch version', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

PyTorch version 2.9.1+cu128
Device: cuda


#### General

In [3]:
import os
import json
import math
import time
import random
from pathlib import Path
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from tqdm import tqdm

from sklearn.metrics import roc_auc_score

#### Monai

In [1]:
import monai
print("MONAI version:", monai.__version__)

MONAI version: 1.5.1


In [2]:
from monai.networks.nets import UNet as MonaiUNet
from monai.losses import DiceCELoss
from monai.utils import set_determinism
from monai.transforms import (
    Compose,
    RandFlipd,
    RandRotate90d,
    RandZoomd,
    RandAffined,
    RandGaussianNoised,
    RandShiftIntensityd,
    RandScaleIntensityd,
    EnsureTyped,
)

# 2) Parameters

#### 2.1) Directories

In [6]:
drive_root = Path("/home/usrs/hnoel/DRIVE")

# TRAIN = 20 images (21–40)
train_images_dir = drive_root / "training" / "images"
train_manual_dir = drive_root / "training" / "1st_manual"   # 21_manual1.gif, ...
train_fov_dir    = drive_root / "training" / "masks"        # 21_training_mask.tif, ...

# TEST = 20 images (01–20)
test_images_dir  = drive_root / "test" / "images"
test_manual_dir  = drive_root / "test" / "1st_manual"       # 01_manual1.gif, ...
test_fov_dir     = drive_root / "test" / "masks" 

#### 2.2) Hyperparameters

In [29]:
# Image sizes
IMG_HEIGHT = 584
IMG_WIDTH  = 565

# Batch size
BATCH_SIZE = 6

# Number of epochs
EPOCHS = 200

# Learning rate
LR = 1e-3

# Classes
NUM_CLASSES = 2
BACKGROUND_IDX = 0
VESSEL_CLASS = 1

#### 2.3) Seed

In [8]:
SEED = 42

# Python
random.seed(SEED)

# NumPy
np.random.seed(SEED)

# PyTorch
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Monai
set_determinism(seed=SEED)

print(f"Seed fixed to {SEED}")

Seed fixed to 42


#### 2.4) Save directories

In [30]:
model_name = "UNet_DRIVE_v1"

# Courbes, JSON, figures, etc.
data_dir = Path("/home/usrs/hnoel/Tohoku/Monai/UNet") / model_name
data_dir.mkdir(parents=True, exist_ok=True)

# Poids de modèles
models_dir = Path("/home/usrs/hnoel/MODELS/UNet") / model_name
models_dir.mkdir(parents=True, exist_ok=True)

# 3) DataLoaders

#### 3.1) Data augmentation

In [11]:
train_transforms = Compose([
    # --- Géométriques (image + label + fov) ---
    RandFlipd(keys=["image", "label", "fov"], prob=0.6, spatial_axis=0),
    RandFlipd(keys=["image", "label", "fov"], prob=0.5, spatial_axis=1),
    RandRotate90d(keys=["image", "label", "fov"], prob=0.4, max_k=3),

    RandZoomd(
        keys=["image", "label", "fov"],
        prob=0.3,
        min_zoom=0.9,
        max_zoom=1.1,
        mode=("bilinear", "nearest", "nearest"),  # image / label / fov
        keep_size=True,
    ),

    RandAffined(
        keys=["image", "label", "fov"],
        prob=0.2,
        rotate_range=(0.15, 0.15),      # ~8.5°
        shear_range=(0.05, 0.05),
        translate_range=(8, 8),
        scale_range=(0.05, 0.05),
        mode=("bilinear", "nearest", "nearest"),
        padding_mode="zeros",
    ),

    # --- Intensité (image seule) ---
    RandGaussianNoised(keys=["image"], prob=0.1, std=0.01),
    RandShiftIntensityd(keys=["image"], prob=0.1, offsets=0.1),
    RandScaleIntensityd(keys=["image"], prob=0.1, factors=0.1),

    # S'assure que tout est bien en tensors PyTorch
    EnsureTyped(keys=["image", "label", "fov"]),
])

# Pas d'augmentation sur la validation
val_transforms = None


#### 3.2) Create the class fot the DRIVE datset images

In [12]:
class DriveDataset(Dataset):
    """
    Dataset pour DRIVE :
      - img_dir    : .../training/images ou .../test/images
      - manual_dir : .../training/1st_manual ou .../test/1st_manual (vessels)
      - fov_dir    : .../training/masks ou .../test/masks (FOV disc)
      - img_size   : (H, W)
      - indices    : sous-ensemble (pour split train/val)
    """
    def __init__(self, img_dir, manual_dir, fov_dir,
                 img_size=(IMG_HEIGHT, IMG_WIDTH),
                 transforms=None,
                 indices=None):
        self.img_paths = sorted(list(Path(img_dir).glob("*.tif")))
        self.manual_dir = Path(manual_dir)
        self.fov_dir    = Path(fov_dir)
        self.img_size   = img_size  # (H, W)
        self.transforms = transforms
        self.indices    = indices

    def __len__(self):
        if self.indices is None:
            return len(self.img_paths)
        return len(self.indices)

    def _get_paths(self, idx):
        if self.indices is not None:
            idx = self.indices[idx]

        img_path = self.img_paths[idx]
        base = img_path.stem              # ex: '21_training' ou '01_test'
        num  = base.split('_')[0]         # '21' ou '01'

        manual_path = self.manual_dir / f"{num}_manual1.gif"

        # FOV mask : on cherche "<base>_mask.*" quel que soit l'ext
        candidates = list(self.fov_dir.glob(f"{base}_mask.*"))
        if len(candidates) == 0:
            raise FileNotFoundError(f"No FOV mask found for {base} in {self.fov_dir}")
        fov_path = candidates[0]

        return img_path, manual_path, fov_path

    def __getitem__(self, idx):
        img_path, manual_path, fov_path = self._get_paths(idx)

        H, W = self.img_size

        # --- Image RGB ---
        img = Image.open(img_path).convert("RGB").resize((W, H), Image.BILINEAR)
        img_np = np.array(img).astype(np.float32) / 255.0
        img_t = torch.from_numpy(img_np).permute(2, 0, 1)   # (3,H,W)

        # --- GT vaisseaux (manual1, binaire) ---
        gt_img = Image.open(manual_path).convert("L").resize((W, H), Image.NEAREST)
        gt_np = np.array(gt_img)

        label_np = np.zeros_like(gt_np, dtype=np.int64)
        label_np[gt_np > 0] = 1      # 1 = vessel

        label_t = torch.from_numpy(label_np).long().unsqueeze(0)  # (1,H,W)

        # --- FOV mask (0 / 1) ---
        fov_img = Image.open(fov_path).convert("L").resize((W, H), Image.NEAREST)
        fov_np = (np.array(fov_img) > 0).astype(np.uint8)
        fov_t = torch.from_numpy(fov_np).long().unsqueeze(0)      # (1,H,W)

        sample = {"image": img_t, "label": label_t, "fov": fov_t}

        if self.transforms is not None:
            sample = self.transforms(sample)

        return sample


In [13]:
# ========= Préparation des indices train / val sur TRAINING (21–40) =========
train_img_paths_all = sorted(list(Path(train_images_dir).glob("*.tif")))
train_nums = [int(p.stem.split("_")[0]) for p in train_img_paths_all]

"""
Validation on 22, 29, 31, 35
22 = Côtes d'Armor 
29 = Finistère
31 = Haute-Garonne
35 = Ille-et-Vilaine
"""
val_nums = {22, 29, 31, 35}
val_indices = [i for i, n in enumerate(train_nums) if n in val_nums]
train_indices = [i for i, n in enumerate(train_nums) if n not in val_nums]

print("Train indices (image numbers):", [train_nums[i] for i in train_indices])
print("Val indices (image numbers):  ", [train_nums[i] for i in val_indices])

# ========= Datasets =========
train_ds = DriveDataset(
    train_images_dir,
    train_manual_dir,
    train_fov_dir,
    img_size=(IMG_HEIGHT, IMG_WIDTH),
    transforms=train_transforms,
    indices=train_indices,
)

val_ds = DriveDataset(
    train_images_dir,
    train_manual_dir,
    train_fov_dir,
    img_size=(IMG_HEIGHT, IMG_WIDTH),
    transforms=val_transforms,
    indices=val_indices,
)

test_ds = DriveDataset(
    test_images_dir,
    test_manual_dir,
    test_fov_dir,
    img_size=(IMG_HEIGHT, IMG_WIDTH),
    transforms=val_transforms,
    indices=None,
)

# ========= DataLoaders =========
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print("Train / Val / Test sizes:", len(train_ds), len(val_ds), len(test_ds))


Train indices (image numbers): [21, 23, 24, 25, 26, 27, 28, 30, 32, 33, 34, 36, 37, 38, 39, 40]
Val indices (image numbers):   [22, 29, 31, 35]
Train / Val / Test sizes: 16 4 20


# 4) UNet Model

In [14]:
model = MonaiUNet(
    spatial_dims=2,                             # UNet in 2D
    in_channels=3,                              # images RGB => 3 channels
    out_channels=NUM_CLASSES,                   # 2 classes : background + vessels
    channels=(64, 128, 256, 512, 1024),         # number of feature channels at each level
    strides=(2, 2, 2, 2),                       # 4 steps of downsampling (implicite max-pool 2x2)
    kernel_size=3,                              # conv 3x3
    up_kernel_size=3,                           # up-conv 3x3
    num_res_units=0,                            # no residual units -> simple blocs conv+ReLU
    act="RELU",                                 # Use ReLU, defaults is PReLU 
    norm="instance",                                  # no normalization in the 2015 article
    dropout=0.0,                                # No dropout here
).to(device)

In [15]:
print(model.__class__.__name__, "initialized on", device)

UNet initialized on cuda


# 5) Optimizer

In [17]:
optimizer = optim.AdamW(
    model.parameters(),
    lr=LR,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=1e-2
    )

# 6) Scheduler

In [18]:
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=50,   # tous les 50 epochs
    gamma=0.5       # LR = LR * 0.5
)

# 7) Loss function

In [19]:
# Pour DRIVE (binaire), on commence sans pondération explicite des classes.
# Si besoin, on pourra recalculer des poids plus tard à partir des fréquences de pixels.
weights = None

In [20]:
loss_function = DiceCELoss(
    include_background=False,   # on ne compte pas le background dans le Dice
    to_onehot_y=True,
    softmax=True,
    weight=weights,
    lambda_dice=1.0,
    lambda_ce=1.0,
)

# 8) History and metrics

In [22]:
# ===========================================
#   DICTIONNAIRES STOCRAGE (MEAN ONLY)
# ===========================================

train_history = {
    "loss": [],
    "accuracy":   [],
    "dice":       [],
    "iou":        [],
    "precision":  [],
    "recall":     [],
    "sensitivity":[],
    "specificity":[],
    "auc":        [],
    "lr":         [],
    "alpha":      []
}

val_history = {
    "loss": [],
    "accuracy":   [],
    "dice":       [],
    "iou":        [],
    "precision":  [],
    "recall":     [],
    "sensitivity":[],
    "specificity":[],
    "auc":        []
}

In [23]:
def init_confmat_sums(num_classes=NUM_CLASSES):
    """
    Agrégats TP/FP/FN/TN sur l'ensemble du dataset.
    N = nombre de pixels pris en compte (dans le FOV).
    """
    sums = {
        "TP": np.zeros(num_classes, dtype=np.int64),
        "FP": np.zeros(num_classes, dtype=np.int64),
        "FN": np.zeros(num_classes, dtype=np.int64),
        "TN": np.zeros(num_classes, dtype=np.int64),
        "N":  0,
    }
    return sums

In [24]:
def init_auc_buffers(classes):
    """
    Buffers pour stocker y_true / y_score par classe (liste d'indices de classes).
    Exemple : classes = [VESSEL_CLASS].
    """
    buffers = {}
    for c in classes:
        buffers[c] = {"y_true": [], "y_score": []}
    return buffers

In [25]:
def update_confmat_sums(sums, preds, targets, num_classes=NUM_CLASSES, fov=None):
    """
    Met à jour les compteurs TP/FP/FN/TN pour chaque classe.
    preds, targets, fov : tensors (N,H,W)
    Seuls les pixels avec fov == 1 sont pris en compte (si fov n'est pas None).
    """
    with torch.no_grad():
        if fov is not None:
            f = fov.bool()
        else:
            f = None

        for c in range(num_classes):
            p = (preds == c)
            t = (targets == c)

            if f is not None:
                p = p & f
                t = t & f

            tp = (p & t).sum().item()
            fp = (p & ~t).sum().item()
            fn = (~p & t).sum().item()
            tn = (~p & ~t).sum().item()

            sums["TP"][c] += tp
            sums["FP"][c] += fp
            sums["FN"][c] += fn
            sums["TN"][c] += tn

        if f is not None:
            sums["N"] += f.sum().item()
        else:
            sums["N"] += targets.numel()


In [26]:
def update_auc_buffers(buffers, logits, targets, fov=None):
    """
    Stocke y_true / y_score pour calculer l'AUC par classe ensuite.
    - logits: (N,C,H,W)
    - targets: (N,H,W)
    - fov: (N,H,W) ou None
    Seuls les pixels avec fov == 1 sont pris en compte.
    """
    with torch.no_grad():
        probs = F.softmax(logits, dim=1)  # (N,C,H,W)

        flat_probs = probs.permute(0, 2, 3, 1).reshape(-1, probs.shape[1])
        flat_t     = targets.reshape(-1)

        if fov is not None:
            flat_f = fov.reshape(-1).bool()
            flat_probs = flat_probs[flat_f]
            flat_t     = flat_t[flat_f]

        for c in buffers.keys():
            y_true  = (flat_t == c).cpu().numpy().astype(np.uint8)
            y_score = flat_probs[:, c].cpu().numpy()
            buffers[c]["y_true"].append(y_true)
            buffers[c]["y_score"].append(y_score)

In [None]:
def compute_scalar_metrics_from_confmat(sums, include_classes):
    """
    Calcule accuracy / precision / recall / specificity / dice / IoU
    en moyenne sur les classes listées (ex: [VESSEL_CLASS]).
    """
    TP, FP, FN, TN = sums["TP"], sums["FP"], sums["FN"], sums["TN"]
    eps = 1e-7

    acc_c  = (TP + TN) / np.maximum(TP + TN + FP + FN, eps)
    prec_c = TP / np.maximum(TP + FP, eps)
    rec_c  = TP / np.maximum(TP + FN, eps)
    spec_c = TN / np.maximum(TN + FP, eps)
    dice_c = (2 * TP) / np.maximum(2 * TP + FP + FN, eps)
    iou_c  = TP / np.maximum(TP + FP + FN, eps)

    idx = np.array(include_classes, dtype=int)
    return {
        "accuracy":    float(np.mean(acc_c[idx])),
        "precision":   float(np.mean(prec_c[idx])),
        "recall":      float(np.mean(rec_c[idx])),
        "specificity": float(np.mean(spec_c[idx])),
        "dice":        float(np.mean(dice_c[idx])),
        "iou":         float(np.mean(iou_c[idx])),
    }

In [28]:
def compute_mean_auc(buffers):
    """
    Calcule l'AUC moyen en moyennant sur les classes présentes dans `buffers`.
    """
    aucs = []
    for c, pack in buffers.items():
        if len(pack["y_true"]) == 0:
            continue
        y_true  = np.concatenate(pack["y_true"])
        y_score = np.concatenate(pack["y_score"])

        pos = (y_true == 1).sum()
        neg = (y_true == 0).sum()
        if pos == 0 or neg == 0:
            continue

        try:
            aucs.append(roc_auc_score(y_true, y_score))
        except Exception:
            pass

    if len(aucs) == 0:
        return float("nan")
    return float(np.mean(aucs))

# 9) Training

In [None]:
for epoch in range(EPOCHS):

    # =========================
    #        TRAIN
    # =========================
    model.train()

    running_loss = 0.0
    conf_sums = init_confmat_sums()
    auc_buf   = init_auc_buffers([VESSEL_CLASS])

    for imgs, masks in train_loader:
        imgs  = imgs.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)

        optimizer.zero_grad()

        outputs = model(imgs)  # (N,C,H,W)
        loss    = loss_function(outputs, masks)  # Dice + CE (MONAI)

        loss.backward()
        optimizer.step()

        batch_size = imgs.size(0)
        running_loss += loss.item() * batch_size

        preds = outputs.argmax(dim=1)
        update_confmat_sums(conf_sums, preds, masks, NUM_CLASSES)
        update_auc_buffers(auc_buf, outputs, masks)

    train_loss   = running_loss / len(train_ds)
    train_metrics = compute_scalar_metrics_from_confmat(conf_sums, VESSEL_CLASSES)
    train_auc     = compute_mean_auc(auc_buf)

    current_lr = optimizer.param_groups[0]["lr"]

    train_history["loss"].append(train_loss)
    train_history["accuracy"].append(train_metrics["accuracy"])
    train_history["dice"].append(train_metrics["dice"])
    train_history["iou"].append(train_metrics["iou"])
    train_history["precision"].append(train_metrics["precision"])
    train_history["recall"].append(train_metrics["recall"])
    train_history["specificity"].append(train_metrics["specificity"])
    train_history["auc"].append(train_auc)
    train_history["lr"].append(current_lr)

    # =========================
    #       VALIDATION
    # =========================
    model.eval()

    val_running_loss = 0.0
    conf_sums_val = init_confmat_sums()
    auc_buf_val   = init_auc_buffers(classes=[VESSEL_CLASS])

    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs  = imgs.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)

            outputs = model(imgs)
            loss    = loss_function(outputs, masks)

            batch_size = imgs.size(0)
            val_running_loss += loss.item() * batch_size

            preds = outputs.argmax(dim=1)
            update_confmat_sums(conf_sums_val, preds, masks, NUM_CLASSES)
            update_auc_buffers(auc_buf_val, outputs, masks)

    val_loss    = val_running_loss / len(val_ds)
    val_metrics = compute_scalar_metrics_from_confmat(conf_sums_val, [VESSEL_CLASS])
    val_auc     = compute_mean_auc(auc_buf_val)

    val_history["loss"].append(val_loss)
    val_history["accuracy"].append(val_metrics["accuracy"])
    val_history["dice"].append(val_metrics["dice"])
    val_history["iou"].append(val_metrics["iou"])
    val_history["precision"].append(val_metrics["precision"])
    val_history["recall"].append(val_metrics["recall"])
    val_history["specificity"].append(val_metrics["specificity"])
    val_history["auc"].append(val_auc)

    # =========================
    #        LOGS
    # =========================
    print("-" * 40)
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val   Loss: {val_loss:.4f}")
    print(f"  Train Dice: {train_metrics['dice']:.4f} | Val Dice: {val_metrics['dice']:.4f}")
    print(f"  Current LR: {current_lr:.6f}")

    # =========================
    #   SAVE BEST MODELS
    # =========================
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), str(models_dir / "loss_best_model_state_dict.pth"))
        torch.save(model,              str(models_dir / "loss_best_model_full.pth"))
        print("✅ New best model (loss) saved")

    if val_metrics["dice"] > best_val_dice:
        best_val_dice = val_metrics["dice"]
        torch.save(model.state_dict(), str(models_dir / "dice_best_model_state_dict.pth"))
        torch.save(model,              str(models_dir / "dice_best_model_full.pth"))
        print("✅ New best model (dice) saved")

    # =========================
    #   SCHEDULER STEP
    # =========================
    scheduler.step()

# =========================
#     SAVE LAST MODEL
# =========================
torch.save(model.state_dict(), str(models_dir / f"epoch{EPOCHS}_model_state_dict.pth"))
torch.save(model,              str(models_dir / f"epoch{EPOCHS}_model_full.pth"))
print("✅ Last model saved")


In [None]:
with open(data_dir / "UNet_train_history.json", "w") as f:
    json.dump(train_history, f, indent=4)

with open(data_dir / "UNet_val_history.json", "w") as f:
    json.dump(val_history, f, indent=4)

print("Historiques sauvegardés en JSON")

# 10) Validation tests

#### 10.1) Load the model

In [None]:
model.load_state_dict(torch.load(models_dir / "dice_best_model_state_dict.pth", map_location=device, weights_only=True))
model.eval()

#### 10.2) Images

#### 10.3) Curves

#### 10.3) Table of metrics

lol