In [None]:
import torch
import torch.nn as nn
import random
import os
import numpy as np
#set the random seed
# random.seed(0)
# torch.manual_seed(0)
# np.random.seed(0)   



dir_path = '../../../DataSet/IAM-Online/Resized_Dataset/Train/Images/'
num_files = len([f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))])
print(num_files)

def loss_fn(input_seq, target_seq):
    '''Loss Function for DTW Loss'''
    ''' Input sequence and target sequence are of the form n x 5, where n is the number of points in the sequence and 5 is the dimension of each point representing x, y, time, start_of_stroke (binary), end_of_stroke (binary)'''

    n = input_seq.shape[0]
    m = target_seq.shape[0]

    # Create a matrix to store the accumulated distances
    dtw_matrix = torch.zeros((n + 1, m + 1))

    # Initialize the first row and column of the matrix
    dtw_matrix[0, 1:] = float('inf')
    dtw_matrix[1:, 0] = float('inf')

    # Create a matrix to store the optimal warping path
    path_matrix = torch.zeros((n + 1, m + 1), dtype=torch.int)

    # Calculate the accumulated distances and optimal warping path
    cost = torch.cdist(input_seq[:, :2], target_seq[:, :2])  # Pairwise Euclidean distances

    min_cost = torch.minimum(dtw_matrix[:-1, 1:], dtw_matrix[1:, :-1])
    dtw_matrix[1:, 1:] = cost + torch.minimum(min_cost, dtw_matrix[:-1, :-1])

    for i in range(n):
        for j in range(m):
            min_cost = torch.min(torch.stack([dtw_matrix[i, j+1], dtw_matrix[i+1, j], dtw_matrix[i, j]]))
            dtw_matrix[i+1, j+1] = cost[i, j] + min_cost

    # for i in range(n):
    #     for j in range(m):
    #         min_cost = min(dtw_matrix[i, j+1], dtw_matrix[i+1, j], dtw_matrix[i, j])
    #         dtw_matrix[i+1, j+1] = cost[i, j] + min_cost

     # Update the path matrix based on the minimum cost
    path_matrix[1:, 1:][dtw_matrix[1:, :-1] < dtw_matrix[:-1, :-1]] = 2  # Horizontal movement
    path_matrix[1:, 1:][dtw_matrix[:-1, 1:] < dtw_matrix[:-1, :-1]] = 1  # Vertical movement
    path_matrix[1:, 1:][dtw_matrix[:-1, :-1] <= torch.min(dtw_matrix[1:, :-1], dtw_matrix[:-1, 1:])] = 3  # Diagonal movement

    # Calculate the DTW loss as the last element in the matrix
    dtw_loss = dtw_matrix[-1, -1]

    # Compute the optimal warping path
    i, j = n, m
    warping_path = [(i, j)]
    while i > 1 or j > 1:
        if path_matrix[i, j] == 1:
            i -= 1  # Vertical movement
        elif path_matrix[i, j] == 2:
            j -= 1  # Horizontal movement
        else:
            i -= 1  # Diagonal movement
            j -= 1
        warping_path.append((i, j))

    warping_path.reverse()

    # Perform backward propagation to compute gradients
    # dtw_loss.backward()

    # Retrieve the gradients
    # gradients = input_seq.grad

    return dtw_loss, warping_path


In [None]:
def loss_fn1(input_seq, target_seq):
    '''Loss Function for DTW Loss'''
    ''' Input sequence and target sequence are of the form n x 5, where n is the number of points in the sequence and 5 is the dimension of each point representing x, y, time, start_of_stroke (binary), end_of_stroke (binary)'''

    n = input_seq.shape[0]
    m = target_seq.shape[0]

    # Create a matrix to store the accumulated distances
    dtw_matrix = torch.zeros((n + 1, m + 1))

    # Initialize the first row and column of the matrix
    dtw_matrix[0, 1:] = float('inf')
    dtw_matrix[1:, 0] = float('inf')

    # Calculate the accumulated distances
    cost_matrix = torch.cdist(input_seq[:, :2], target_seq[:, :2])  # Pairwise Euclidean distances
    dtw_matrix[1:, 1:] = cost_matrix + torch.min(torch.min(dtw_matrix[:-1, 1:], dtw_matrix[1:, :-1]), dtw_matrix[:-1, :-1])

    # Create a matrix to store the optimal warping path
    path_matrix = torch.zeros((n + 1, m + 1), dtype=torch.int)

    # Update the path matrix based on the minimum cost
    path_matrix[1:, 1:][dtw_matrix[1:, :-1] < dtw_matrix[:-1, :-1]] = 2  # Horizontal movement
    path_matrix[1:, 1:][dtw_matrix[:-1, 1:] < dtw_matrix[:-1, :-1]] = 1  # Vertical movement
    path_matrix[1:, 1:][dtw_matrix[:-1, :-1] <= torch.min(dtw_matrix[1:, :-1], dtw_matrix[:-1, 1:])] = 3  # Diagonal movement

    # Calculate the DTW loss as the last element in the matrix
    dtw_loss = dtw_matrix[-1, -1]

    # Compute the optimal warping path
    i, j = n, m
    warping_path = [(i, j)]
    while i > 1 or j > 1:
        if path_matrix[i, j] == 1:
            i -= 1  # Vertical movement
        elif path_matrix[i, j] == 2:
            j -= 1  # Horizontal movement
        else:
            i -= 1  # Diagonal movement
            j -= 1
        warping_path.append((i, j))

    warping_path.reverse()

    # Perform backward propagation to compute gradients
    # dtw_loss.backward()

    # Retrieve the gradients
    # gradients = input_seq.grad

    return dtw_loss, warping_path

In [70]:
# Test the loss function
# function to get random images from the dataset
img_num1 = random.randint(1, num_files + 1)
stroke_path = '../../../DataSet/IAM-Online/Resized_Dataset/Train/Strokes/' + f'stroke_{img_num1}.npy'
stroke = np.load(stroke_path)
input_seq = torch.from_numpy(stroke).float()
inp_seq = input_seq.clone() 
#multiply an offset of constant value to x and y in input_seq
offset = 5
inp_seq[:, 0] += offset
inp_seq[:, 1] += offset

img_num2 = random.randint(1, num_files + 1)
stroke_path = '../../../DataSet/IAM-Online/Resized_Dataset/Train/Strokes/' + f'stroke_{img_num2}.npy'
stroke = np.load(stroke_path)
target_seq = torch.from_numpy(stroke).float()
print(input_seq.shape)
print(inp_seq.shape)

torch.Size([168, 5])
torch.Size([168, 5])


In [None]:
loss, path = loss_fn(input_seq, target_seq)
#plot input_seq(only x and y coordinates)
import matplotlib.pyplot as plt
plt.plot(input_seq[:, 0], input_seq[:, 1])
plt.show()

#plot inp_seq(only x and y coordinates)
plt.plot(target_seq[:, 0], target_seq[:, 1])
plt.show()
print('DTW loss:', loss)
print('Optimal Warping Path:', path)
#print input_seq size
print(f'Input Sequence size:- {input_seq.shape}')
print(f'No. of mappings:- {len(path)}')
# print('Gradients:', gradients)

In [None]:
#Method1
loss, path = loss_fn1(input_seq, target_seq)
#plot input_seq(only x and y coordinates)
import matplotlib.pyplot as plt
plt.plot(input_seq[:, 0], input_seq[:, 1])
plt.show()

#plot inp_seq(only x and y coordinates)
plt.plot(target_seq[:, 0], target_seq[:, 1])
plt.show()
print('DTW loss:', loss)
print('Optimal Warping Path:', path)
#print input_seq size
print(f'Input Sequence size:- {input_seq.shape}')
print(f'No. of mappings:- {len(path)}')
# print('Gradients:', gradients)


In [None]:
from soft_dtw import SoftDTW
criterion = SoftDTW(gamma=1.0, normalize=True)
loss = criterion(input_seq, target_seq)
print('DTW loss:', loss)

In [None]:
import numpy as np
from scipy.spatial.distance import euclidean

from fastdtw import fastdtw

x = input_seq[:, :2].numpy()
y = target_seq[:, :2].numpy()
distance, path = fastdtw(x, y, dist=euclidean)
print(distance)
#print shape of path
print(f'Number of mappings:- {np.array(path)}')

## Class for DTW-Loss

Comments:- Does not work at the moment, need to fix tensor to np array and dimension issues

In [71]:
class DTWLoss(nn.Module):
    def __init__(self, gamma=1.0, normalize=True):
        super(DTWLoss, self).__init__()
        self.gamma = gamma
        self.normalize = normalize

    def forward(self, input, target):
        # Compute the loss
        distance, path = fastdtw(input[:, :2].numpy(), target[:, :2].numpy(), dist=euclidean)
        loss = distance
        loss = torch.tensor(loss)

        # Return the loss
        return loss

In [72]:
import torch
from torch.autograd import Function
from scipy.spatial.distance import euclidean
from fastdtw import fastdtw

class DTWLoss(Function):
    @staticmethod
    def forward(ctx, input, target):
        # Compute the loss
        distance, path = fastdtw(input[:, :2].detach().numpy(), target[:, :2].detach().numpy(), dist=euclidean)
        loss = torch.tensor(distance, requires_grad=True)

        # Save the path for use in backward pass
        ctx.save_for_backward(input, target, torch.tensor(path))

        # Return the loss
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved tensors from forward pass
        input, target, path = ctx.saved_tensors

        # Initialize gradients for input and target
        grad_input = torch.zeros_like(input)
        grad_target = torch.zeros_like(target)

        # Compute gradients for input and target based on the path
        for i, j in path:
            grad_input[i, :2] += grad_output * (input[i, :2] - target[j, :2])
            grad_target[j, :2] += grad_output * (target[j, :2] - input[i, :2])

        # Return the gradients
        return grad_input, grad_target

# Example usage
# input = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], requires_grad=True)
# target = torch.tensor([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]], requires_grad=True)

# Create input and target tensors using input_sec and target_seq for calculating DTW loss and gradients
input = inp_seq.clone()
target = target_seq.clone()

#activate gradient
input.requires_grad = True
target.requires_grad = True


# Create an instance of DTWLoss
dtw_loss = DTWLoss.apply

# Compute DTW loss
loss = dtw_loss(input, target)

# Perform backpropagation
loss.backward()

# Retrieve gradients
gradients_input = input.grad
gradients_target = target.grad

print("Gradients for input:", gradients_input)
print("Gradients for target:", gradients_target)

Gradients for input: tensor([[ 1.0412e+01,  4.8414e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0509e+01,  4.2664e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0412e+01,  3.4996e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0509e+01,  2.6370e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0412e+01,  1.8702e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0316e+01,  7.2004e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 9.7387e+00, -3.3428e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 8.5839e+00, -1.5803e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 7.4292e+00, -2.3471e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 5.6008e+00, -2.7304e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 4.1574e+00, -2.0595e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 3.3446e+00, -1.8476e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 3.3444e+00, -2.4309e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00],
    