In [9]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.utils.loss import HybridDiceCLDiceLoss

import numpy as np
import nibabel as nib

import torch

In [None]:
# loss init
loss_fn = HybridDiceCLDiceLoss()
loss_fn_with_weights = HybridDiceCLDiceLoss(class_weights=[2, 2])

# input dimensions
B, C, H, W, D = 1, 3, 16, 16, 16

# perfect case : pred = target
target = torch.randint(low=0, high=C, size=(B, 1, H, W, D))
pred = torch.nn.functional.one_hot(target.squeeze(1), num_classes=C).permute(0, 4, 1, 2, 3).float()

# logits -> add int to simulate true model output
pred_logits = pred + 10

loss = loss_fn(pred_logits, target)
loss_with_weights = loss_fn_with_weights(pred_logits, target)
print(f"Loss perfect case (no weights), expected : ~ 0.0000 : {loss.item():.4f}")

print()
# random case
pred_logits_rand = torch.randn((B, C, H, W, D))
loss_rand = loss_fn(pred_logits_rand, target)
loss_rand_with_weights = loss_fn_with_weights(pred_logits_rand, target)
print(f"Loss random case (no weights): {loss_rand.item():.4f}")
print(f"Loss random case (with weights, expected : x2): {loss_rand_with_weights.item():.4f}")

# check backward
pred_logits_rand.requires_grad_()
loss = loss_fn(pred_logits_rand, target)
loss.backward()
print("Gradients calculés ?", pred_logits_rand.grad.abs().sum().item() > 0)

Loss perfect case (no weights), expected : ~ 0.0000 : 0.4249

Loss random case (no weights): 0.6663
Loss random case (with weights, expected : x2): 1.3326
