In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from libs.losses import  l2, SignalDiceLoss

In [2]:
sampling_rate = 1000
sdsc   = SignalDiceLoss(eps=0, soft=False)
sdsc_1 = SignalDiceLoss(eps=0, alpha=1)
sdsc_10 = SignalDiceLoss(eps=0, alpha=10)
sdsc_100 = SignalDiceLoss(eps=0, alpha=100)

In [3]:
def mse(a, b):
    return torch.mean((a-b)**2)

def mae(a, b):
    return torch.mean(torch.abs(a-b))

def dtw_distance_normalized(s1, s2):
    n, m = len(s1), len(s2)
    dtw_matrix = np.full((n + 1, m + 1), np.inf)
    dtw_matrix[0, 0] = 0

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            cost = abs(s1[i - 1] - s2[j - 1])
            dtw_matrix[i, j] = cost + min(
                dtw_matrix[i - 1, j],    # insertion
                dtw_matrix[i, j - 1],    # deletion
                dtw_matrix[i - 1, j - 1] # match
            )

    # Traceback to count the path length
    i, j = n, m
    path_length = 0
    while i > 0 and j > 0:
        path_length += 1
        min_idx = np.argmin([
            dtw_matrix[i-1, j-1],
            dtw_matrix[i-1, j],
            dtw_matrix[i, j-1]
        ])
        if min_idx == 0:
            i -= 1
            j -= 1
        elif min_idx == 1:
            i -= 1
        else:
            j -= 1
    path_length += (i + j)  # finish remaining steps

    return dtw_matrix[n, m] / path_length
    
def make_signals(sampling_rate, noise_level=0.3, shift=1, scale_factor=0.5):
    t = torch.linspace(0,2*np.pi, sampling_rate) 
    gt = torch.sin(t).clone().detach().requires_grad_(True)
    signals = {
        "t" : t,
        "gt" : gt,
        "inverted": -gt,
        "scaled" :scale_factor * gt,
        "shift"  :gt+shift,
        "noisy":gt + noise_level * torch.randn(sampling_rate),
        "jittered":(gt + 0.05 * torch.randn_like(gt)),
    }
    return signals


def compute_grad_norm(loss_func, x, y):
    y = y.clone().detach().requires_grad_(True)
    loss = loss_func(x, y)
    loss.backward(retain_graph=True)
    return y.grad.norm().item()

def print_grad_norm(x, y):
    print("MSE : ",compute_grad_norm(mse, x, y))
    print("MAE : ",compute_grad_norm(mae,  x, y))
    print("SDSC LOSS :",compute_grad_norm(sdsc,  x, y))
    print("SDSC Alpha 1 LOSS :",compute_grad_norm(sdsc_1,  x, y))
    print("SDSC Alpha 10 LOSS :",compute_grad_norm(sdsc_10,  x, y))
    print("SDSC Alpha 100 LOSS :",compute_grad_norm(sdsc_100,  x, y))

In [4]:
signal = make_signals(sampling_rate)

# Inverted

In [5]:
print_grad_norm(signal["gt"], signal["inverted"])

MSE :  0.08939799666404724
MAE :  0.03160696476697922
SDSC LOSS : 0.0
SDSC Alpha 1 LOSS : 0.009130844846367836
SDSC Alpha 10 LOSS : 0.00824382621794939
SDSC Alpha 100 LOSS : 0.004718021955341101


In [6]:
print_grad_norm(signal["gt"], signal["scaled"])

MSE :  0.02234949916601181
MAE :  0.03160696476697922
SDSC LOSS : 0.04417587071657181
SDSC Alpha 1 LOSS : 0.028901392593979836
SDSC Alpha 10 LOSS : 0.043713416904211044
SDSC Alpha 100 LOSS : 0.04336824640631676


In [14]:
print_grad_norm(signal["gt"], signal["gt"]*2)

MSE :  0.04469899833202362
MAE :  0.03160696476697922
SDSC LOSS : 0.011043967679142952
SDSC Alpha 1 LOSS : 0.006246507633477449
SDSC Alpha 10 LOSS : 0.010244142264127731
SDSC Alpha 100 LOSS : 0.010846496559679508


In [8]:
print_grad_norm(signal["gt"], signal["gt"].roll(1))

MSE :  0.00028113272855989635
MAE :  0.03162277862429619
SDSC LOSS : 0.024848824366927147
SDSC Alpha 1 LOSS : 0.017074046656489372
SDSC Alpha 10 LOSS : 0.023008694872260094
SDSC Alpha 100 LOSS : 0.02430911548435688


# Shifted

In [9]:
print_grad_norm(signal["gt"], signal["shift"])

MSE :  0.06324555724859238
MAE :  0.03162277862429619
SDSC LOSS : 0.007514291908591986
SDSC Alpha 1 LOSS : 0.007416722364723682
SDSC Alpha 10 LOSS : 0.008759652264416218
SDSC Alpha 100 LOSS : 0.007695657666772604


In [10]:
print_grad_norm(signal["gt"], signal["shift"]-2)

MSE :  0.06324555724859238
MAE :  0.03162277862429619
SDSC LOSS : 0.007514291908591986
SDSC Alpha 1 LOSS : 0.007416720502078533
SDSC Alpha 10 LOSS : 0.008759652264416218
SDSC Alpha 100 LOSS : 0.007695657666772604


In [11]:
print_grad_norm(signal["gt"], signal["noisy"])

MSE :  0.019361242651939392
MAE :  0.03162277862429619
SDSC LOSS : 0.023788530379533768
SDSC Alpha 1 LOSS : 0.015233566984534264
SDSC Alpha 10 LOSS : 0.02279873937368393
SDSC Alpha 100 LOSS : 0.023698238655924797


In [13]:
print_grad_norm(signal["gt"], torch.zeros_like(signal['gt']))

MSE :  0.04469899833202362
MAE :  0.03160696476697922
SDSC LOSS : 0.0
SDSC Alpha 1 LOSS : 0.0
SDSC Alpha 10 LOSS : 0.0
SDSC Alpha 100 LOSS : 0.0


# Jittered

In [12]:
print_grad_norm(signal["gt"], signal["jittered"])

MSE :  0.0032219982240349054
MAE :  0.03162277862429619
SDSC LOSS : 0.02475747838616371
SDSC Alpha 1 LOSS : 0.016573229804635048
SDSC Alpha 10 LOSS : 0.02287229709327221
SDSC Alpha 100 LOSS : 0.024257011711597443
