In [12]:
import warnings
from typing import Dict, Any, Tuple

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.data import DataLoader, PersistentDataset
from monai.inferers import sliding_window_inference
from monai.losses import DiceCELoss
from monai.metrics import DiceHelper
from monai.networks.nets import UNet
import monai.transforms as T
from pytorch_lightning.loggers import TensorBoardLogger
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MultilabelRecall, MultilabelAUROC, MultilabelPrecision
from transformers import get_linear_schedule_with_warmup
import pandas as pd
import numpy as np
from pathlib import Path

In [10]:
import os

SEG_DIR = '/home/tibia/Projet_Hemorragie/Seg_hemorragie/split_MONAI'
CLASSIFICATION_DATA_DIR = '/home/tibia/Projet_Hemorragie/MBH_label_case'
SAVE_DIR = "/home/tibia/Projet_Hemorragie/MBH_multitask_log"
os.makedirs(SAVE_DIR, exist_ok=True)

NUM_CLASSES = 6
CLASS_NAMES = ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']

In [4]:
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
DEVICE


device(type='cuda', index=1)

## Vérfier que les CT Scans sont différents

In [None]:
import os

SEG_DIR = '/home/tibia/Projet_Hemorragie/MBH_SEG_2025_LLG_2025_06_12/imagesTr'
CLASSIFICATION_DATA_DIR = '/home/tibia/Projet_Hemorragie/MBH_label_case'


seg_files = set(os.listdir(SEG_DIR))
classif_files = set(os.listdir(CLASSIFICATION_DATA_DIR))

# Fichiers communs
common_files = seg_files & classif_files

# Fichiers spécifiques à chaque dossier
only_in_seg = seg_files - classif_files
only_in_classif = classif_files - seg_files

print(f"Nombre de fichiers dans SEG_DIR : {len(seg_files)}")
print(f"Nombre de fichiers dans CLASSIFICATION_DATA_DIR : {len(classif_files)}")
print(f"Nombre de fichiers communs : {len(common_files)}")
print(f"\nExemples communs : {list(common_files)[:5]}")

print(f"\nFichiers uniquement dans SEG_DIR : {len(only_in_seg)}")
print(f"Exemples : {list(only_in_seg)[:5]}")

print(f"\nFichiers uniquement dans CLASSIFICATION_DATA_DIR : {len(only_in_classif)}")
print(f"Exemples : {list(only_in_classif)[:5]}")

Nombre de fichiers dans SEG_DIR : 192
Nombre de fichiers dans CLASSIFICATION_DATA_DIR : 1982
Nombre de fichiers communs : 1

Exemples communs : ['ID_3be7cc35_ID_cf2643c5a8.nii.gz']

Fichiers uniquement dans SEG_DIR : 191
Exemples : ['ID_3e481fc5_ID_750530c215.nii.gz', 'ID_cb4c887c_ID_ec2a4643f3.nii.gz', 'ID_a429fd9a_ID_2eabc43f3d.nii.gz', 'ID_a1703be5_ID_d0e02f0443.nii.gz', 'ID_5f53ab26_ID_f1c7d54a53.nii.gz']

Fichiers uniquement dans CLASSIFICATION_DATA_DIR : 1981
Exemples : ['ID_b5a546db_ID_a008d70806.nii.gz', 'ID_0685c3cb_ID_a44a99c41b.nii.gz', 'ID_abdba136_ID_05f1970bac.nii.gz', 'ID_9c138c03_ID_d640907eaa.nii.gz', 'ID_1a7ca051_ID_6cc5549a4b.nii.gz']


## Test sur le modèle multitask

In [None]:
class MultiTaskHemorrhageNet(nn.Module):
    def __init__(self, num_seg_classes=6, num_cls_classes=6):
        super().__init__()
        
        # Encodeur partagé basé sur UNet
        self.shared_encoder = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=32,  # Features intermédiaires
            channels=(32, 64, 128, 256, 320), 
            strides=(2, 2, 2, 2),
            num_res_units=2,
            up_kernel_size=3,
            act=('LeakyReLU', {'inplace': True}),
        )
       
        
        # Tête de segmentation
        self.seg_head = nn.Sequential(
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(64, num_seg_classes, kernel_size=1)
        )
        
        # Tête de classification
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool3d((4, 4, 4)),  # Global pooling adaptatif
            nn.Flatten(),
            nn.Linear(32 * 4 * 4 * 4, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_cls_classes)
        )
        
    def forward(self, x):
        # Encodage partagé
        shared_features = self.shared_encoder(x)
        
        # Segmentation
        seg_logits = self.seg_head(shared_features)
        
        # Classification
        cls_logits = self.cls_head(shared_features)
        
        return seg_logits, cls_logits
    
model= MultiTaskHemorrhageNet(num_seg_classes=NUM_CLASSES, num_cls_classes=NUM_CLASSES)
model = model.to(DEVICE)
# Vérification de la structure du modèle
# print(model)
# model.state_dict()

In [None]:
def forward(self, x: torch.Tensor):
        """
        Args:
            x: input should have spatially N dimensions
                ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `spatial_dims`.
                It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have
                even edge lengths.

        Returns:
            A torch Tensor of "raw" predictions in shape
            ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``.
        """
        x0 = self.conv_0(x)

        x1 = self.down_1(x0)
        x2 = self.down_2(x1)
        x3 = self.down_3(x2)
        x4 = self.down_4(x3)

        cls_logits = self.cls_head(x4)


        u4 = self.upcat_4(x4, x3)
        u3 = self.upcat_3(u4, x2)
        u2 = self.upcat_2(u3, x1)
        u1 = self.upcat_1(u2, x0)

        logits = self.final_conv(u1)
        return logits

In [19]:
model =  UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=32,  # Features intermédiaires
            channels=(32, 64, 128, 256, 320),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            up_kernel_size=3,
            act=('LeakyReLU', {'inplace': True}),
        )

model_2 = nn.Sequential(
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(64, 6, kernel_size=1)
        )
        
        # Tête de classification
model_3=   nn.Sequential(
            nn.AdaptiveAvgPool3d((4, 4, 4)),  # Global pooling adaptatif
            nn.Flatten(),
            nn.Linear(32 * 4 * 4 * 4, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 6)
        )
        

In [14]:
try : 
    ! pip install torchinfo

except:
    print("torchinfo is already installed or installation failed.")
    
import torchinfo
from torchinfo import summary
# # Display model summary
model = model.to(DEVICE)
summary(model, input_size=(8, 1, 96, 96,96))

# model_2 = model_2.to(DEVICE)
# summary(model_2, input_size=(1, 32, 96, 96, 96))
    
# model_3 = model_3.to(DEVICE)
# summary(model_3, input_size=(32, 32, 96, 96,96))




Layer (type:depth-idx)                                                                          Output Shape              Param #
MultiTaskHemorrhageNet                                                                          [8, 6, 96, 96, 96]        --
├─UNet: 1-1                                                                                     [8, 32, 96, 96, 96]       --
│    └─Sequential: 2-1                                                                          [8, 32, 96, 96, 96]       --
│    │    └─ResidualUnit: 3-1                                                                   [8, 32, 48, 48, 48]       29,472
│    │    └─SkipConnection: 3-2                                                                 [8, 64, 48, 48, 48]       12,830,400
│    │    └─Sequential: 3-3                                                                     [8, 32, 96, 96, 96]       83,008
├─Sequential: 1-2                                                                               [8, 6, 9

## Création du dataset

In [28]:


def get_segmentation_data(split="train"):
    img_dir = Path(SEG_DIR) / split / "img"
    seg_dir = Path(SEG_DIR) / split / "seg"
    
    images = sorted(img_dir.glob("*.nii.gz"))
    labels = sorted(seg_dir.glob("*.nii.gz"))
    
    assert len(images) == len(labels), "Mismatch between image and label counts"

    data = []
    for img, lbl in zip(images, labels):
        data.append({
            "image": str(img),
            "label": str(lbl),
            "task": "segmentation"
        })
        
    return data


def get_classification_data(split="train"):
    csv_path = Path(CLASSIFICATION_DATA_DIR) / "splits" / f"{split}_split.csv"
    df = pd.read_csv(csv_path)
    nii_dir = Path(CLASSIFICATION_DATA_DIR)
    label_cols = ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']
    
    data = []
    for _, row in df.iterrows():
        image_path = str(nii_dir / f"{row['patientID_studyID']}.nii.gz")
        label = np.array([row[col] for col in label_cols], dtype=np.float32)
        
        data.append({
            "image": image_path,
            "label": label,
            "task": "classification"
        })
    return data


def get_multitask_dataset(split="train"):
    seg_data = get_segmentation_data(split)
    cls_data = get_classification_data(split)
    return seg_data + cls_data


new_segmentation_data = get_segmentation_data("train")
print(f"Nombre de données de segmentation : {len(new_segmentation_data)}")
print(f"Exemples de données de segmentation : {new_segmentation_data[:2]}")




new_classification_data = get_classification_data("train")
print(f"Nombre de données de classification : {len(new_classification_data)}")
print(f"Exemples de données de classification : {new_classification_data[:2]}")

new_multitask_data = get_multitask_dataset("train")
print(f"Nombre total de données multitâches : {len(new_multitask_data)}")
print(f"Exemples de données multitâches : {new_multitask_data[:2]}")


from collections import Counter
counter = Counter([sample["task"] for sample in new_multitask_data])
print(f"Répartition des tâches : {dict(counter)}")

Nombre de données de segmentation : 154
Exemples de données de segmentation : [{'image': '/home/tibia/Projet_Hemorragie/Seg_hemorragie/split_MONAI/train/img/ID_0237f3c9_ID_40015688b9.nii.gz', 'label': '/home/tibia/Projet_Hemorragie/Seg_hemorragie/split_MONAI/train/seg/ID_0237f3c9_ID_40015688b9.nii.gz', 'task': 'segmentation'}, {'image': '/home/tibia/Projet_Hemorragie/Seg_hemorragie/split_MONAI/train/img/ID_02b882cc_ID_a4892e60ae.nii.gz', 'label': '/home/tibia/Projet_Hemorragie/Seg_hemorragie/split_MONAI/train/seg/ID_02b882cc_ID_a4892e60ae.nii.gz', 'task': 'segmentation'}]
Nombre de données de classification : 1274
Exemples de données de classification : [{'image': '/home/tibia/Projet_Hemorragie/MBH_label_case/ID_89c8dc6b_ID_690d17c09c.nii.gz', 'label': array([1., 0., 1., 1., 1., 0.], dtype=float32), 'task': 'classification'}, {'image': '/home/tibia/Projet_Hemorragie/MBH_label_case/ID_6d4908ca_ID_bdaf1e8b99.nii.gz', 'label': array([1., 0., 0., 0., 0., 1.], dtype=float32), 'task': 'class

## Creation des transformées

In [None]:
# Idée est de faire deux pipelines de transformation, une pour la segmentation et une pour la classification



class TaskBasedTransform(T.MapTransform):
    """
    Applique un pipeline différent selon la tâche : "segmentation" ou "classification".
    """
    def __init__(self, keys):
        super().__init__(keys)
        self.window_preset = {"window_center": 40, "window_width": 80}

        self.seg_pipeline = T.Compose([
            T.EnsureChannelFirstd(keys=["image", "label"]),
            T.CropForegroundd(keys=["image", "label"], source_key='image'),
            T.Orientationd(keys=["image", "label"], axcodes='RAS'),
            T.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=["bilinear", "nearest"]),
            T.SpatialPadd(keys=["image", "label"], spatial_size=(96, 96, 96)),
            T.ScaleIntensityRanged(
                keys=["image"],
                a_min=self.window_preset["window_center"] - self.window_preset["window_width"] // 2,
                a_max=self.window_preset["window_center"] + self.window_preset["window_width"] // 2,
                b_min=0.0, b_max=1.0, clip=True
            ),
            T.RandCropByPosNegLabeld(
                keys=['image', 'label'],
                image_key='image',
                label_key='label',
                pos=5.0,
                neg=1.0,
                spatial_size=(96, 96, 64),
                num_samples=2
            ),
            T.RandFlipd(keys=["image", "label"], spatial_axis=[0, 1], prob=0.5),
            T.RandRotate90d(keys=["image", "label"], spatial_axes=(0, 1), prob=0.5),
            T.RandScaleIntensityd(keys=["image"], factors=0.02, prob=0.5),
            T.RandShiftIntensityd(keys=["image"], offsets=0.05, prob=0.5)
        ])

        self.cls_pipeline = T.Compose([
            T.EnsureChannelFirstd(keys=["image"]),
            T.Orientationd(keys=["image"], axcodes='RAS'),
            T.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
            T.ResizeWithPadOrCropd(keys=["image"], spatial_size=(96,96,96)), #bof ( trop petit) pour classification mais on test
            T.ScaleIntensityRanged(
                keys=["image"],
                a_min=self.window_preset["window_center"] - self.window_preset["window_width"] // 2,
                a_max=self.window_preset["window_center"] + self.window_preset["window_width"] // 2,
                b_min=0.0, b_max=1.0, clip=True
            ),
            T.RandFlipd(keys=["image"], spatial_axis=[0, 1, 2], prob=0.5),
            T.RandRotate90d(keys=["image"], spatial_axes=(0, 1), prob=0.5),
            T.RandScaleIntensityd(keys=["image"], factors=0.1, prob=0.5),
            T.RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5)
        ])
        
    def __call__(self, data):
        task = data["task"]
        if task == "segmentation":
            return self.seg_pipeline(data)
        elif task == "classification":
            return self.cls_pipeline(data)
        else:
            raise ValueError(f"Tâche inconnue : {task}")
        
def get_multitask_transforms():
    return T.Compose([
        T.LoadImaged(keys=["image", "label"], image_only=True),
        TaskBasedTransform(keys=["image", "label"]),
        T.ToTensord(keys=["image", "label"]),
    ])
        
        
from monai.data import Dataset
transforms = get_multitask_transforms()
dataset = Dataset(data=new_multitask_data, transform=transforms)
for i in range(3):
    sample = dataset[i]
    sample = dataset[i]
    print(type(sample))  # debug
    print(sample)
    





<class 'list'>
[{'image': metatensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.5109, 0.4996, 0.4883],
          [0.0000, 0.0000, 0.0000,  ..., 0.4588, 0.4720, 0.4853],
          [0.0000, 0.0000, 0.0000,  ..., 0.4566, 0.4629, 0.4691],
          ...,
          [0.5307, 0.5325, 0.5365,  ..., 0.4462, 0.4519, 0.4576],
          [0.5764, 0.5798, 0.5846,  ..., 0.4625, 0.4626, 0.4627],
          [0.5829, 0.5869, 0.5933,  ..., 0.4615, 0.4697, 0.4778]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.4817, 0.4787, 0.4757],
          [0.0000, 0.0000, 0.0000,  ..., 0.5201, 0.5139, 0.5076],
          [0.0000, 0.0000, 0.0000,  ..., 0.4765, 0.4789, 0.4813],
          ...,
          [0.6058, 0.6257, 0.6405,  ..., 0.4233, 0.4339, 0.4445],
          [0.6124, 0.6147, 0.6199,  ..., 0.4432, 0.4455, 0.4478],
          [0.5991, 0.5903, 0.5891,  ..., 0.4002, 0.4102, 0.4202]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.4960, 0.4767, 0.4574],
          [0.0000, 0.0000, 0.0000,  ..., 0.5408, 0.5202, 0.4996],
          [0

## Création du modèle 

In [5]:
# But est de partager l'encodeur entre les deux tâches, donc modifié le Unet de base de MONAI



from collections.abc import Sequence
from typing import Optional

import torch
import torch.nn as nn

from monai.networks.blocks import Convolution, UpSample
from monai.networks.layers.factories import Conv, Pool
from monai.networks.nets.basic_unet import TwoConv, Down, UpCat
from monai.utils import ensure_tuple_rep


class BasicUNetWithClassification(nn.Module):
    def __init__(
        self,
        spatial_dims: int = 3,
        in_channels: int = 1,
        out_channels: int = 6,  # pour segmentation
        num_cls_classes: int = 6,  # pour classification
        features: Sequence[int] = (32, 32, 64, 128, 256, 32),
        act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
        norm: str | tuple = ("instance", {"affine": True}),
        bias: bool = True,
        dropout: float | tuple = 0.0,
        upsample: str = "deconv",
    ):
        super().__init__()
        fea = ensure_tuple_rep(features, 6)
        print(f"BasicUNet features: {fea}.")

        # Encoder
        self.conv_0 = TwoConv(spatial_dims, in_channels, fea[0], act, norm, bias, dropout)
        self.down_1 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout)
        self.down_2 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout)
        self.down_3 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout)
        self.down_4 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout)

        # Decoder
        self.upcat_4 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample)
        self.upcat_3 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample)
        self.upcat_2 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample)
        self.upcat_1 = UpCat(spatial_dims, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False)

        self.final_conv = Conv["conv", spatial_dims](fea[5], out_channels, kernel_size=1)

        # Classification head → à partir du bottleneck `x4`
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool3d((4, 4, 4)),
            nn.Flatten(),
            nn.Linear(fea[4] * 4 * 4 * 4, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_cls_classes)
        )

    def forward(self, x: torch.Tensor):
        # Encoder
        x0 = self.conv_0(x)
        x1 = self.down_1(x0)
        x2 = self.down_2(x1)
        x3 = self.down_3(x2)
        x4 = self.down_4(x3)

        # Decoder (segmentation)
        u4 = self.upcat_4(x4, x3)
        u3 = self.upcat_3(u4, x2)
        u2 = self.upcat_2(u3, x1)
        u1 = self.upcat_1(u2, x0)
        seg_logits = self.final_conv(u1)

        # Classification
        cls_logits = self.cls_head(x4)  # x4 est le bottleneck

        return seg_logits, cls_logits
    
model_4 = BasicUNetWithClassification(
    spatial_dims=3,
    in_channels=1,
    out_channels=6,  # pour segmentation
    num_cls_classes=6,  # pour classification       
)
model_4 = model_4.to(DEVICE)


BasicUNet features: (32, 32, 64, 128, 256, 32).


In [6]:
try : 
    ! pip install torchinfo

except:
    print("torchinfo is already installed or installation failed.")
    
import torchinfo
from torchinfo import summary
# # Display model summary
model_4 = model_4.to(DEVICE)
summary(model_4, input_size=(8, 1, 96, 96,96))



Layer (type:depth-idx)                             Output Shape              Param #
BasicUNetWithClassification                        [8, 6, 96, 96, 96]        --
├─TwoConv: 1-1                                     [8, 32, 96, 96, 96]       --
│    └─Convolution: 2-1                            [8, 32, 96, 96, 96]       --
│    │    └─Conv3d: 3-1                            [8, 32, 96, 96, 96]       896
│    │    └─ADN: 3-2                               [8, 32, 96, 96, 96]       64
│    └─Convolution: 2-2                            [8, 32, 96, 96, 96]       --
│    │    └─Conv3d: 3-3                            [8, 32, 96, 96, 96]       27,680
│    │    └─ADN: 3-4                               [8, 32, 96, 96, 96]       64
├─Down: 1-2                                        [8, 32, 48, 48, 48]       --
│    └─MaxPool3d: 2-3                              [8, 32, 48, 48, 48]       --
│    └─TwoConv: 2-4                                [8, 32, 48, 48, 48]       --
│    │    └─Convolution: 3-5  

In [16]:
x = torch.randn(2, 1, 96, 96, 96).to(DEVICE) 
seg, cls = model(x)
print("segmentation:", seg.shape)
print("classification:", cls.shape)

segmentation: torch.Size([2, 6, 96, 96, 96])
classification: torch.Size([2, 6])
