In [1]:
import sys
import os

# Chemin du dossier parent
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(parent_dir)

import data.dataset as dataset
import data.transform as T_mtsk

# Import depuis models/
from models.architecture import BasicUNetWithClassification
from models.lightning_module import MultiTaskHemorrhageModule, MultiTaskHemorrhageModule_homeo, MultiTaskHemorrhageModule_gradnorm, MultiTaskSoftSharing

import utils

from monai.data import DataLoader, PersistentDataset
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import torch
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

import config

import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

train_data=dataset.get_equalized_multitask_dataset('train')
random.shuffle(train_data)  # Mélange des données pour l'entraînement
val_data=dataset.get_equalized_multitask_dataset('val')

print(f"Nombre d'images d'entraînement : {len(train_data)}")
print(f"Nombre d'images de validation : {len(val_data)}")

print("Exemple d'entrée de train_data :", train_data[0])

print("visualisation de quel taches sont présentes dans les données d'entraînement :")
for item in train_data[:10]:
    print(f"  {item['task']}")
    
print("=== Test direct du dataset ===")
for i in range(min(4, len(train_data))):
    item = train_data[i]
    print(f"\nItem {i}:")
    print(f"  Task: {item.get('task')}")
    if 'image' in item:
        if isinstance(item['image'], list):
            print(f"  Image: LIST de {len(item['image'])} éléments")
            print(f"    Premier élément shape: {item['image'][0].shape if hasattr(item['image'][0], 'shape') else 'no shape'}")
        else:
            print(f"  Image shape: {item['image'].shape if hasattr(item['image'], 'shape') else 'no shape'}")

TRAIN - Positifs: 371 | Négatifs: 371
Nombre de données de segmentation : 154
Nombre de données de classification équilibrées : 742
VAL - Positifs: 92 | Négatifs: 92
Nombre de données de segmentation : 38
Nombre de données de classification équilibrées : 184
Nombre d'images d'entraînement : 896
Nombre d'images de validation : 222
Exemple d'entrée de train_data : {'image': '/home/tibia/Projet_Hemorragie/MBH_label_case/ID_dfcea57b_ID_331f2c76e6.nii.gz', 'label': array([1., 0., 0., 0., 1., 0.], dtype=float32), 'task': 'classification'}
visualisation de quel taches sont présentes dans les données d'entraînement :
  classification
  segmentation
  classification
  segmentation
  classification
  classification
  classification
  classification
  classification
  classification
=== Test direct du dataset ===

Item 0:
  Task: classification
  Image shape: no shape

Item 1:
  Task: segmentation
  Image shape: no shape

Item 2:
  Task: classification
  Image shape: no shape

Item 3:
  Task: seg

In [3]:
from monai.data import Dataset
import data.transform as T_mtsk

train_transforms, val_transforms = T_mtsk.TaskBasedTransform_V2(keys=["image", "label"]), T_mtsk.TaskBasedValTransform_V2(keys=["image", "label"])

import shutil

# Supprime complètement le cache
shutil.rmtree("./cache_dir", ignore_errors=True)

train_dataset = PersistentDataset(
        train_data, 
        transform=train_transforms,
        cache_dir="./cache_dir")


for i in range(15):
    print(f"Échantillon {i} :")
    for key, value in train_data[i].items():
         print(f"  - {key}: {value}")
print("Nombre total d'éléments dans le dataset:", len(train_data))

>>> TaskBasedTransform initialized
>>> TaskBasedTransform initialized
Échantillon 0 :
  - image: /home/tibia/Projet_Hemorragie/MBH_label_case/ID_dfcea57b_ID_331f2c76e6.nii.gz
  - label: [1. 0. 0. 0. 1. 0.]
  - task: classification
Échantillon 1 :
  - image: /home/tibia/Projet_Hemorragie/Seg_hemorragie/split_MONAI/train/img/ID_a84354a5_ID_6699da85bb.nii.gz
  - label: /home/tibia/Projet_Hemorragie/Seg_hemorragie/split_MONAI/train/seg/ID_a84354a5_ID_6699da85bb.nii.gz
  - task: segmentation
Échantillon 2 :
  - image: /home/tibia/Projet_Hemorragie/MBH_label_case/ID_67b355ba_ID_52ff5ab9b6.nii.gz
  - label: [0. 0. 0. 0. 0. 0.]
  - task: classification
Échantillon 3 :
  - image: /home/tibia/Projet_Hemorragie/Seg_hemorragie/split_MONAI/train/img/ID_96967deb_ID_5d782e9379.nii.gz
  - label: /home/tibia/Projet_Hemorragie/Seg_hemorragie/split_MONAI/train/seg/ID_96967deb_ID_5d782e9379.nii.gz
  - task: segmentation
Échantillon 4 :
  - image: /home/tibia/Projet_Hemorragie/MBH_label_case/ID_3922d1cd_ID



In [7]:
import pprint

loader = DataLoader(
    train_dataset, 
    batch_size=4, 
    shuffle=True, # Pour plus de clarté
    num_workers=0,
    collate_fn=utils.multitask_collate_fn # Utilisation de la fonction de collate personnalisée

)

# --- Exécution du Test ---
print(f"Dataset initial (volumes) : {len(train_data)}")
print(f"Batch Size (volumes) : {loader.batch_size}")
print("-" * 50)
print("Vérification de la taille des lots APRÈS l'application des patches (via collate_fn) :")

print (f"Nombre total de lots dans le DataLoader : {len(loader)}")
print("Résultat du premier lot :")
first_batch = next(iter(loader))

# Avec le call initial 

print("=== Infos générales ===")
print(f"type: {type(first_batch)}")
print(f"keys: {list(first_batch.keys())}")

print("=== Infos par clé ===")

cls_count = first_batch['classification']['image'].shape[0] if first_batch['classification'] is not None else 0
seg_count = first_batch['segmentation']['image'].shape[0] if first_batch['segmentation'] is not None else 0
total_patches = cls_count + seg_count

print(f"Patches de classification: {cls_count}")
print(f"Patches de segmentation: {seg_count}")

print(f"  shape : {list(first_batch['classification']["image"].shape)}")

#


Dataset initial (volumes) : 896
Batch Size (volumes) : 4
--------------------------------------------------
Vérification de la taille des lots APRÈS l'application des patches (via collate_fn) :
Nombre total de lots dans le DataLoader : 224
Résultat du premier lot :
=== Infos générales ===
type: <class 'dict'>
keys: ['classification', 'segmentation']
=== Infos par clé ===
Patches de classification: 48
Patches de segmentation: 0
  shape : [48, 1, 96, 96, 96]


 Test Dataloader utils 

In [22]:
import torch

# Votre fonction flatten
def flatten(batch):
    for item in batch:
        if isinstance(item, list):
            yield from flatten(item)
        else:
            yield item

# --- Création d'un batch d'exemple ---
# Nous simulerons des données simples (tenseurs) et une imbrication
# pour montrer que flatten fonctionne.

# Tenseurs aléatoires
tensor_A = torch.randn(3, 10)  # Tenseur de 3x10
tensor_B = torch.randn(3, 10)  # Tenseur de 3x10
tensor_C = torch.randn(3, 10)  # Tenseur de 3x10
tensor_D = torch.randn(3, 10)  # Tenseur de 3x10
tensor_E = torch.randn(3, 10)  # Tenseur de 3x10

# Batch imbriqué
# [ Élément simple, [ Liste d'éléments simples ], Élément simple, [ Liste imbriquée ] ]
example_batch = [
    tensor_A,
    [tensor_B, tensor_C],
    tensor_D,
    [[tensor_E]] # Double imbrication
]

print("--- 1. Structure du Batch (Avant flatten) ---")
print(f"Type de l'élément à l'index 1: {type(example_batch[1])}")
print(f"Longueur du batch: {len(example_batch)}")


# --- Test de la fonction flatten ---
flat_batch = list(flatten(example_batch))

print("\n--- 2. Structure du Batch (Après flatten) ---")
print(f"Type du résultat: {type(flat_batch)}")
print(f"Longueur du batch aplati: {len(flat_batch)}")

print("\n--- 3. Vérification des Tenseurs ---")

# On vérifie que les tenseurs originaux sont tous présents
print(f"Le Tenseur A est-il le premier élément? {flat_batch[0] is tensor_A}")
print(f"Le Tenseur E est-il le dernier élément? {flat_batch[-1] is tensor_E}")
print(f"Forme du premier élément aplati: {flat_batch[0].shape}")
print(f"Type du dernier élément aplati: {type(flat_batch[-1])}")
print(f"batch aplati complet: {flat_batch}")

--- 1. Structure du Batch (Avant flatten) ---
Type de l'élément à l'index 1: <class 'list'>
Longueur du batch: 4

--- 2. Structure du Batch (Après flatten) ---
Type du résultat: <class 'list'>
Longueur du batch aplati: 5

--- 3. Vérification des Tenseurs ---
Le Tenseur A est-il le premier élément? True
Le Tenseur E est-il le dernier élément? True
Forme du premier élément aplati: torch.Size([3, 10])
Type du dernier élément aplati: <class 'torch.Tensor'>
batch aplati complet: [tensor([[ 0.8457,  0.4845,  1.5668, -1.8926, -0.6487,  1.3580,  0.9388,  1.2377,
          1.1349, -0.7695],
        [-1.1854,  0.1463,  1.2118, -1.5987, -0.0058, -1.6769,  0.1111,  0.9685,
         -0.9121, -0.6517],
        [-1.3885, -1.3750,  0.7516,  1.0253,  0.2023,  0.6105,  0.7807,  1.9234,
         -0.6046,  1.6436]]), tensor([[-0.7419,  0.5092, -0.0504, -1.5096, -1.2095, -2.1121,  0.1571, -0.8586,
         -0.2099, -0.7358],
        [ 0.5918,  1.1614,  1.2502, -0.4430,  2.1776, -0.8310,  1.3217, -0.2493,
 

# Répartition dataset

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

csv_path = Path("/home/tibia/Projet_Hemorragie/MBH_label_case/case-wise_annotation.csv")
label_cols = ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']
df = pd.read_csv(csv_path)

def visualize_label_distribution(label_file_counts, total_files):
    fig = plt.figure(figsize=(20, 12))
    
    labels = list(range(6))
    counts = [label_file_counts[i] for i in labels]
    label_names = ['Background', 'EDH', 'IPH', 'IVH', 'SAH', 'SDH']
    
    colors = ['#2C3E50', '#E74C3C', '#3498DB', '#F39C12', '#9B59B6', '#1ABC9C']
    
    ax1 = plt.subplot(2, 3, 1)
    bars = ax1.bar(labels, counts, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
    ax1.set_xlabel('Label', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Number of Files', fontsize=12, fontweight='bold')
    ax1.set_title('Files Containing Each Label', fontsize=14, fontweight='bold')
    ax1.set_xticks(labels)
    ax1.grid(axis='y', alpha=0.3)
    
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 1,
                f'{count}', ha='center', va='bottom', fontweight='bold')
    
    ax2 = plt.subplot(2, 3, 2)
    percentages = [(count/total_files)*100 for count in counts]
    bars2 = ax2.barh(label_names, percentages, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
    ax2.set_xlabel('Percentage of Files (%)', fontsize=12, fontweight='bold')
    ax2.set_title('Percentage Distribution', fontsize=14, fontweight='bold')
    ax2.grid(axis='x', alpha=0.3)
    
    for i, (bar, pct) in enumerate(zip(bars2, percentages)):
        width = bar.get_width()
        ax2.text(width + 1, bar.get_y() + bar.get_height()/2.,
                f'{pct:.1f}%', ha='left', va='center', fontweight='bold')
    
    ax3 = plt.subplot(2, 3, 3)
    hemorrhage_counts = counts[1:]
    hemorrhage_labels = label_names[1:]
    hemorrhage_colors = colors[1:]
    
    wedges, texts, autotexts = ax3.pie(hemorrhage_counts, labels=hemorrhage_labels, 
                                       colors=hemorrhage_colors, autopct='%1.1f%%',
                                       startangle=90, explode=(0.05, 0.05, 0.05, 0.05, 0.05))
    ax3.set_title('Hemorrhage Types Distribution\n(Excluding Background)', fontsize=14, fontweight='bold')
    
    for autotext in autotexts:
        autotext.set_color('white')
        autotext.set_fontweight('bold')
    
    ax4 = plt.subplot(2, 3, 4)
    background_count = counts[0]
    lesion_counts = counts[1:]
    
    bottom = 0
    bar_width = 0.6
    
    ax4.bar('Dataset', background_count, bar_width, label='Background', 
            color=colors[0], alpha=0.8, edgecolor='black', linewidth=1)
    bottom += background_count
    
    for i, (count, color, name) in enumerate(zip(lesion_counts, colors[1:], label_names[1:])):
        ax4.bar('Dataset', count, bar_width, bottom=bottom, label=name,
                color=color, alpha=0.8, edgecolor='black', linewidth=1)
        bottom += count
    
    ax4.set_ylabel('Number of Files', fontsize=12, fontweight='bold')
    ax4.set_title('Stacked Distribution', fontsize=14, fontweight='bold')
    ax4.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    ax5 = plt.subplot(2, 3, 5)
    wedges, texts = ax5.pie(counts, labels=labels, colors=colors, 
                           wedgeprops=dict(width=0.5), startangle=90)
    
    ax5.text(0, 0, f'Total\n{total_files}\nFiles', ha='center', va='center',
             fontsize=14, fontweight='bold')
    ax5.set_title('Complete Distribution\n(Donut Chart)', fontsize=14, fontweight='bold')
    
    ax6 = plt.subplot(2, 3, 6)
    ax6.axis('off')
    
    summary_data = []
    for i, (label, count, name) in enumerate(zip(labels, counts, label_names)):
        pct = (count/total_files)*100
        summary_data.append([f'Label {label}', name, count, f'{pct:.1f}%'])
    
    summary_data.append(['', 'TOTAL FILES', total_files, '100.0%'])
    summary_data.append(['', 'Files w/ Lesions', (df['any'] == 1).sum(), ''])
    
    table = ax6.table(cellText=summary_data,
                     colLabels=['Label', 'Type', 'Count', 'Percentage'],
                     cellLoc='center',
                     loc='center',
                     colWidths=[0.15, 0.45, 0.2, 0.2])
    
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 2)
    
    for i in range(len(summary_data) + 1):
        for j in range(4):
            cell = table[(i, j)]
            if i == 0:
                cell.set_facecolor('#34495E')
                cell.set_text_props(weight='bold', color='white')
            elif i == len(summary_data) - 1 or i == len(summary_data):
                cell.set_facecolor('#ECF0F1')
                cell.set_text_props(weight='bold')
            else:
                cell.set_facecolor(colors[i-1] if i-1 < len(colors) else '#FFFFFF')
                if i-1 == 0:
                    cell.set_text_props(color='white', weight='bold')
    
    ax6.set_title('Summary Statistics', fontsize=14, fontweight='bold')
    
    plt.suptitle('Hemorrhage Dataset Annotation Distribution Analysis', 
                 fontsize=18, fontweight='bold', y=0.98)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    
    plt.show()
    
    print("\n" + "="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    print(f"Total files: {total_files}")
    print(f"Files with background (Label 0): {counts[0]} ({counts[0]/total_files*100:.1f}%)")
    print("Files with hemorrhage types:")
    for i in range(1, 6):
        pct = (counts[i]/total_files)*100
        print(f"  Label {i}: {counts[i]} files ({pct:.1f}%)")
    
    total_lesion_files_present = (df['any'] == 1).sum()
    print(f"\nFiles with at least one hemorrhage (any=1): {total_lesion_files_present}")

total_files = len(df)
label_file_counts = [0] * 6
label_file_counts[0] = (df['any'] == 0).sum() # Background (no hemorrhage)
label_file_counts[1] = (df['epidural'] == 1).sum()
label_file_counts[2] = (df['intraparenchymal'] == 1).sum()
label_file_counts[3] = (df['intraventricular'] == 1).sum()
label_file_counts[4] = (df['subarachnoid'] == 1).sum()
label_file_counts[5] = (df['subdural'] == 1).sum()

visualize_label_distribution(label_file_counts, total_files)

### Deboggage avec pprint