## Pr√©paration du dataloader


### Recup dataset , transforms et datatloader 

In [1]:
import os
from pathlib import Path
import numpy as np
import pandas as pd

SEG_LABEL_COLS = ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']
SEG_DIR = '/home/tibia/Projet_Hemorragie/Seg_hemorragie/split_MONAI'
CLASSIFICATION_DATA_DIR = '/home/tibia/Projet_Hemorragie/MBH_label_case'
SAVE_DIR = "/home/tibia/Projet_Hemorragie/MBH_multitask_libMTL/saved_models"
os.makedirs(SAVE_DIR, exist_ok=True)
# ======================
# DATA PREPARATION
# ======================
def get_segmentation_data(split="train"):
    img_dir = Path(SEG_DIR) / split / "img"
    seg_dir = Path(SEG_DIR) / split / "seg"
    
    images = sorted(img_dir.glob("*.nii.gz"))
    labels = sorted(seg_dir.glob("*.nii.gz"))
    
    assert len(images) == len(labels), "Mismatch between image and label counts"

    data = []
    for img, lbl in zip(images, labels):
        data.append({
            "image": str(img),
            "label": str(lbl),
        })
        
    return data


def get_classification_data(split="train"):
    csv_path = Path(CLASSIFICATION_DATA_DIR) / "splits" / f"{split}_split.csv"
    df = pd.read_csv(csv_path)
    nii_dir = Path(CLASSIFICATION_DATA_DIR)
    label_cols = ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']
    
    data = []
    for _, row in df.iterrows():
        image_path = str(nii_dir / f"{row['patientID_studyID']}.nii.gz")
        label = np.array([row[col] for col in label_cols], dtype=np.float32)
        
        data.append({
            "image": image_path,
            "label": label
        })
    return data

In [2]:
#D√©finitions des transformations MONAI
from monai import transforms as T
import torch
 
def get_segmentation_transform(mode='train'):
    # Transforms de base (toujours appliqu√©es)
    base_transforms = [
        T.LoadImaged(keys=["image", "label"], image_only=True ),
        T.EnsureChannelFirstd(keys=["image", "label"]),
        T.CropForegroundd(keys=["image", "label"], source_key='image'),
        T.Orientationd(keys=["image", "label"], axcodes='RAS'),
        T.Spacingd(keys=["image", "label"], pixdim=(1., 1., 1.), mode=["bilinear", "nearest"]),
        T.SpatialPadd(keys=["image", "label"], spatial_size=(96, 96, 96)),
        T.ScaleIntensityRanged(
            keys=["image"],
            a_min=-10,
            a_max=140,
            b_min=0.0, b_max=1.0, clip=True
        ),
        T.SelectItemsD(keys=["image", "label"])
        
    ]
    augmentation_transforms = []
    if mode == 'train':
        augmentation_transforms = [
            T.RandCropByPosNegLabeld(
                keys=['image', 'label'],
                image_key='image',
                label_key='label',
                pos=5.0,
                neg=1.0,
                spatial_size=(96, 96, 96),
                num_samples=2
            ),
            T.RandFlipd(keys=["image", "label"], spatial_axis=[0, 1], prob=0.5),
            T.RandRotate90d(keys=["image", "label"], spatial_axes=(0, 1), prob=0.5),
            T.RandScaleIntensityd(keys=["image"], factors=0.1, prob=0.5),
            T.RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
            
        ]
        
        
        # T.RandGaussianNoised(keys=["image"], prob=0.5, mean=0.0, std=0.1),
        
        # # 3. Optionnel mais top : Flou (simulation de mouvement patient) ou Nettet√©
        # T.RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), sigma_z=(0.5, 1.0), prob=0.1),
     # √† tester
        
    final_transform = [T.EnsureTyped(keys=["image", "label"], track_meta=False)]
    

    all_transforms = base_transforms + augmentation_transforms + final_transform
    
       
    
    return T.Compose(all_transforms)
    
    

    


def get_classification_transform(mode='train'):
    # Transforms de base (toujours appliqu√©es)
    base_transforms = [
            T.LoadImaged(keys=["image"], image_only=True),
            T.EnsureChannelFirstd(keys=["image"]),
            T.Orientationd(keys=["image"], axcodes='RAS'),
            T.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
            T.CropForegroundd(keys=["image"], source_key='image'),
            T.ScaleIntensityRanged(
                keys=["image"],
                a_min=-10,a_max=140, 
                b_min=0.0, b_max=1.0, 
                clip=True) ,
            T.RandSpatialCropd(keys=["image"], roi_size=(96, 96, 96), random_size=False)]
        
    augmentation_transforms = []      
    if mode == 'train':
        augmentation_transforms = [
            T.RandFlipd(keys=["image"], spatial_axis=[0, 1, 2], prob=0.5),
            T.RandRotate90d(keys=["image"], spatial_axes=(0, 1), prob=0.5),
            T.RandScaleIntensityd(keys=["image"], factors=0.1, prob=0.5),
            T.RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
            
        ]
        
    final_transform = [T.ToTensord(keys=["image", "label"]),
                       T.SelectItemsD(keys=["image", "label"]),
                       T.EnsureTyped(keys=["image", "label"], track_meta=False)]
        
    all_transforms = base_transforms + augmentation_transforms + final_transform
        
    return T.Compose(all_transforms)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from monai.data import DataLoader, PersistentDataset
seg_train_data=get_segmentation_data("train")
cls_train_data=get_classification_data("train")

seg_train_dataset = PersistentDataset(
        seg_train_data, 
        transform=get_segmentation_transform('train'),
        cache_dir=os.path.join(SAVE_DIR, "cache_train")
    )

cls_train_dataset = PersistentDataset(
        cls_train_data,
        transform=get_classification_transform('train'),
        cache_dir=os.path.join(SAVE_DIR, "cache_train"))
    
#Val dataset
seg_val_data=get_segmentation_data("val")
cls_val_data=get_classification_data("val")
seg_val_dataset = PersistentDataset(
        seg_val_data, 
        transform=get_segmentation_transform('val'),    
        cache_dir=os.path.join(SAVE_DIR, "cache_val")
    )   
cls_val_dataset = PersistentDataset(
        cls_val_data,
        transform=get_classification_transform('val'),
        cache_dir=os.path.join(SAVE_DIR, "cache_val"))      


# DataLoaders
seg_train_loader = DataLoader(
        seg_train_dataset, 
        batch_size=2, 
        shuffle=True, 
        num_workers=8,
        persistent_workers=True,
)

cls_train_loader = DataLoader(
        cls_train_dataset, 
        batch_size=2, 
        shuffle=True, 
        num_workers=8,
        persistent_workers=True,
)
  
train_dataloaders = {'segmentation': seg_train_loader,
                     'classification': cls_train_loader
                     }


seg_val_loader = DataLoader(
        seg_val_dataset, 
        batch_size=1, 
        shuffle=False, 
        num_workers=8,
        persistent_workers=True,
)   

cls_val_loader = DataLoader(  
        cls_val_dataset, 
        batch_size=1, 
        shuffle=False, 
        num_workers=8,
        persistent_workers=True,
)
val_dataloaders = {'segmentation': seg_val_loader,
                   'classification': cls_val_loader
                   }




### Tests de shape

In [12]:
from monai.transforms import LoadImaged, EnsureTyped, Compose
from monai.data import Dataset, DataLoader
from pathlib import Path
import torch

sample_data = get_segmentation_data(split="train")[:1]  # Prendre un seul √©chantillon pour la d√©monstration
# ----------------------------------------------------
# Pipeline simple pour d√©monstration
# ----------------------------------------------------
transforms = Compose([
    LoadImaged(keys=["image", "label"]),   # Charge en numpy
    # üëâ Avant EnsureTyped : ce sont des numpy arrays
   # EnsureTyped(keys=["image", "label"])   # Convertit en MetaTensor
])

dataset = Dataset(data=sample_data, transform=transforms)
loader = DataLoader(dataset, batch_size=1)

# ----------------------------------------------------
# R√©cup√©ration d'un batch
# ----------------------------------------------------
batch = next(iter(loader))

img = batch["image"]
lbl = batch["label"]

print("--------------- AVANT EnsureTyped ---------------")
print("Dans LoadImaged, les donn√©es sont initialement numpy arrays.")

print("--------------- APR√àS EnsureTyped ----------------")
print("Type image :", type(img))
print("Type label :", type(lbl))

print("\nEst-ce que c'est un torch.Tensor ? ", isinstance(img, torch.Tensor))
print("Est-ce un MetaTensor MONAI ?        ", img.__class__.__name__)

# Afficher les m√©tadonn√©es disponibles
print("\n--- Metadata de l'image ---")
print(img.meta)

print("\n--- Metadata du label ---")
print(lbl.meta)

print("\nShape du tenseur :", img.shape)
print("Dtype :", img.dtype)


--------------- AVANT EnsureTyped ---------------
Dans LoadImaged, les donn√©es sont initialement numpy arrays.
--------------- APR√àS EnsureTyped ----------------
Type image : <class 'monai.data.meta_tensor.MetaTensor'>
Type label : <class 'monai.data.meta_tensor.MetaTensor'>

Est-ce que c'est un torch.Tensor ?  True
Est-ce un MetaTensor MONAI ?         MetaTensor

--- Metadata de l'image ---
{'intent_p1': tensor([0.]), 'qoffset_x': tensor([125.]), 'cal_min': tensor([0.]), original_affine: tensor([[[-4.8828e-01,  0.0000e+00,  0.0000e+00,  1.2500e+02],
         [ 0.0000e+00, -4.6168e-01, -1.7216e+00,  1.6214e+02],
         [ 0.0000e+00, -1.5897e-01,  4.9999e+00,  2.5580e+01],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]]],
       dtype=torch.float64), 'as_closest_canonical': tensor([False]), 'session_error': tensor([0], dtype=torch.int16), 'datatype': tensor([8], dtype=torch.int16), 'quatern_d': tensor([0.9863]), 'filename_or_obj': ['/home/tibia/Projet_Hemorragie/Seg_h

In [5]:
# shape de ce qui rentre dans le mod√®le
for batch in train_dataloaders['segmentation']:
    print(batch['image'].shape)
    break   

for batch in train_dataloaders['classification']:
    print(batch['image'].shape)
    break

torch.Size([4, 1, 96, 96, 96])
torch.Size([2, 1, 96, 96, 96])




In [13]:
print("=== SEGMENTATION TRAIN BATCH ===")
for batch in seg_train_loader:
    print("Keys:", batch.keys())
    print("Image type:", type(batch["image"]))
    print("Label type:", type(batch["label"]))

    # MONAI MetaTensor metadata
    #print("Image meta:", batch["image"].meta)
    #print("Label meta:", batch["label"].meta)

    # shapes
    print("Image shape:", batch["image"].shape)
    print("Label shape:", batch["label"].shape)

    break  # Only inspect first batch

print("\n=== CLASSIFICATION TRAIN BATCH ===")
for batch in cls_train_loader:
    print("Keys:", batch.keys())
    print("Image type:", type(batch["image"]))
    print("Label type:", type(batch["label"]))

    # Attention: classification labels are numpy ‚Üí convert?
    print("Label:", batch["label"])
    print("Image meta:", batch["image"].meta)

    print("Image shape:", batch["image"].shape)
    print("Label shape:", batch["label"].shape)

    break

=== SEGMENTATION TRAIN BATCH ===
Keys: dict_keys(['image', 'label'])
Image type: <class 'torch.Tensor'>
Label type: <class 'torch.Tensor'>
Image shape: torch.Size([4, 1, 96, 96, 96])
Label shape: torch.Size([4, 1, 96, 96, 96])

=== CLASSIFICATION TRAIN BATCH ===
Keys: dict_keys(['image', 'label'])
Image type: <class 'torch.Tensor'>
Label type: <class 'torch.Tensor'>
Label: tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])


AttributeError: 'Tensor' object has no attribute 'meta'



## Pr√©paration dictionnaire de t√¢che 

### Pr√©paration m√©triques et loss

In [4]:
# Losses
# Ponderer ensuite pa classe avec WeightSampler

from LibMTL.loss import AbsLoss
import torch
from monai.losses import DiceCELoss

class ClassificationLossWrapper(AbsLoss):
    def __init__(self):
        super().__init__()
        self.loss_fn = torch.nn.BCEWithLogitsLoss()

    def compute_loss(self, pred, gt):
        return self.loss_fn(pred, gt.float())
    
class SegmentationLossWrapper(AbsLoss):
    def __init__(self):
        super().__init__()
        self.loss_fn = DiceCELoss(
            include_background=False,
            to_onehot_y=True,
            softmax=True
        )

    def compute_loss(self, pred, gt):
        return self.loss_fn(pred, gt)



In [5]:
from LibMTL.metrics import AbsMetric
import torch

from torchmetrics.classification import MultilabelAUROC

class MultiLabelAUCMetric(AbsMetric):
    def __init__(self, num_labels=6):
        super().__init__()
        self.metric = MultilabelAUROC(num_labels=num_labels, average=None)   # par classe
        self.metric_mean = MultilabelAUROC(num_labels=num_labels, average="macro")  # moyenne
        self.num_labels = num_labels

    def update_fun(self, pred, gt):
        # pred = logits -> transform needed
        pred = torch.sigmoid(pred)
        gt=gt.detach().cpu().long() #  pour torchmetrics veut des long
        pred=pred.detach().cpu()
        
        self.metric.update(pred,gt)
        self.metric_mean.update(pred,gt)

    def score_fun(self):
        per_class = self.metric.compute().tolist()
        mean_auc = self.metric_mean.compute().item()
        return per_class + [mean_auc]

    def reinit(self):
        super().reinit()
        self.metric.reset()
        self.metric_mean.reset()

In [None]:
# Loss
from LibMTL.metrics import AbsMetric
from monai.metrics import DiceMetric,DiceHelper
from monai.utils import MetricReduction, deprecated_arg
        
class DiceMetricAdapter(AbsMetric):
    """
    Cet adaptateur impl√©mente AbsMetric pour calculer le Dice Score correctement.
    
    - `update_fun` utilise DiceHelper pour obtenir les scores bruts (B, C) 
      et les stocke dans `self.record`.
    - `score_fun` agr√®ge tous les scores de `self.record` et calcule 
      la moyenne finale (le "score des totaux" √©mul√©).
    """
    def __init__(self, num_classes, include_background=False):
        # Initialise self.record et self.bs
        super().__init__()
        
        self.num_classes = num_classes
        self.include_background = include_background
        
        # On utilise DiceHelper comme "calculateur" ponctuel.
        # On lui demande de NE PAS faire de r√©duction (reduction="none")
        # car on veut stocker les scores bruts (Batch, Classes).
        self.dice_helper = DiceHelper(
            include_background=include_background,
            num_classes=num_classes,
            reduction=MetricReduction.NONE,
            ignore_empty=True,  # Important : ignore les cas o√π le GT est vide
            apply_argmax=False  # On le fera nous-m√™mes dans update_fun
        )

    def update_fun(self, pred, gt):
        """
        Appel√© √† chaque batch. Calcule les scores (B, C) et les stocke.
        
        Args:
            pred (torch.Tensor): Pr√©dictions (logits) de forme (B, C, H, W, D)
            gt (torch.Tensor): V√©rit√© terrain (labels) de forme (B, 1, H, W, D)
        """
        # 1. Convertir les logits en labels
        # DiceHelper attend des labels, pas des logits
        pred_labels = torch.argmax(pred, dim=1, keepdim=True)
        
        # 2. Calculer les scores Dice pour ce batch
        # Le r√©sultat est un tenseur de (B, num_classes_calcul√©es)
        # ex: (B, 5) si num_classes=6 et include_background=False
        batch_dice_scores,_ = self.dice_helper(pred_labels, gt)
        
        # 3. Stocker ce tenseur dans notre "record"
        self.record.append(batch_dice_scores)
        
        # 4. Stocker la taille du batch (comme le fait AbsMetric)
        self.bs.append(pred.shape[0])

    def score_fun(self):
        """
        Appel√© √† la fin de l'√©poque. Agr√®ge les scores et calcule la moyenne. Peut etre √† modifier pour le loggage de chaque dice
        """
        if not self.record:
            # Retourne un score pour chaque classe, mis √† 0
            num_expected_classes = self.num_classes - (1 if not self.include_background else 0)
            return torch.zeros(num_expected_classes)
            
        # 1. Rassembler tous les tenseurs de (B, C) en un seul
        # grand tenseur de (Total_B, C)
        all_scores = torch.cat(self.record, dim=0)
        
        # 2. Calculer la moyenne sur la dimension des batches (dim=0)
        # On utilise nanmean pour ignorer les NaN (cas des GT vides)
        # C'est la fa√ßon correcte d'agr√©ger le Dice.
        mean_scores_per_class = torch.nanmean(all_scores, dim=0)
        #mean_gloabal = torch.nanmean(mean_scores_per_class)
        
        # `score_fun` est cens√© retourner une "liste", mais un tenseur
        # est plus utile. On retourne la moyenne par classe.
        return mean_scores_per_class.tolist()
    
    # La m√©thode reinit() est h√©rit√©e de AbsMetric et fonctionne parfaitement
    # car elle vide self.record et self.bs.


In [7]:
## test torch.cat
record = []
dummy_pred = torch.randn(2, 6, 96, 96, 96)
dummy_label =torch.argmax(dummy_pred,dim=1,keepdim=True)
dummy_gt = torch.randint(low=0, high=6, size=(2, 1, 96, 96, 96)) # 6 classes (0, 1, 2, 3, 4, 5)
print("dummy_pred shape :", dummy_pred.shape)
print("dummy_label shape :", dummy_label.shape)
batch_dice_scores, _ = DiceHelper(
    include_background=False,
    num_classes=6,
    reduction=MetricReduction.NONE,
    ignore_empty=True,
    apply_argmax=False
)(dummy_label, dummy_gt)
print("batch_dice_scores shape :", batch_dice_scores.shape)

record.append(batch_dice_scores)
print("record : ", record)


## nouveau batch :
dummy_pred2 = torch.randn(2, 6, 96, 96, 96)
dummy_label2 =torch.argmax(dummy_pred2,dim=1,keepdim=True)
dummy_gt2 = torch.randint(low=0, high=6, size=(2, 1, 96, 96, 96)) # 6 classes (0, 1, 2, 3, 4, 5)
batch_dice_scores2, _ = DiceHelper(
    include_background=False,
    num_classes=6,
    reduction=MetricReduction.NONE,
    ignore_empty=True,
    apply_argmax=False
)(dummy_label2, dummy_gt2)
record.append(batch_dice_scores2)

print("record after 2 batches: ", record)
all_scores = torch.cat(record, dim=0)
print("all_scores shape after cat: ", all_scores.shape)
mean_scores_per_class = torch.nanmean(all_scores, dim=0)
print("mean_scores_per_class: ", mean_scores_per_class)


dummy_pred shape : torch.Size([2, 6, 96, 96, 96])
dummy_label shape : torch.Size([2, 1, 96, 96, 96])
batch_dice_scores shape : torch.Size([2, 5])
record :  [tensor([[0.1662, 0.1655, 0.1659, 0.1675, 0.1670],
        [0.1674, 0.1660, 0.1658, 0.1665, 0.1658]])]
record after 2 batches:  [tensor([[0.1662, 0.1655, 0.1659, 0.1675, 0.1670],
        [0.1674, 0.1660, 0.1658, 0.1665, 0.1658]]), tensor([[0.1662, 0.1659, 0.1672, 0.1662, 0.1659],
        [0.1674, 0.1669, 0.1682, 0.1666, 0.1684]])]
all_scores shape after cat:  torch.Size([4, 5])
mean_scores_per_class:  tensor([0.1668, 0.1661, 0.1668, 0.1667, 0.1668])


### Dictionnaire de t√¢ches

In [9]:
class_names = ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']
seg_metric_names_list= [ 'EDH', ]
metric_names_list = [f"AUC_{name}" for name in class_names] + ["AUC_Mean"]
seg_metric_names_list = [f"Dice_{name}" for name in class_names]

print("Noms des m√©triques de classification :", metric_names_list)
print("Noms des m√©triques de segmentation   :", seg_metric_names_list)
# dictionnaire de t√¢ches

task_dict = {
    'classification': {
        'loss_fn': ClassificationLossWrapper(),
        'metrics_fn': MultiLabelAUCMetric(num_labels=6),
        'metrics': ['val_auc_class_0', 'val_auc_class_1', 'val_auc_class_2', 
                   'val_auc_class_3', 'val_auc_class_4', 'val_auc_class_5', 
                   'val_auc_mean'],
        'weight': [1.0]
    },
    'segmentation': {
        'loss_fn': SegmentationLossWrapper(),
        'metrics_fn': DiceMetricAdapter(num_classes=6, include_background=False),
        'metrics': ['dice_c1', 'dice_c2', 'dice_c3', 'dice_c4', 'dice_c5'],
        'weight': [1.0] * 5
    }
}
# self.task_num = len(task_dict)
task_num = len(task_dict)
print(f"Nombre de t√¢ches d√©finies : {task_num}")

Noms des m√©triques de classification : ['AUC_any', 'AUC_epidural', 'AUC_intraparenchymal', 'AUC_intraventricular', 'AUC_subarachnoid', 'AUC_subdural', 'AUC_Mean']
Noms des m√©triques de segmentation   : ['Dice_any', 'Dice_epidural', 'Dice_intraparenchymal', 'Dice_intraventricular', 'Dice_subarachnoid', 'Dice_subdural']
Nombre de t√¢ches d√©finies : 2


### Petits tests

In [10]:
# Model
from monai.networks import nets as monai_nets

model = monai_nets.UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=6,
    channels=(32, 64, 128, 256, 320, 320),
    strides=(2, 2, 2, 2, 2),
    num_res_units=2,
)

In [11]:
import torch

B = 2    # Batch Size (nombre d'√©chantillons)
C = 6    # Nombre de Classes (+ background)
H = 32   # Hauteur (Height)
W = 32   # Largeur (Width)
D = 16   # Profondeur (Depth)


pred = torch.rand(B, C, H, W, D) 
pred_labels = torch.argmax(pred, dim=1, keepdim=True)
pred_lables_2= torch.nn.Softmax(dim=1)(pred)

print("Predicted Labels Shape:", pred_labels.shape)  # Devrait afficher (2, 1, 32, 32, 16)
print("Predicted Labels after Softmax Shape:", pred_lables_2.shape)  # Devrait afficher (2, 6, 32, 32, 16)
# --- Tenseur de V√©rit√© Terrain (Labels) ---
# Forme d√©sir√©e : (B, 1, H, W, D) -> (2, 1, 32, 32, 16)
# Utilisation de torch.randint pour simuler des labels (entiers de 0 √† C-1)
# Les labels doivent √™tre des entiers et non des flottants.
gt = torch.randint(low=0, high=C, size=(B, 1, H, W, D))



Predicted Labels Shape: torch.Size([2, 1, 32, 32, 16])
Predicted Labels after Softmax Shape: torch.Size([2, 6, 32, 32, 16])


## Architecture

### Encodeur / D√©codeurs

In [12]:
from typing import Sequence
import torch
import torch.nn as nn
from monai.networks.nets.basic_unet import TwoConv, Down, UpCat


class HemorrhageEncoder(nn.Module):
    """
    Cette classe contient la partie descendante (encodeur) du U-Net.
    Elle est partag√©e par les deux t√¢ches.
    Son forward pass retourne une liste de toutes les feature maps
    n√©cessaires pour les skip connections du d√©codeur de segmentation.
    """
    def __init__(
        self,
        spatial_dims: int = 3,
        in_channels: int = 1,
        features: Sequence[int] = (32, 32, 64, 128, 256, 32),
        act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
        norm: str | tuple = ("instance", {"affine": True}),
        bias: bool = True,
        dropout: float | tuple = 0.0,
    ):
        super().__init__()
        
        # Assure que 'features' a la bonne longueur
        self.fea = nn.Parameter(torch.tensor(features), requires_grad=False)
        
        self.conv_0 = TwoConv(spatial_dims, in_channels, self.fea[0], act, norm, bias, dropout)
        self.down_1 = Down(spatial_dims, self.fea[0], self.fea[1], act, norm, bias, dropout)
        self.down_2 = Down(spatial_dims, self.fea[1], self.fea[2], act, norm, bias, dropout)
        self.down_3 = Down(spatial_dims, self.fea[2], self.fea[3], act, norm, bias, dropout)
        self.down_4 = Down(spatial_dims, self.fea[3], self.fea[4], act, norm, bias, dropout)

    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
        """
        Le forward pass de l'encodeur.
        Retourne une liste contenant le bottleneck (x4) et toutes les
        sorties interm√©diaires pour les skip connections.
        """
        x0 = self.conv_0(x)
        x1 = self.down_1(x0)
        x2 = self.down_2(x1)
        x3 = self.down_3(x2)
        x4 = self.down_4(x3)  # C'est le bottleneck (la repr√©sentation partag√©e)
        
        return [x4, x3, x2, x1, x0]


# ========================================================================
# 2. LES D√âCODEURS (Les t√™tes sp√©cifiques √† chaque t√¢che)
# ========================================================================

class SegmentationDecoder(nn.Module):
    """
    Le d√©codeur pour la t√¢che de segmentation.
    Il prend la liste de features de l'encodeur et reconstruit le masque.
    """
    def __init__(
        self,
        spatial_dims: int = 3,
        out_channels: int = 6,
        features: Sequence[int] = (32, 32, 64, 128, 256, 32),
        act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
        norm: str | tuple = ("instance", {"affine": True}),
        bias: bool = True,
        dropout: float | tuple = 0.0,
        upsample: str = "deconv",
    ):
        super().__init__()
        
        fea = nn.Parameter(torch.tensor(features), requires_grad=False)
        
        self.upcat_4 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample)
        self.upcat_3 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample)
        self.upcat_2 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample)
        self.upcat_1 = UpCat(spatial_dims, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False)
        self.final_conv = nn.Conv3d(fea[5], out_channels, kernel_size=1)

    def forward(self, enc_out: list[torch.Tensor]) -> torch.Tensor:
        # On r√©cup√®re les tenseurs de la liste fournie par l'encodeur
        x4, x3, x2, x1, x0 = enc_out
        
        u4 = self.upcat_4(x4, x3)
        u3 = self.upcat_3(u4, x2)
        u2 = self.upcat_2(u3, x1)
        u1 = self.upcat_1(u2, x0)
        
        return self.final_conv(u1)

class ClassificationDecoder(nn.Module):
    """
    Le d√©codeur pour la t√¢che de classification.
    Il prend la liste de features de l'encodeur mais n'utilise que le
    bottleneck (x4) pour pr√©dire les classes.
    """
    def __init__(
        self,
        in_features: int,  # Doit correspondre √† features[4] de l'encodeur
        num_cls_classes: int = 6,
    ):
        super().__init__()
        
        # T√™te de classification, exactement comme avant
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool3d((4, 4, 4)),
            nn.Flatten(),
            nn.Linear(in_features * 4 * 4 * 4, 512),
            nn.LayerNorm(512),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_cls_classes)
        )

    def forward(self, enc_out: list[torch.Tensor]) -> torch.Tensor:
        # On ne prend que le bottleneck (le premier √©l√©ment de la liste)
        x4 = enc_out[0]
        
        # Toute la logique d'agr√©gation de patches a disparu !
        # On passe directement les features √† la t√™te de classification.
        return self.cls_head(x4)

# ========================================================================
# 3. ASSEMBLAGE FINAL POUR LibMTL
# ========================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device utilis√© : {device}")
# D√©finis tes param√®tres
task_name = ["segmentation", "classification"]
features = (32, 32, 64, 128, 256, 32)
print("Features 4: ", features[4])

# Cr√©e une instance de l'encodeur partag√©
encoder = HemorrhageEncoder(features=features).to(device)

# Cr√©e un dictionnaire de d√©codeurs
decoders = nn.ModuleDict({
    'segmentation': SegmentationDecoder(
        out_channels=6, # 6 classes de segmentation
        features=features
    ).to(device),
    'classification': ClassificationDecoder(
        in_features=features[4], # La taille du bottleneck (256)
        num_cls_classes=6 # 6 classes de classification
    ).to(device)
})



Device utilis√© : cuda
Features 4:  256


### Param√®tres ( optimiseur ,scheduler , kwargs )

In [13]:
# Param√®tres optim & scheduler
optim_param = {
    'optim': 'sgd', 
    'lr': 1e-3, 
    'weight_decay': 3e-5,  # 0.00003 est √©gal √† 3e-5
    'momentum': 0.99, 
    'nesterov': True
}

lengths = [len(loader) for loader in train_dataloaders.values()]

# 2. Trouver le dataloader le plus long (c'est sur lui que LibMTL se cale)
steps_per_epoch = max(lengths)


total_steps = steps_per_epoch * 1000  # 1000 epochs 

scheduler_param = {
    'scheduler': 'linearschedulewithwarmup',  # Correspond √† get_linear_schedule_with_warmup
    'num_warmup_steps': 0, 
    'num_training_steps': total_steps
   }
# --- 2. D√âFINITION MANUELLE DE KWARGS  ---

# Arguments sp√©cifiques √† l'architecture (Exemple pour un U-Net 3D ou une archi complexe)
arch_args = {
    
    # Si vous utilisez CGC, PLE, ou MMoE, vous devez sp√©cifier la taille d'image et le nombre d'experts
    # 'img_size': (96, 96, 96), 
    # 'num_experts': [4, 4, 4], 
    
    # Si votre encodeur ResNet a des arguments sp√©cifiques, mettez-les ici
    # Ex: 'channels': 3 # Si vous devez le passer √† l'initialisation de l'encodeur
}

# Arguments sp√©cifiques √† la m√©thode de pond√©ration (Exemple pour 'EW' qui n'a besoin de rien)
weight_args = {
    # Pour EW (Equal Weighting), c'est souvent vide.
}

# Si vous utilisiez DWA, vous d√©finiriez T :
# weight_args = {'T': 1.0} 

# Si vous utilisiez GradNorm, vous d√©finiriez alpha :
# weight_args = {'alpha': 0.1}

# --- 3. CONSOLIDER KWARGS (Optionnel mais propre) ---

# Cr√©e le dictionnaire kwargs global
kwargs = {
    'arch_args': arch_args,
    'weight_args': weight_args
}

### Petits tests ( shape )

In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device utilis√© : {device}")
tensor_test = torch.rand(2, 1, 96, 96, 96).to(device)  # Batch de 2 √©chantillons
encoder_outputs = encoder(tensor_test)
print (encoder_outputs[0].shape)

Device utilis√© : cuda
torch.Size([2, 256, 6, 6, 6])


In [42]:
encoder_outputs = encoder(tensor_test)
classification_output = decoders['classification'](encoder_outputs)
print (classification_output.shape)

segmentation_output = decoders['segmentation'](encoder_outputs)
print (segmentation_output.shape)


torch.Size([2, 6])
torch.Size([2, 6, 96, 96, 96])


In [21]:

def inspect_dataset_item_robust(dataset, name):
    print(f"\nüîé INSPECTION : {name}")
    try:
        item = dataset[500]
        
        # Cas 1 : C'est une Liste (Segmentation avec plusieurs patchs)
        if isinstance(item, list):
            print(f"   üì¶ C'est une LISTE de {len(item)} patchs.")
            first_patch = item[0]
            print(f"   -> Inspection du premier patch :")
            for key, value in first_patch.items():
                print(f"      - Cl√© '{key}' : {type(value)}")
                if 'numpy' in str(type(value)):
                    print(f" COUPABLE ! (NumPy dans la liste)")
        
        # Cas 2 : C'est un Dictionnaire (Classification)
        elif isinstance(item, dict):
            print(f"   üì¶ C'est un DICTIONNAIRE unique.")
            for key, value in item.items():
                print(f"      - Cl√© '{key}' : {type(value)}")
                if 'numpy' in str(type(value)):
                    print(f"COUPABLE ! (NumPy dans le dict)")
                    
    except Exception as e:
        print(f"Erreur d'inspection : {e}")

# Lance l'inspection
inspect_dataset_item_robust(train_dataloaders['classification'].dataset, "Train Segmentation")

#inspection train_datalaoders
print("\n--- INSPECTION DES DATASETS ---")
print(f" il y a {len(train_dataloaders['segmentation'].dataset)} √©chantillons dans le dataset de segmentation.")
print(f" il y a {len(train_dataloaders['classification'].dataset)} √©chantillons dans le dataset de classification.")


üîé INSPECTION : Train Segmentation


   üì¶ C'est un DICTIONNAIRE unique.
      - Cl√© 'image' : <class 'torch.Tensor'>
      - Cl√© 'label' : <class 'torch.Tensor'>

--- INSPECTION DES DATASETS ---
 il y a 154 √©chantillons dans le dataset de segmentation.
 il y a 1274 √©chantillons dans le dataset de classification.


In [44]:
from LibMTL.trainer import Trainer

test_trainer =Trainer(
    task_dict=task_dict,
    weighting= 'EW',
    architecture='Unet_hemo',
    #save_path=SAVE_DIR, √† ajouter   
    encoder_class=HemorrhageEncoder,
    decoders=decoders,
    rep_grad=False,
    multi_input=True,
    optim_param=optim_param,
    scheduler_param=scheduler_param,
    #device='cuda',
    **kwargs
)

print (f" attributs du trainer : {dir(test_trainer)}")
print (f" devices du trainer : {test_trainer.device}")
print (f" task num : {test_trainer.task_num}")
print (f" task name : {test_trainer.task_name}")
print (f" weighting : {test_trainer.weighting}")
print(f"multi input : {test_trainer.multi_input}")
print(f"kwargs : {test_trainer.kwargs}")

print(f" meter dict : {test_trainer.meter}")

#prepare optimizer
print(f"optimiser : {test_trainer.optimizer}")
print(f"scheduler : {test_trainer.scheduler}")

#prepare model 
print(f"model type : {type(test_trainer.model)}")

print("--- Arbre g√©n√©alogique (MRO) ---")
# MRO = Method Resolution Order
for i, parent in enumerate(test_trainer.model.__class__.mro()):
    print(f"{i}. {parent.__name__}")
    



Total Params: 14273074
Trainable Params: 14273068
Non-trainable Params: 6
LOG FORMAT | classification_LOSS val_auc_class_0 val_auc_class_1 val_auc_class_2 val_auc_class_3 val_auc_class_4 val_auc_class_5 val_auc_mean | segmentation_LOSS dice_c1 dice_c2 dice_c3 dice_c4 dice_c5 | TIME
 attributs du trainer : ['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_compute_loss', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs',

In [45]:
# R√©sum√© du mod√®le avec torchinfo
try:
    from torchinfo import summary
    # On simule une entr√©e (batch_size, channels, H, W, D)

    print(summary(test_trainer.model, input_size=(2, 1, 96, 96, 96), depth=3))
except ImportError:
    print("Installe torchinfo avec 'pip install torchinfo' pour une belle visualisation")
except Exception as e:
    print(f"Impossible de r√©sumer : {e}")

Layer (type:depth-idx)                                       Output Shape              Param #
MTLmodel                                                     [2, 6, 96, 96, 96]        --
‚îú‚îÄHemorrhageEncoder: 1-1                                     [2, 256, 6, 6, 6]         6
‚îÇ    ‚îî‚îÄTwoConv: 2-1                                          [2, 32, 96, 96, 96]       --
‚îÇ    ‚îÇ    ‚îî‚îÄConvolution: 3-1                                 [2, 32, 96, 96, 96]       960
‚îÇ    ‚îÇ    ‚îî‚îÄConvolution: 3-2                                 [2, 32, 96, 96, 96]       27,744
‚îÇ    ‚îî‚îÄDown: 2-2                                             [2, 32, 48, 48, 48]       --
‚îÇ    ‚îÇ    ‚îî‚îÄMaxPool3d: 3-3                                   [2, 32, 48, 48, 48]       --
‚îÇ    ‚îÇ    ‚îî‚îÄTwoConv: 3-4                                     [2, 32, 48, 48, 48]       55,488
‚îÇ    ‚îî‚îÄDown: 2-3                                             [2, 64, 24, 24, 24]       --
‚îÇ    ‚îÇ    ‚îî‚îÄMaxPool3d: 3-

In [46]:
import torch

# 1. Cr√©ation d'un faux batch (1 image, 1 canal, 96x96x96)
dummy_input = torch.randn(2, 1, 96, 96, 96).to(test_trainer.device)

print("--- Test Forward Pass Manuel ---")
try:
    # On passe l'input dans le mod√®le
    outputs = test_trainer.model(dummy_input)
    
    print("Succ√®s ! (Ce serait surprenant)")
    for task, out in outputs.items():
        print(f"Sortie {task}: shape {out.shape}")
        
except Exception as e:
    print("\nBOOM ! Erreur d√©tect√©e :")
    print(e)
    
    # Inspectons la sortie de l'encodeur seul pour confirmer
    print("\n--- Inspection de l'Encodeur seul ---")
    enc_out = test_trainer.model.encoder(dummy_input)
    print(f"Shape sortie encodeur : {enc_out.shape}")
    if len(enc_out.shape) == 2:
        print("-> CONFIRM√â : L'encodeur renvoie un vecteur plat (Batch, Features).")
        print("-> La segmentation a besoin de (Batch, Features, D, H, W).")

--- Test Forward Pass Manuel ---
Succ√®s ! (Ce serait surprenant)
Sortie classification: shape torch.Size([2, 6])
Sortie segmentation: shape torch.Size([2, 6, 96, 96, 96])


In [47]:
print("\n--- Inspection de l'Encodeur seul ---")
enc_out = test_trainer.model.encoder(dummy_input)
print(f"Type sortie encodeur : {type(enc_out)}")
print(f"Nombre d'√©l√©ments : {len(enc_out)}")
for i, tensor in enumerate(enc_out):
    print(f"enc_out[{i}].shape : {tensor.shape}")


--- Inspection de l'Encodeur seul ---
Type sortie encodeur : <class 'list'>
Nombre d'√©l√©ments : 5
enc_out[0].shape : torch.Size([2, 256, 6, 6, 6])
enc_out[1].shape : torch.Size([2, 128, 12, 12, 12])
enc_out[2].shape : torch.Size([2, 64, 24, 24, 24])
enc_out[3].shape : torch.Size([2, 32, 48, 48, 48])
enc_out[4].shape : torch.Size([2, 32, 96, 96, 96])


In [48]:
print("\n--- Test d√©codeur classification seul ---")
enc_out = test_trainer.model.encoder(dummy_input)
print(f"Bottleneck (enc_out[0]) shape: {enc_out[0].shape}")

try:
    cls_output = test_trainer.model.decoders['classification'](enc_out)
    print(f"Classification output shape: {cls_output.shape}")
except Exception as e:
    print(f"Erreur dans le d√©codeur de classification: {e}")
    
    print("\n--- Debug √©tape par √©tape ---")
    x4 = enc_out[0]
    print(f"1. Input x4: {x4.shape}")
    
    pool = nn.AdaptiveAvgPool3d((4, 4, 4))
    after_pool = pool(x4)
    print(f"2. After AdaptiveAvgPool3d: {after_pool.shape}")
    
    after_flatten = torch.flatten(after_pool, start_dim=1)
    print(f"3. After Flatten: {after_flatten.shape}")
    print(f"   Expected: [2, 16384]")


--- Test d√©codeur classification seul ---
Bottleneck (enc_out[0]) shape: torch.Size([2, 256, 6, 6, 6])
Classification output shape: torch.Size([2, 6])


In [49]:
print("\n--- Test du forward complet avec debug ---")

# Patch temporaire pour voir ce qui se passe dans _prepare_rep
original_prepare_rep = test_trainer.model._prepare_rep

def debug_prepare_rep(ss_rep, task, same_rep=None):
    print(f"\nüîç _prepare_rep appel√© pour task: {task}")
    print(f"   Type de ss_rep avant: {type(ss_rep)}")
    if isinstance(ss_rep, list):
        print(f"   Liste de {len(ss_rep)} √©l√©ments")
        for i, x in enumerate(ss_rep):
            print(f"     ss_rep[{i}].shape: {x.shape}")
    else:
        print(f"   ss_rep.shape: {ss_rep.shape}")
    
    result = original_prepare_rep(ss_rep, task, same_rep)
    
    print(f"   Type de ss_rep apr√®s: {type(result)}")
    if isinstance(result, list):
        print(f"   Liste de {len(result)} √©l√©ments")
        for i, x in enumerate(result):
            print(f"     result[{i}].shape: {x.shape}")
    else:
        print(f"   result.shape: {result.shape}")
    
    return result

# Remplacer temporairement
test_trainer.model._prepare_rep = debug_prepare_rep

# Tester le forward
try:
    outputs = test_trainer.model(dummy_input)
    print("\n‚úÖ Succ√®s !")
    for task, out in outputs.items():
        print(f"Sortie {task}: {out.shape}")
except Exception as e:
    print(f"\n‚ùå Erreur: {e}")


--- Test du forward complet avec debug ---

‚úÖ Succ√®s !
Sortie classification: torch.Size([2, 6])
Sortie segmentation: torch.Size([2, 6, 96, 96, 96])


In [50]:
import torch

dummy_input = torch.randn(2, 1, 96, 96, 96).to(test_trainer.device)

print("--- Scanner de l'Encodeur ---")
# On r√©cup√®re la liste
enc_out = test_trainer.model.encoder(dummy_input)

print(f"Type de sortie : {type(enc_out)}")
print(f"Nombre d'√©l√©ments dans la liste : {len(enc_out)}")

print("\n--- D√©tail des Skip Connections ---")
for i, feat in enumerate(enc_out):
    print(f"Feature {i} shape : {feat.shape}")

--- Scanner de l'Encodeur ---
Type de sortie : <class 'list'>
Nombre d'√©l√©ments dans la liste : 5

--- D√©tail des Skip Connections ---
Feature 0 shape : torch.Size([2, 256, 6, 6, 6])
Feature 1 shape : torch.Size([2, 128, 12, 12, 12])
Feature 2 shape : torch.Size([2, 64, 24, 24, 24])
Feature 3 shape : torch.Size([2, 32, 48, 48, 48])
Feature 4 shape : torch.Size([2, 32, 96, 96, 96])


## Entrainement

In [14]:

import wandb

config_l = dict(
    sharing_type="hard",   # "soft" ou "fine_tune"
    model="BasicUNetWithClassification",
    loss_weighting="none",
    dataset_size="balanced",  # "full" ou "balanced" ou "optimized"
    batch_size=2,
    learning_rate=1e-3,
    optimizer="sgd",
    batch_strat√©gie= "loop", 
    seed=42
)
torch.cuda.set_device(0)
# G√©n√©ration automatique de tags √† partir de config
tags = [f"{k}:{v}" for k, v in config_l.items() if k in ["sharing_type", "optimizer", "model", "loss_weighting"]]


# : Initialisation manuelle de wandb
# Au lieu de : wandb_logger = WandbLogger(...)
wandb.init(
    project="hemorrhage_multitask_test",
    group="noponderation",
    tags=tags,
    config=config_l,
    name="multitask_unet3d_libMTL"
)




# 3Ô∏è M√©thodes multit√¢ches
from LibMTL.architecture import HPS
from LibMTL.weighting import GradNorm

# 4Ô∏è Instanciation du Trainer
from LibMTL.trainer import Trainer

hemorrhage_trainer = Trainer(
    task_dict=task_dict,
    weighting= 'EW',
    architecture='Unet_hemo',
    #save_path=SAVE_DIR, √† ajouter 
    encoder_class=HemorrhageEncoder,
    decoders=decoders,
    rep_grad=False,
    multi_input=True,
    optim_param=optim_param,
    scheduler_param=scheduler_param,
    #device='cuda',
    **kwargs
)

#  Entra√Ænement
hemorrhage_trainer.train(train_dataloaders, test_dataloaders = None, epochs=1000 , val_dataloaders=val_dataloaders)


[34m[1mwandb[0m: Currently logged in as: [33mad_92_ywlcod[0m ([33mad_92_ywlcod-polytechnique-montr-al[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Total Params: 14273074
Trainable Params: 14273068
Non-trainable Params: 6
LOG FORMAT | classification_LOSS val_auc_class_0 val_auc_class_1 val_auc_class_2 val_auc_class_3 val_auc_class_4 val_auc_class_5 val_auc_mean | segmentation_LOSS dice_c1 dice_c2 dice_c3 dice_c4 dice_c5 | TIME




Epoch: 0000 | TRAIN: 0.4220 0.4589 0.4322 0.4955 0.4901 0.4910 0.4728 0.4734 | 1.1279 0.0007 0.2995 0.0000 0.0001 0.0004 0.0601 | Time: 184.6781 | 



ValueError: operands could not be broadcast together with shapes (5,) (6,) 