In [244]:
#|default_exp metrics

#|export
import sys
sys.path.append('..')
from abc import ABC, abstractmethod
import torch
import numpy as np
import pandas as pd
from tsai.basics import *
from swdf.losses import wMAELoss, MSELoss, WeightedLoss, ClassificationLoss
from sklearn.metrics import precision_recall_curve, auc



# Metrics
---
## Index

---

[Intro]

In [245]:
#|export

class Metrics(ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def get_metrics(self) -> list:
        return NotImplementedError

## Loss Metrics

We have implemented a class that generates relevant metrics to better evaluate the performance of the weighted loss function. This class calculates how much of the loss is associated with each activity level, providing **deeper insights into the model's behavior**. The number of methods is elevated because each condition requires its own function to be coded, as dynamically generating these functions can lead to errors.

### Regression Metrics
[Text]

In [246]:
#|export

class RegressiveMetrics(Metrics):
    def __init__(self, loss_func):
        super().__init__()
        self.loss_func = loss_func

    def _apply_weighted_loss_by_level(self, input, target, weight_idx):
        loss_copy = deepcopy(self.loss_func)
        
        for idx1 in range(len(loss_copy.weights)):
            if is_iter(loss_copy.weights[0]):
                for idx2 in range(len(loss_copy.weights[idx1])):
                    if (idx1 != weight_idx[0]) | (idx2 != weight_idx[1]):
                        loss_copy.weights[idx1][idx2] = 0
            else:
                if idx1 != weight_idx[1]:
                    loss_copy.weights[idx1] = 0
                
        return loss_copy(input, target)

    @abstractmethod
    def get_metrics(self) -> list:
        return NotImplementedError

#### Solar Indices FSMY 10.7 Metrics
[Text]

In [247]:
#|export

class SOLFMYMetrics(RegressiveMetrics):
    def __init__(self, loss_func):
        super().__init__(loss_func)
        self.loss_func = loss_func



    # Metrics
    def Loss_Low(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [0,0])
    
    def Loss_Moderate(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [0,1])
    
    def Loss_Elevated(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [0,2])
    
    def Loss_High(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [0,3])
    
    
    # Metrics retrieval function
    def get_metrics(self) -> list:
        return [
                self.Loss_Low, 
                self.Loss_Moderate, 
                self.Loss_Elevated, 
                self.Loss_High
            ]

#### Geomagnetic Indices DST and AP Metrics
[Text]

In [248]:
#|export

class GEODSTAPMetrics(RegressiveMetrics):
    def __init__(self, loss_func, indices:str='geodstap'):
        super().__init__(loss_func)
        self.indices = indices
        
        
    # Metrics
    def Loss_Low(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [0,0])
    
    def Loss_Medium(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [0,1])
    
    def Loss_Active(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [0,2])
    
    def Loss_G0(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [1,0])
    
    def Loss_G1(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [1,1])
    
    def Loss_G2(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [1,2])
        
    def Loss_G3(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [1,3])
    
    def Loss_G4(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [1,4])
    
    def Loss_G5(self, input, target):
        return self._apply_weighted_loss_by_level(input, target, [1,5])
    

    # Metrics retrieval function
    def get_metrics(self) -> list:
        if self.indices == 'geodst':
            return [
                self.Loss_G0, 
                self.Loss_G1, 
                self.Loss_G2, 
                self.Loss_G3, 
                self.Loss_G4, 
                self.Loss_G5
            ]
        
        elif self.indices == 'geoap':
            return [
                    self.Loss_Low, 
                    self.Loss_Medium, 
                    self.Loss_Active
                ]
        
        return [
                self.Loss_Low, 
                self.Loss_Medium, 
                self.Loss_Active,
                self.Loss_G0, 
                self.Loss_G1, 
                self.Loss_G2, 
                self.Loss_G3, 
                self.Loss_G4, 
                self.Loss_G5
            ]

### Classification Metrics
[Text]

In [249]:
#|export

class ClassificationMetrics(Metrics):
    def __init__(self, loss_func):
        super().__init__()
        self.loss_func = loss_func



    def _compute_misclassifications(self, predictions, targets):
        # Use the weighted loss tensor from the provided loss function
        classifier = self.loss_func.weighted_loss_tensor
        
        # Get the true and predicted labels using the classifier
        true_labels = classifier(targets)
        predicted_labels = classifier(predictions)

        # Misclassifications are those where the predicted label does not match the true label
        misclassified_labels = (true_labels != predicted_labels).int() * predicted_labels

        return misclassified_labels

    def _count_misclassifications_by_position(self, predictions, targets, row, col):
        # Calculate misclassifications for a specific (row, column) pair
        misclassified_labels = self._compute_misclassifications(predictions, targets)
        
        # Extract the specific misclassification at the (row, column) position and sum across the time dimension
        if row < misclassified_labels.shape[1] and col < misclassified_labels.shape[2]:
            misclassification_count = misclassified_labels[:, row, col].sum().item()
        else:
            misclassification_count = 0  # Out of bounds, assume no misclassification
        
        return misclassification_count
  
    
    @abstractmethod
    def get_metrics(self) -> list:
        return NotImplementedError

#### Solar Indices FSMY 10.7 Metrics
[Text]

In [250]:
#|export

class SOLFMYClassificationMetrics(ClassificationMetrics):
    def __init__(self, loss_func):
        super().__init__(loss_func)


    # Metrics
    def Missclassifications_Low(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 0, 1)

    def Missclassifications_Moderate(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 0, 2)

    def Missclassifications_Elevated(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 0, 3)

    def Missclassifications_High(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 0, 4)


    # Metrics retrieval function
    def get_metrics(self) -> list:
        return [
                self.Missclassifications_Low,
                self.Missclassifications_Moderate, 
                self.Missclassifications_Elevated, 
                self.Missclassifications_High
            ]

#### Geomagnetic Indices DST and AP Metrics
[Text]

In [251]:
#|export

class GEODSTAPClassificationMetrics(ClassificationMetrics):
    def __init__(self, loss_func, indices:str='geodstap'):
        super().__init__(loss_func)
        self.indices = indices


    # Metrics
    def Missclassifications_Low(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 0, 1)

    def Missclassifications_Medium(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 0, 2)

    def Missclassifications_Active(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 0, 3)

    def Missclassifications_G0(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 1, 1)

    def Missclassifications_G1(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 1, 2)

    def Missclassifications_G2(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 1, 3)

    def Missclassifications_G3(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 1, 4)

    def Missclassifications_G4(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 1, 5)

    def Missclassifications_G5(self, predictions, targets):
        return self._count_misclassifications_by_position(predictions, targets, 1, 6)

    # Metrics retrieval function
    def get_metrics(self) -> list:
        if self.indices == 'geodst':
            return [
                    self.Missclassifications_G0,
                    self.Missclassifications_G1, 
                    self.Missclassifications_G2, 
                    self.Missclassifications_G3, 
                    self.Missclassifications_G4, 
                    self.Missclassifications_G5
                ]
        
        elif self.indices == 'geoap':
            return [
                    self.Missclassifications_Low, 
                    self.Missclassifications_Medium, 
                    self.Missclassifications_Active
                ]
        
        return [
                self.Missclassifications_Low, 
                self.Missclassifications_Medium, 
                self.Missclassifications_Active, 
                self.Missclassifications_G0, 
                self.Missclassifications_G1,
                self.Missclassifications_G2,
                self.Missclassifications_G3,
                self.Missclassifications_G4,
                self.Missclassifications_G5
            ]

### Loss Metrics Retrieval
[Text]

In [252]:
#|export

class LossMetrics(Metrics):
    def __init__(self, loss_func, indices:str = ''):
        super().__init__()
        self.loss_func = loss_func
        self.indices = indices

    ## Metrics Not Available
    def Metrics_Not_Available(self, input, target): return np.nan 
    
    # Metrics retrieval
    def get_metrics(self):
        if isinstance(self.loss_func, ClassificationLoss):
            if self.indices.lower() == 'solfsmy':
                return SOLFMYClassificationMetrics(self.loss_func).get_metrics()
            if self.indices.lower() in ['geodstap', 'geoap', 'geodst']:
                return GEODSTAPClassificationMetrics(self.loss_func, self.indices).get_metrics()
        
        if isinstance(self.loss_func, WeightedLoss):
            if self.indices.lower() == 'solfsmy':
                return SOLFMYMetrics(self.loss_func).get_metrics()
            
            if self.indices.lower() in ['geodstap', 'geoap', 'geodst']:
                return GEODSTAPMetrics(self.loss_func, self.indices).get_metrics()
        
        return [self.Metrics_Not_Available]

## Validation Metrics
[Text]

### Outliers Metrics
[Text]

> The choice of 3.5 comes from empirical observations that in normally distributed data, approximately 99.7% of data should fall within a Z-Score of 3. If a point has a Z-Score greater than 3.5, it is considered significantly deviant.

In [253]:
#|export

class OutlierDetectionMetrics(Metrics):
    def __init__(self, threshold=3.5):
        super().__init__()
        self.threshold = threshold

    @staticmethod
    def _modified_z_score(x):
        """
        Calculate the Modified Z-Score for each variable in the tensor.
        
        Parameters:
        tensor (torch.Tensor): Input tensor of shape (batch_size, variables, horizon)
        
        Returns:
        torch.Tensor: Modified Z-Score tensor of the same shape as input
        """
        median = torch.median(x, dim=2, keepdim=True).values
        
        mad = torch.median(torch.abs(x - median), dim=2, keepdim=True).values
        mad = torch.where(mad == 0, torch.tensor(1.0, device=x.device), mad)
        
        modified_z_scores = 0.6745 * (x - median) / mad
        
        return modified_z_scores

    def _detect_outliers(self, values):
        """
        Detect outliers based on Modified Z-Scores.
        
        Parameters:
        z_scores (torch.Tensor): Modified Z-Scores tensor
        
        Returns:
        torch.Tensor: Boolean tensor indicating outliers
        """
        z_scores = self._modified_z_score(values)
        return torch.abs(z_scores) > self.threshold
    
    @abstractmethod
    def get_metrics(self) -> list:
        return NotImplementedError


    

#### F1 Score Metric

In [254]:
class F1ScoreMetrics(OutlierDetectionMetrics):
    def __init__(self, threshold=3.5):
        super().__init__(threshold)

    def _evaluate_outlier_predicted(self, y_true, y_pred):
        """
        Evaluate the performance of outlier detection.
        
        Parameters:
        y_true (torch.Tensor): Actual values tensor of shape (batch_size, variables, horizon)
        y_pred (torch.Tensor): Predicted values tensor of the same shape as y_true
        
        Returns:
        AttrDict: Dictionary with true/false positives, false negatives, indices of true/predicted outliers
        """    
        # Detect outliers based on the threshold
        true_outliers = self._detect_outliers(y_true)
        pred_outliers = self._detect_outliers(y_pred)
        
        # Evaluate the detection by comparing true outliers and predicted outliers
        tp = torch.sum((pred_outliers & true_outliers).float())  # True Positives
        fp = torch.sum((pred_outliers & ~true_outliers).float()) # False Positives
        fn = torch.sum((~pred_outliers & true_outliers).float()) # False Negatives

        return AttrDict({
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "true_outliers": true_outliers,
            "predicted_outliers": pred_outliers
        })
    

    # Metrics
    def Precision(self, y_true, y_pred):
        stats = self._evaluate_outlier_predicted(y_true, y_pred)

        if (stats.tp + stats.fp) > 0:
            precision = stats.tp / (stats.tp + stats.fp)  
        else: 
            precision = torch.tensor(0.0)

        return precision
    
    def Recall(self, y_true, y_pred):
        stats = self._evaluate_outlier_predicted(y_true, y_pred)

        if (stats.tp + stats.fn) > 0:
            recall = stats.tp / (stats.tp + stats.fn)
        else: 
            recall = torch.tensor(0.0)

        return recall
    
    def F1_Score(self, y_true, y_pred):
        precision = self.Precision(y_true, y_pred)
        recall = self.Recall(y_true, y_pred)

        if (precision + recall) > 0:
            f1_score = 2 * (precision * recall) / (precision + recall)
        else: 
            f1_score = torch.tensor(0.0)

        return f1_score
    
    def Outliers_Difference(self, y_true, y_pred):
        stats = self._evaluate_outlier_predicted(y_true, y_pred)
        
        return torch.sum(stats.true_outliers & ~stats.predicted_outliers)


    # Metrics retrieval function
    def get_metrics(self) -> list:
        return [self.Precision, self.Recall, self.F1_Score, self.Outliers_Difference]

### Area Under the Precision-Recall Curve (AURPC)
[Text]

In [255]:
#|export

class AUPRCMetric(OutlierDetectionMetrics):
    def __init__(self, threshold=3.5):
        super().__init__(threshold=threshold)

    def AURPC(self, y_true, y_pred):
        """
        Calculate the Area Under the Precision-Recall Curve (AUPRC).
        
        Parameters:
        y_true (torch.Tensor): Actual values tensor of shape (batch_size, variables, horizon)
        y_pred (torch.Tensor): Predicted values tensor of the same shape as y_true
        
        Returns:
        torch.Tensor: AUPRC score
        """
        pred_z_scores = self._modified_z_score(y_pred)
        
        pred_z_scores_flat = pred_z_scores.view(-1).cpu().numpy()
        true_outliers_flat = self._detect_outliers(y_true).view(-1).cpu().numpy()
        
        # Use precision_recall_curve to get precision and recall for different thresholds
        precision, recall, _ = precision_recall_curve(true_outliers_flat, pred_z_scores_flat)
        
        print(type(precision), type(recall))
        auprc_value = auc(recall, precision)
        
        return torch.tensor(auprc_value, device=y_true.device)

    def get_metrics(self) -> list:
        return [self.AURPC]

### R2 Metric
[Text]

In [256]:
#|export

# TODO: Revisar
class R2Score(Metrics):
    def __init__(self):
        super().__init__()

    def R2(self, y_true, y_pred):
        """
        Calculate the R^2 score.
        
        Parameters:
        y_true (torch.Tensor): Actual values tensor of shape (batch_size, variables, horizon)
        y_pred (torch.Tensor): Predicted values tensor of the same shape as y_true
        
        Returns:
        torch.Tensor: R^2 score
        """
        # Calculate the mean of the actual values
        y_true_mean = torch.mean(y_true, dim=2, keepdim=True)
        
        # Calculate the total sum of squares
        ss_tot = torch.sum((y_true - y_true_mean) ** 2)
        
        # Calculate the residual sum of squares
        ss_res = torch.sum((y_true - y_pred) ** 2)
        
        # Calculate the R^2 score
        r2 = 1 - ss_res / ss_tot
        
        return r2

    def get_metrics(self) -> list:
        return [self.R2]
    
    

## Tests

In [257]:
device = 'cpu'
ranges = {'A': np.array([[0, 1], [1, 2], [2, 3], [3, 4]]),
          'B': np.array([[0, 1], [1, 2], [2, 3], [3, 4]]),
          'C': np.array([[0, 1], [1, 2], [2, 3], [3, 4]]),
          'D': np.array([[0, 1], [1, 2], [2, 3], [3, 4]])}

weights = {'A': np.array([1, 2, 3, 4])}

target = torch.tensor([[[0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
                        [0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
                        [0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
                        [0.5, 1.5, 2.5, 3.5, 4.5, 5.5]]], device=device, dtype=torch.float32)

input = target + 1

In [258]:
# Test

def test_LossMetrics():
    loss = wMAELoss(ranges, weights).to(device)
    metrics = LossMetrics(loss, 'SolFSMY').get_metrics()

    loss_value = loss(input, target)
    metrics_values = [metric(input, target) for metric in metrics]

    assert torch.isclose(loss_value, sum(metrics_values)), f"Expected {loss_value}, but got {sum(metrics_values)} ({metrics_values})"
    print("LossMetrics test passed!")

def test_LossMetrics_for_classification():
    loss = ClassificationLoss(ranges, MSELoss()).to(device)
    metrics = SOLFMYClassificationMetrics(loss)

    # Compute the total misclassifications manually for all specific positions
    total_counts = 0
    total_counts += metrics.Missclassifications_Low(input, target)
    total_counts += metrics.Missclassifications_Moderate(input, target)
    total_counts += metrics.Missclassifications_Elevated(input, target)
    total_counts += metrics.Missclassifications_High(input, target)

    # Use the generate_metrics method to retrieve and calculate all defined metrics
    metrics_functions = LossMetrics(loss, 'SolFSMY').get_metrics()
    metrics_values = [metric(input, target) for metric in metrics_functions]

    # Assert that the total manually calculated matches the sum of individual metrics
    assert np.isclose(total_counts, sum(metrics_values)), f"Expected {total_counts}, but got {sum(metrics_values)} ({metrics_values})"
    print("LossMetrics for classification loss test passed!")

In [259]:
y_true = torch.tensor([
    [[10, 12, 12, 13, 12, 12, 12, 14, 12, 100], [20, 22, 23, 20, 22, 20, 21, 22, 23, 200]],
    [[-11, -12, -13, -13, -14, -14, -12, -13, -14, -105], [-22, -23, -25, -23, -22, -24, -23, -22, -25, -210]]
], dtype=torch.float)

y_pred = torch.tensor([
    [[11, 12, 12, 13, 12, 12, 13, 14, 13, 90], [21, 22, 23, 21, 22, 21, 22, 22, 23, 195]],
    [[-12, -13, -14, -14, -15, -14, -13, -14, -14, -10], [-23, -24, -26, -24, -23, -25, -24, -23, -26, -205]]
], dtype=torch.float)

In [260]:
def test_OutlierDetectionMetrics():
    metrics = F1ScoreMetrics().get_metrics()
    metrics_precision = metrics[0](y_true, y_pred) 
    metrics_recall = metrics[1](y_true, y_pred)
    metrics_f1_score = metrics[2](y_true, y_pred)
    metrics_outliers_difference = metrics[3](y_true, y_pred)

    f1 = 2 * (metrics_precision * metrics_recall) / (metrics_precision + metrics_recall)

    assert metrics_precision == 1.0, f"Expected 1.0, but got {metrics_precision}"
    assert metrics_recall == 0.75, f"Expected 0.75, but got {metrics_recall}"
    assert metrics_f1_score == f1, f"Expected {f1}, but got {metrics_f1_score}"
    assert metrics_outliers_difference == 1, f"Expected 2, but got {metrics_outliers_difference}"

    print("OutlierDetectionMetrics test passed!")

In [261]:
test_LossMetrics()
test_LossMetrics_for_classification()
test_OutlierDetectionMetrics()






LossMetrics test passed!
















LossMetrics for classification loss test passed!
OutlierDetectionMetrics test passed!
