In [3]:
from typing import List
import torch
from torch import nn

class mCSI(nn.Module):
    """Compute the critical success index (CSI) score."""
    def __init__(self, thresholds: List[float] = [ 16., 74., 133. ], eps: float = 1e-4) -> None:
        super().__init__()
        self.thresholds = thresholds
        self.eps = eps
    
    @staticmethod
    def _threshold(y_true: torch.FloatTensor, y_pred: torch.FloatTensor, threshold: float) -> torch.FloatTensor:
        """Apply a threshold to both the target and the prediction tensors.

        Parameters
        ----------
        y_true : FloatTensor
            The target tensor.
        y_pred : FloatTensor
            The prediction tensor.
        threshold : float
            The threshold to apply.

        Returns
        -------
        FloatTensor
            The thresholded target tensor.
        FloatTensor
            The thresholded prediction tensor.
        """
    
        y_true_res = (y_true >= threshold).float()
        y_pred_res = (y_pred >= threshold).float()

        is_nan = torch.isnan(y_true) | torch.isnan(y_pred)

        y_true_res[is_nan] = 0
        y_pred_res[is_nan] = 0

        return y_true_res, y_pred_res

    def forward(self, pred, target):
        """
        Compute the critical success index (CSI) score.
         
        Parameters
        ----------
        pred, target:   torch.Tensor
            shape = (batch_size, seq_len, height, width)
        """
        results = 0.

        with torch.no_grad():
            for thresh in self.thresholds:
                target, pred = self._threshold(target, pred, thresh)
                hits = torch.sum(target * pred, dim=(-2, -1)).int()
                misses = torch.sum(target * (1 - pred), dim=(-2, -1)).int()
                fas = torch.sum((1 - target) * pred, dim=(-2, -1)).int()
                csi = hits / (hits + misses + fas + self.eps)
                results += csi.mean()

        return results / len(self.thresholds)