In [None]:
#| default_exp vision_metrics

In [None]:
#| export
import torch
import numpy as np
from monai.metrics import compute_hausdorff_distance, compute_dice
from fastMONAI.vision_data import pred_to_binary_mask, batch_pred_to_multiclass_mask

# Vision metrics
>

In [None]:
#| export
def calculate_dsc(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
    """MONAI `compute_meandice`"""

    return torch.Tensor([compute_dice(p[None], t[None]) for p, t in list(zip(pred,targ))])

In [None]:
#| export
def calculate_haus(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
    """MONAI `compute_hausdorff_distance`"""

    return torch.Tensor([compute_hausdorff_distance(p[None], t[None]) for p, t in list(zip(pred,targ))])

In [None]:
#| export
def binary_dice_score(act: torch.tensor, targ: torch.Tensor) -> torch.Tensor:
    """Calculates the mean Dice score for binary semantic segmentation tasks.
    
    Args:
        act: Activation tensor with dimensions [B, C, W, H, D].
        targ: Target masks with dimensions [B, C, W, H, D].

    Returns:
        Mean Dice score.
    """
    pred = pred_to_binary_mask(act)
    dsc = calculate_dsc(pred.cpu(), targ.cpu())

    return torch.mean(dsc)

In [None]:
#| export
def multi_dice_score(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
    """Calculate the mean Dice score for each class in multi-class semantic 
    segmentation tasks.

    Args:
        act: Activation tensor with dimensions [B, C, W, H, D].
        targ: Target masks with dimensions [B, C, W, H, D].

    Returns:
        Mean Dice score for each class.
    """
    pred, n_classes = batch_pred_to_multiclass_mask(act)
    binary_dice_scores = []

    for c in range(1, n_classes):
        c_pred, c_targ = torch.where(pred == c, 1, 0), torch.where(targ == c, 1, 0)
        dsc = calculate_dsc(c_pred, c_targ)
        binary_dice_scores.append(np.nanmean(dsc)) # #TODO update torch to get torch.nanmean() to work

    return torch.Tensor(binary_dice_scores)

In [None]:
#| export
def binary_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
    """Calculate the mean Hausdorff distance for binary semantic segmentation tasks.
    
    Args:
        act: Activation tensor with dimensions [B, C, W, H, D].
        targ: Target masks with dimensions [B, C, W, H, D].

    Returns:
        Mean Hausdorff distance.
    """
    

    pred = pred_to_binary_mask(act)

    haus = calculate_haus(pred.cpu(), targ.cpu())
    return torch.mean(haus)

In [None]:
#| export
def multi_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor :
    """Calculate the mean Hausdorff distance for each class in multi-class semantic segmentation tasks.
    
    Args:
        act: Activation tensor with dimensions [B, C, W, H, D].
        targ: Target masks with dimensions [B, C, W, H, D].

    Returns:
        Mean Hausdorff distance for each class.
    """

    pred, n_classes = batch_pred_to_multiclass_mask(act)
    binary_haus = []

    for c in range(1, n_classes):
        c_pred, c_targ = torch.where(pred==c, 1, 0), torch.where(targ==c, 1, 0)
        haus = calculate_haus(pred, targ)
        binary_haus.append(np.nanmean(haus))
    return torch.Tensor(binary_haus)

In [None]:
#| hide 

# Test Dice score and Hausdorff distance 
pred = torch.zeros((1,1,10,10,10))
pred[:,:,:5, :5, :5] = 1

targ = torch.zeros((1,1,10,10,10))
targ[:,:,:5, :5, :5] = 1

dsc = float(calculate_dsc(pred, targ)) 
haus = float(calculate_haus(pred,targ))

assert dsc == 1.0
assert haus == 0.0