In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.utils.labels_mapping import RelabelByModality
from src.utils.loss import HybridDiceCLDiceLoss

from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    RandCropByPosNegLabeld
)
from monai.data import DataLoader, Dataset

import numpy as np
import pandas as pd
import nibabel as nib

import torch

In [2]:
train_data = pd.read_csv('../data/processed/train_split.csv')
train_data = train_data.drop(columns=['file_name']).rename(columns={'image_path': 'image', 'label_path': 'label'})
train_data['image'] = train_data['image'].apply(lambda x: os.path.join('..', x))
train_data['label'] = train_data['label'].apply(lambda x: os.path.join('..', x))
train_dict = train_data.to_dict('records')

transforms = Compose([
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    RelabelByModality(keys=['label']),
    RandCropByPosNegLabeld(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[96, 96, 96],
        pos=1,
        neg=1,
        num_samples=4,
        image_key="image",
        image_threshold=0,
    )])

dataset = Dataset(data=train_dict, transform=transforms)
dataloader = DataLoader(dataset, batch_size=1)

In [3]:
# loss init
loss_fn = HybridDiceCLDiceLoss(iter_=7)

# perfect case : pred = target
print("Perfect case (pred = target):")
for batch_data in dataloader:
    inputs, targets = batch_data["image"], batch_data["label"]
    # If targets is (B, 1, H, W, D), squeeze to (B, H, W, D)
    preds = torch.nn.functional.one_hot(targets.squeeze(1).long(), num_classes=49).permute(0, 4, 1, 2, 3).float()
    preds_logits = preds * 20 - (1 - preds) * 20  # make logits very large for perfect prediction
    loss = loss_fn(preds_logits, targets)
    print(f"Loss (perfect case): {loss.item():.4f}")
    break

print('-' * 30)
# Bad case : True foreground, but class completely wrong
print("Bad case (Foreground (v > 0) in the good place, but class completely wrong):")
for batch_data in dataloader:
    inputs, targets = batch_data["image"], batch_data["label"]
    # Create completely wrong predictions by shifting the one-hot encoding
    targets_shifted = torch.where(targets == 0, 0, (targets + 1) % 49)  # Shift classes by 1, keep background as 0
    preds = torch.nn.functional.one_hot(targets_shifted.squeeze(1).long(), num_classes=49).permute(0, 4, 1, 2, 3).float()
    preds_logits = preds * 20 - (1 - preds) * 20  # make logits very large for wrong prediction
    loss = loss_fn(preds_logits, targets)
    print(f"Loss (bad case): {loss.item():.4f}")
    break

print('-' * 30)
print('Worst case (pred = background):')
for batch_data in dataloader:
    inputs, targets = batch_data["image"], batch_data["label"]
    # Create completely wrong predictions by setting all to background
    preds = torch.nn.functional.one_hot(torch.zeros_like(targets).squeeze(1).long(), num_classes=49).permute(0, 4, 1, 2, 3).float()
    preds_logits = preds * 20 - (1 - preds) * 20  # make logits very large for wrong prediction
    loss = loss_fn(preds_logits, targets)
    print(f"Loss (worst case): {loss.item():.4f}")
    break

Perfect case (pred = target):
Loss (perfect case): 0.0000
------------------------------
Bad case (Foreground (v > 0) in the good place, but class completely wrong):
Loss (bad case): 0.6266
------------------------------
Worst case (pred = background):
Loss (worst case): 0.7713
