In [10]:
#|default_exp losses

#|export
import torch
import numpy as np
import pandas as pd
from tsai.basics import *


# Custom losses

In [11]:
#|export
class Loss(nn.Module):
    def __init__(self, ranges, weights, solact_levels):
        super().__init__()
        self.register_buffer('ranges', torch.Tensor(ranges))
        self.register_buffer('weights', torch.Tensor(weights))
        self.solact_levels = solact_levels

    def weighted_loss_tensor(self, target):        
        batch, variables, horizon = target.shape  # Example shape (32, 4, 6)
        variable, range, interval = self.ranges.shape  # Example shape (4, 4, 2)

        target_shaped = torch.reshape(target, (batch, variables, 1, horizon))  # Example shape (32, 4, 6) -> (32, 4, 1, 6)
        ranges_shaped = torch.reshape(self.ranges, (variable, range, 1, interval))  # Example shape (4, 4, 2) -> (4, 4, 1, 2)

        weights_tensor = ((ranges_shaped[..., 0] <= target_shaped) & (target_shaped <= ranges_shaped[..., 1])).float()
        
        return torch.einsum('r,bvrh->bvh', self.weights, weights_tensor)
    
    def loss_measure(self, y_pred, y_true):
        return NotImplementedError
    
    def forward(self, y_pred, y_true):
        error = self.loss_measure(y_pred, y_true)
        weights = self.weighted_loss_tensor(y_true)
        loss = (error * weights).mean()
        
        return loss
    
    # Metrics
    def _loss_call(self, input, target, weight_idx):
        loss_copy = deepcopy(self)
        
        for i in range(len(loss_copy.weights)):
            if i != weight_idx:
                loss_copy.weights[i] = 0

        return loss_copy(input, target)

    def metrics(self):
        metrics = []
        for i, name in enumerate(['low', 'moderate', 'elevated', 'high']):
            def make_metric(self, input, target, i=i):
                return self._loss_call(input, target, i)
            
            make_metric.__name__ = name
            metric_func = types.MethodType(make_metric, self)
            metrics.append(metric_func)
        
        return metrics

In [12]:
# Test
device = 'cpu'
ranges = np.array([[[0, 1], [1, 2], [2, 3], [3, 4]],
                   [[0, 1], [1, 2], [2, 3], [3, 4]],
                   [[0, 1], [1, 2], [2, 3], [3, 4]],
                   [[0, 1], [1, 2], [2, 3], [3, 4]]])

weights = np.array([1, 2, 3, 4])

target = torch.tensor([[[0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
                        [0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
                        [0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
                        [0.5, 1.5, 2.5, 3.5, 4.5, 5.5]]], device=device, dtype=torch.float32)

input = target + 1

expected_weights = torch.tensor([[[1, 2, 3, 4, 0, 0],
                                 [1, 2, 3, 4, 0, 0],
                                 [1, 2, 3, 4, 0, 0],
                                 [1, 2, 3, 4, 0, 0]]], device=device, dtype=torch.float32)

solact_levels = ['low', 'moderate', 'elevated', 'high']




def test_LossWeightsTensor():
    loss = Loss(ranges, weights, solact_levels).to(device)
    result = loss.weighted_loss_tensor(target)

    assert torch.equal(result, expected_weights), f"Expected {expected_weights}, but got {result}"
    print(f"Loss Tensor test passed!")

In [13]:
#|export
class wMSELoss(Loss):
    def __init__(self, ranges, weights, solact_levels):
        super().__init__(ranges, weights, solact_levels)

    
    def loss_measure(self, y_pred, y_true):
        return (y_true-y_pred)**2

In [14]:
#|export
class wMAELoss(Loss):
    def __init__(self, ranges, weights, solact_levels):
        super().__init__(ranges, weights, solact_levels)

    
    def loss_measure(self, y_pred, y_true):
        return torch.abs(y_true-y_pred)

In [15]:
# Test

def check_loss_function(loss_class, expected_value):
    loss = loss_class(ranges, weights, solact_levels).to(device)
    result = loss(input, target)

    assert torch.isclose(result, expected_value), f"Expected {expected_value}, but got {result}"
    print(f"{type(loss).__name__} test passed!")

def test_wMSELoss():
    expected_mse_loss = torch.mean(expected_weights * (input - target) ** 2)
    check_loss_function(wMSELoss, expected_mse_loss)

def test_wMAELoss():
    expected_mae_loss = torch.mean(expected_weights * torch.abs(input - target))
    check_loss_function(wMAELoss, expected_mae_loss)

In [16]:
# Test
def test_LossMetrics():
    loss = wMAELoss(ranges, weights, solact_levels).to(device)
    metrics = loss.metrics()

    loss_value = loss(input, target)
    metrics_values = [metric(input, target) for metric in metrics]

    assert torch.isclose(loss_value, sum(metrics_values)), f"Expected {loss_value}, but got {sum(metrics_values)} ({metrics_values})"
    print("LossMetrics test passed!")

In [17]:
#| test
test_LossWeightsTensor()
test_wMSELoss()
test_wMAELoss()
test_LossMetrics()


Loss Tensor test passed!
wMSELoss test passed!
wMAELoss test passed!
LossMetrics test passed!
