### Import Block

In [55]:
import os
import glob
import numpy as np
import nibabel as nib
from pathlib import Path
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, Model

In [56]:
import tensorflow as tf

### Preprocessing Block

##### Path and constants

In [57]:
PREPROC_ROOT = Path("../../../../raw_data/segmentation/brats2023_preprocessed")

PATCH_SIZE = (80, 80, 80)
BATCH_SIZE = 1
AUTOTUNE = tf.data.AUTOTUNE

RAW_ROOT = Path("../../../../raw_data/segmentation/brats2023_raw")

# Taille cible 
TARGET_SHAPE = (160, 192, 160)  # (H, W, D)

# On choisit un ordre de canaux cohérent (à garder ensuite dans le Model)
# [T1 = t1 native, T1c, T2w, T2 FLAIR]
MODALITIES = [
    ("t1n", "*-t1n.nii.gz"),
    ("t1c", "*-t1c.nii.gz"),
    ("t2w", "*-t2w.nii.gz"),
    ("t2f", "*-t2f.nii.gz"),
]

##### Loading function for .npz files

In [58]:
def load_npz(path):
    path = path.decode("utf-8")  # tf.string -> str
    data = np.load(path)
    img = data["image"]  # (H,W,D,4)
    seg = data["label"]  # (H,W,D)
    # On ajoute éventuellement un canal au label pour être (H,W,D,1)
    seg = np.expand_dims(seg, axis=-1)  # (H,W,D,1)
    return img.astype(np.float32), seg.astype(np.uint8)

def tf_load_npz(path):
    img, seg = tf.numpy_function(load_npz, [path], [tf.float32, tf.uint8])
    # Fixer les shapes statiques
    img.set_shape((*TARGET_SHAPE, len(MODALITIES)))  # (H,W,D,4)
    seg.set_shape((*TARGET_SHAPE, 1))                # (H,W,D,1)
    return img, seg


##### Extraction de patchs 3D aléatoires

In [59]:
PATCH_H, PATCH_W, PATCH_D = PATCH_SIZE

def random_patch_3d(img, seg):
    """img: (H,W,D,4), seg: (H,W,D,1)"""
    shape = tf.shape(img)
    H, W, D = shape[0], shape[1], shape[2]

    max_h = H - PATCH_H
    max_w = W - PATCH_W
    max_d = D - PATCH_D

    # Sécurité si patch > volume
    max_h = tf.maximum(max_h, 0)
    max_w = tf.maximum(max_w, 0)
    max_d = tf.maximum(max_d, 0)

    h = tf.random.uniform((), 0, max_h + 1, dtype=tf.int32)
    w = tf.random.uniform((), 0, max_w + 1, dtype=tf.int32)
    d = tf.random.uniform((), 0, max_d + 1, dtype=tf.int32)

    img_patch = img[h:h+PATCH_H, w:w+PATCH_W, d:d+PATCH_D, :]
    seg_patch = seg[h:h+PATCH_H, w:w+PATCH_W, d:d+PATCH_D, :]

    img_patch.set_shape((PATCH_H, PATCH_W, PATCH_D, len(MODALITIES)))
    seg_patch.set_shape((PATCH_H, PATCH_W, PATCH_D, 1))

    return img_patch, seg_patch


##### Data augmentation simple

In [60]:
def augment(img, seg):
    # Flip aléatoire gauche-droite (axe W)
    if tf.random.uniform(()) > 0.5:
        img = tf.reverse(img, axis=[1])
        seg = tf.reverse(seg, axis=[1])

    # Flip aléatoire avant-arrière (axe H)
    if tf.random.uniform(()) > 0.5:
        img = tf.reverse(img, axis=[0])
        seg = tf.reverse(seg, axis=[0])

    # Flip aléatoire sur l'axe de profondeur (D)
    if tf.random.uniform(()) > 0.5:
        img = tf.reverse(img, axis=[2])
        seg = tf.reverse(seg, axis=[2])

    # léger jitter d'intensité
    if tf.random.uniform(()) > 0.5:
        factor = tf.random.uniform((), 0.9, 1.1)
        img = img * factor

    return img, seg


##### Construction des datasets train / val

In [61]:
# Liste de tous les fichiers .npz
all_npz = sorted(str(p) for p in PREPROC_ROOT.glob("*.npz"))

train_paths, val_paths = train_test_split(all_npz, test_size=0.2, random_state=42)

def make_dataset(paths, augment_data=True):
    ds = tf.data.Dataset.from_tensor_slices(paths)
    ds = ds.map(tf_load_npz, num_parallel_calls=AUTOTUNE)
    ds = ds.map(random_patch_3d, num_parallel_calls=AUTOTUNE)
    if augment_data:
        ds = ds.map(augment, num_parallel_calls=AUTOTUNE)
    ds = ds.batch(BATCH_SIZE, drop_remainder=True).prefetch(AUTOTUNE)
    return ds

train_ds = make_dataset(train_paths, augment_data=True).repeat()
val_ds   = make_dataset(val_paths, augment_data=False)

steps_per_epoch = 200
validation_steps = 50

callbacks = [
    tf.keras.callbacks.ModelCheckpoint("best_unet3d.h5", save_best_only=True, monitor="val_loss"),
    tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3, verbose=1),
    tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
]




##### Better loss function : Dice + CE

In [62]:
import tensorflow as tf

def sparse_dice_loss(y_true, y_pred, smooth=1e-5):
    # y_true: (B,H,W,D,1) entiers [0..C-1]
    # y_pred: (B,H,W,D,C) probas softmax
    y_true = tf.squeeze(y_true, axis=-1)              # (B,H,W,D)
    y_true_one_hot = tf.one_hot(y_true, depth=tf.shape(y_pred)[-1])  # (B,H,W,D,C)

    # Dice par classe
    axes = (1,2,3)  # on somme sur H,W,D
    intersection = tf.reduce_sum(y_true_one_hot * y_pred, axis=axes)
    denom = tf.reduce_sum(y_true_one_hot + y_pred, axis=axes)

    dice = (2. * intersection + smooth) / (denom + smooth)
    dice_loss = 1 - dice  # shape (B,C)

    # moyenne des classes
    return tf.reduce_mean(dice_loss)

def combined_loss(y_true, y_pred):
    ce = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
    ce = tf.reduce_mean(ce)
    dsc = sparse_dice_loss(y_true, y_pred)
    return ce + dsc


In [None]:

def conv_block(x, filters):
    x = layers.Conv3D(filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv3D(filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x


def unet_3d(input_shape=(PATCH_H, PATCH_W, PATCH_D, 4), n_classes=4):
    
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1 = conv_block(inputs, 32)
    p1 = layers.MaxPooling3D(2)(c1)

    c2 = conv_block(p1, 64)
    p2 = layers.MaxPooling3D(2)(c2)

    c3 = conv_block(p2, 128)
    p3 = layers.MaxPooling3D(2)(c3)

    # Bottleneck
    bn = conv_block(p3, 256)

    # Decoder
    u3 = layers.Conv3DTranspose(128, 2, strides=2, padding='same')(bn)
    u3 = layers.Concatenate()([u3, c3])
    c4 = conv_block(u3, 128)

    u2 = layers.Conv3DTranspose(64, 2, strides=2, padding='same')(c4)
    u2 = layers.Concatenate()([u2, c2])
    c5 = conv_block(u2, 64)

    u1 = layers.Conv3DTranspose(32, 2, strides=2, padding='same')(c5)
    u1 = layers.Concatenate()([u1, c1])
    c6 = conv_block(u1, 32)

    outputs = layers.Conv3D(n_classes, 1, activation="softmax", dtype='float32')(c6)
    return Model(inputs, outputs)


model_unet = unet_3d()
model_unet.summary()

model_unet.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=combined_loss,
    metrics=["accuracy"])

model_unet.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",  # ou Dice+CE custom
    metrics=["accuracy"]
)

steps_per_epoch = 200   # par exemple 200 patchs par epoch


In [None]:
# model_unet.fit(
#     train_ds,
#     epochs=20,
#     steps_per_epoch=steps_per_epoch,
#     validation_data=val_ds,
#     validation_steps=validation_steps,
#     callbacks=callbacks,
# )

Epoch 1/20
[1m  5/200[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m11:47[0m 4s/step - accuracy: 0.1821 - loss: 1.6543

KeyboardInterrupt: 