In [2]:
import sys
import os

# 1. Récupère le chemin absolu du dossier courant du notebook
current_dir = os.getcwd()

# 2. Récupère le chemin du dossier parent (le dossier 'Segmentation')
parent_dir = os.path.dirname(current_dir)

# 3. Ajoute ce dossier parent au path de Python
sys.path.append(parent_dir)

print(f"Dossier ajouté au path : {parent_dir}")
print("Tu peux maintenant importer config, data, et models.")


import config
import data.dataset as dataset
from models.lightning_module  import MultiTaskHemorrhageModule

Dossier ajouté au path : /store/home/tibia/Projet_Hemorragie/Seg_hemorragie/Multitask_model
Tu peux maintenant importer config, data, et models.


In [3]:
import monai.transforms as T
def get_val_transforms():
    return T.Compose([
        T.LoadImaged(keys=["image", "seg"]),
        T.EnsureChannelFirstd(keys=["image", "seg"]),
        T.CropForegroundd(keys=['image', 'seg'], source_key='image'),
        T.Orientationd(keys=["image", "seg"], axcodes='RAS'),
        T.Spacingd(keys=["image", "seg"],
                   pixdim=(1., 1., 1.),
                   mode=['bilinear', 'nearest']),
        T.SpatialPadd(keys=["image", "seg"],
                      spatial_size=(96, 96, 96),),
        T.ScaleIntensityRanged(keys=["image"],
                               a_min=-10, a_max=140,
                               b_min=0.0, b_max=1.0,
                               clip=True),
    ])



In [None]:
import os
import torch
import pytorch_lightning as pl
import nibabel as nib
import numpy as np
import pandas as pd 
from pathlib import Path
from monai.data import DataLoader, PersistentDataset
from tqdm import tqdm
import monai.transforms as T
import config 
import data.dataset as dataset
import data.transform as T_seg
from models.lightning_module import MultiTaskHemorrhageModule
from monai.data import decollate_batch
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric, DiceHelper


import monai.transforms as T
def get_val_transforms():
    return T.Compose([
        T.LoadImaged(keys=["image", "label"]),
        T.EnsureChannelFirstd(keys=["image", "label"]),
        T.CropForegroundd(keys=['image', 'label'], source_key='image'),
        T.Orientationd(keys=["image", "label"], axcodes='RAS'),
        T.Spacingd(keys=["image", "label"],
                   pixdim=(1., 1., 1.),
                   mode=['bilinear', 'nearest']),
        T.SpatialPadd(keys=["image", "label"],
                      spatial_size=(96, 96, 96),),
        T.ScaleIntensityRanged(keys=["image"],
                               a_min=-10, a_max=140,
                               b_min=0.0, b_max=1.0,
                               clip=True),
    ])



def save_nifti(data, affine, filename, output_dir):

    Path(output_dir).mkdir(parents=True, exist_ok=True)
    if data.ndim == 4:
        data = np.transpose(data, (1, 2, 3, 0)) # (C, D, H, W) -> (D, H, W, C)
        
    nifti_img = nib.Nifti1Image(data, affine)
    nib.save(nifti_img, os.path.join(output_dir, filename))
    
def main():
  
    pl.seed_everything(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 1. Configuration
    ckpt_path = Path("/home/tibia/Projet_Hemorragie/MBH_multitask_64x64/best_model.ckpt")
    output_seg_dir = Path("/home/tibia/Projet_Hemorragie/MBH_multitask_64x64/predictions_nifti")
    output_seg_dir.mkdir(parents=True, exist_ok=True)
    
    print(f" Démarrage de l'inférence UNIQUE (Images + Métriques)")
    print(f"Chargement : {ckpt_path}")

    # 2. Modèle
    model = MultiTaskHemorrhageModule.load_from_checkpoint(ckpt_path, num_steps=1000)
    model.eval()
    model.to(device)
    dice_helper = DiceHelper(
    include_background=False,
    softmax=True,
    reduction="none",
)
    val_transforms = get_val_transforms()
     
    def seg_predictor(x):
        seg_logits, _ = model(x, task="segmentation")
        return seg_logits
        # 3. Données (Test)
        
    test_files =dataset.get_segmentation_data(split="test")
    
    test_dataset = PersistentDataset(
        test_files,
        transform=val_transforms, #meme transfo que val
        cache_dir=os.path.join(output_seg_dir.parent, "cache_test_inference")
    )
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
     
     
    # Test rapide : afficher la shape d'une image avant les transfos
    load= T.LoadImageD(keys=['image']) 
    test_img = load( test_files[-1])  # pour afficher les infos de l'image
    print (f" shape image test : {test_img['image'].shape}")
    # print(f"keys test img : {list(test_img.keys())}")
    # print(f"Métadonnées de l'image : {test_img['image'].meta.keys()}")
    
   
    
    #print du datalaoder
    
    # print(f"Démarrage du DataLoader pour test...")
    # for i, data in enumerate(test_loader):
    #     print(f" Batch {i} :")
    #     for key in data.keys():
    #         print(f"  - {key} : type {type(data[key])}, shape {data[key].shape if isinstance(data[key], torch.Tensor) else 'N/A'}")
    #         print(f"    Métadonnées ({key}.meta) : {data[key].meta.keys() if hasattr(data[key], 'meta') else 'N/A'}")
    #     if i >= 2:  # Limite à 3 batches pour éviter trop d'output
    #         break

   
#     # 2. TRANSFORM & INVERTER
   
    
#     # L'inverter a besoin de savoir quelles transfos inverser
    inverter = T.Invertd(
        keys=["pred"],
        transform=val_transforms,
        orig_keys=["image"],
        nearest_interp=False, # On garde les probas fluides pour l'inversion
        to_tensor=True
    )
    
    results = [] # Pour stocker dice_scores
    # 4. Boucle d'inférence
    with torch.no_grad():
        for batch in tqdm(test_loader):
            
            images = batch["image"].to(device)
            
#             # B. Inférence (Sliding Window)
            batch["pred"] = sliding_window_inference(images, roi_size=(64, 64, 64),  sw_batch_size=2,  predictor=seg_predictor, overlap=0.5,mode="gaussian")
            scores, _ = dice_helper(batch["pred"], batch["seg"].to(device)) # _ car get_non_nans =true  dans le dicehelper
            print (f" scores shape : {scores.shape}")
            dice_scores = {}
            #for class_idx in range(3):  # In_house
            for class_idx in range(5):  # MBH
                dice_scores[f"dice_c{class_idx+1}"] = scores[0, class_idx].item()
                
        
            
            print (f" shape pred avant inversion : {batch['pred'].shape}")
           
            decollated_batch = decollate_batch(batch)
            pred_inversed = inverter(decollated_batch[0]) #va l'appliquer automatiquement à la bonne clé 
            
            pred = pred_inversed["pred"]
            print (f" shape pred inversed : {pred.shape}") 
            argmax_map = torch.argmax(pred, dim=0).cpu().numpy().astype(np.uint8)
            affine = pred.meta["affine"]
            
            filename =Path(decollated_batch[0]["image"].meta["filename_or_obj"]).name
          

            row = {
                "filename": filename,
            }
            row.update(dice_scores)

            results.append(row)
            save_nifti(argmax_map, affine, f"PRED_{filename}", output_seg_dir)
            df = pd.DataFrame(results)
     # 5. Sauvegarde des résultats dans un CSV pour inférence per patient       
    csv_path = os.path.join(output_seg_dir.parent, "inference_metrics.csv")
    df.to_csv(csv_path, index=False)

    print(f"✅ CSV sauvegardé : {csv_path}")
   
if __name__ == "__main__":
    main()
            


Seed set to 42


 Démarrage de l'inférence UNIQUE (Images + Métriques)
Chargement : /home/tibia/Projet_Hemorragie/MBH_multitask_64x64/best_model.ckpt
BasicUNet features: (32, 32, 64, 128, 256, 32).
Répartition des poids : tensor([1., 1., 1., 1., 1., 1.])




 shape image test : torch.Size([512, 512, 32])


  0%|          | 0/50 [00:00<?, ?it/s]


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/transforms/transform.py", line 141, in apply_transform
    return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/transforms/transform.py", line 98, in _apply_transform
    return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)
                                                                               ^^^^^^^^^^^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/transforms/io/dictionary.py", line 162, in __call__
    for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/transforms/transform.py", line 475, in key_iterator
    raise KeyError(
KeyError: 'Key `seg` of transform `LoadImaged` was missing in the data and allow_missing_keys==False.'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/data/dataset.py", line 108, in __getitem__
    return self._transform(index)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/data/dataset.py", line 412, in _transform
    pre_random_item = self._cachecheck(self.data[index])
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/data/dataset.py", line 385, in _cachecheck
    _item_transformed = self._pre_transform(deepcopy(item_transformed))  # keep the original hashed
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/data/dataset.py", line 323, in _pre_transform
    item_transformed = self.transform(item_transformed, end=first_random, threading=True)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/transforms/compose.py", line 335, in __call__
    result = execute_compose(
             ^^^^^^^^^^^^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/transforms/compose.py", line 111, in execute_compose
    data = apply_transform(
           ^^^^^^^^^^^^^^^^
  File "/home/tibia/Projet_Hemorragie/hemorragie-env/lib/python3.12/site-packages/monai/transforms/transform.py", line 171, in apply_transform
    raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform <monai.transforms.io.dictionary.LoadImaged object at 0x7f9443c2cad0>
