In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [2]:
import os
import json
import torch
import random
import pandas as pd
import numpy as np
import seaborn as sns

from data_handlers import YCDataset, SampleBatchIdx
from models import EmbeddingsMapping
from losses import compute_clust_loss, compute_alignment_loss
from utils import compute_normalization_parameters
from torch.utils.data import DataLoader
from torch import nn, log, exp
from torch.nn import functional as F
from tqdm import tqdm
from pathlib import Path
from matplotlib import pyplot as plt

opj = lambda x, y: os.path.join(x, y)

In [3]:
training_df = pd.read_csv('training_with_labels.csv')
validation_df = pd.read_csv('validation_with_labels.csv')
training_df.shape, validation_df.shape

((1194, 8), (417, 8))

In [4]:
train_dataset = YCDataset(training_df)
batch_sampler = SampleBatchIdx(train_dataset, 8, 24)
train_dl = DataLoader(train_dataset, batch_sampler = batch_sampler)

In [5]:
device = 'cuda:2'
model = EmbeddingsMapping(512, video_layers=2, text_layers=2, drop_layers=2, learnable_drop=True, normalization_dataset=train_dataset)
model = model.to(device)

In [11]:
def run_model(vf, sf, distractor):
    frame_features = model.map_video(vf)
    step_features = model.map_text(sf)
    distractor_features = model.compute_distractors(distractor)

    return frame_features, step_features, distractor_features

frame_gamma = 10
gamma_xz = 10
keep_percentile = 1
l2_normalize = False
for batch in train_dl:
    step_len, step_features, video_len, video_features = batch['step_len'], batch['step_feature'], batch['video_len'], batch['video_feature']
    distractors = torch.stack([ s[:size].mean(0) for s, size in zip(step_features, step_len)], dim=0) # also taking care of the distractor padding (dont worry about it later)
    ff, sf, dif = run_model(video_features.to(device), step_features.to(device), distractors.to(device))
    clust_loss = compute_clust_loss((sf, step_len, ff, video_len, dif), device=device)
    zx_costs_list, drop_costs_list = compute_alignment_loss((sf, step_len, ff, video_len, dif), l2_normalize, gamma_xz, keep_percentile)
    dtw_loss, D = batch_dropDTW(zx_costs_list, drop_costs_list)
    print(clust_loss, sum([c/len(batch) for c in dtw_loss]))

5
tensor(3.9998, device='cuda:2', grad_fn=<MeanBackward0>) tensor(9.0510, device='cuda:2', grad_fn=<AddBackward0>)
5
tensor(3.9998, device='cuda:2', grad_fn=<MeanBackward0>) tensor(9.0510, device='cuda:2', grad_fn=<AddBackward0>)
5
tensor(3.9998, device='cuda:2', grad_fn=<MeanBackward0>) tensor(9.0510, device='cuda:2', grad_fn=<AddBackward0>)
5
tensor(3.9998, device='cuda:2', grad_fn=<MeanBackward0>) tensor(9.0510, device='cuda:2', grad_fn=<AddBackward0>)
5


KeyboardInterrupt: 

In [6]:
class VarTable():
    def __init__(self, dims, device):
        self.dims = dims
        d1, d2, d_rest = dims[0], dims[1], dims[2:]
        
        self.vars = []
        
        # creating the dtw table
        for i in range(d1):
            self.vars.append([])
            for j in range(d2):
                var = torch.zeros(d_rest).to(torch.float).to(device)
                self.vars[i].append(var)
    
    def __getitem__(self, pos):
        i, j = pos
        return self.vars[i][j]

    def __setitem__(self, pos, new_val):
        i, j = pos
        if self.vars[i][j].sum() != 0:
            assert False, 'already assigned'
        else:
            self.vars[i][j] = self.vars[i][j] + new_val
            
    def show(self):
        pass # TODO: needs to be added for visualization
    
def minProb(inputs, gamma = 1, keepdim = True):
    if inputs[0].shape[0] == 1:
        inputs = torch.cat(inputs)
    else:
        inputs = torch.stack(inputs, dim = 0)
    probs = F.softmax(- inputs / gamma, dim = 0)
    minP = (probs * inputs).sum(dim = 0, keepdim = keepdim)
    return minP

In [7]:
def batch_dropDTW(zx_costs_list, drop_costs_list, exclusive=True, contiguous=True):
    
    inf = 99999999
    min_fn = minProb
    
    # to find max padding need to run drop-dtw in batches 
    B = len(zx_costs_list)
    Ns, Ks = [], []
    
    for i in range(B):
        Ki, Ni = zx_costs_list[i].shape
        Ns.append(Ni)
        Ks.append(Ki)
    
    N, K = max(Ns), max(Ks)
    
    
    padded_cum_drop_costs, padded_drop_costs, padded_zx_costs = [], [], []
    
    for i in range(B):
        zx_costs = zx_costs_list[i]
        drop_costs = drop_costs_list[i]
        cum_drop_costs = torch.cumsum(drop_costs, dim=0)
        
        row_pad = torch.zeros([N - Ns[i]]).to(zx_costs.device)
#         print(row_pad.shape)
#         print(cum_drop_costs.shape)
#         print(torch.cat([cum_drop_costs, row_pad]).shape)
        padded_cum_drop_costs.append(torch.cat([cum_drop_costs, row_pad]))
#         print(len(padded_cum_drop_costs))
        padded_drop_costs.append(torch.cat([drop_costs, row_pad]))
        
        multirow_pad = torch.stack([row_pad + inf] * Ks[i], dim=0) # to add padding to each row
#         print(multirow_pad.shape)
        
#         print('padded_table', zx_costs.shape)

        padded_table = torch.cat([zx_costs, multirow_pad], dim=1)
#         print('padded_table', padded_table.shape)
        
        rest_pad = torch.zeros([K - Ks[i], N]).to(zx_costs.device) + inf
        padded_table = torch.cat([padded_table, rest_pad], dim=0)
        
#         print(padded_table.shape)
        
        padded_zx_costs.append(padded_table)
        
#         print("####")
    
    all_zx_costs = torch.stack(padded_zx_costs, dim=-1)
    all_cum_drop_costs = torch.stack(padded_cum_drop_costs, dim=-1)
    all_drop_costs = torch.stack(padded_drop_costs, dim=-1)
    
    
    D = VarTable((K + 1, N + 1, 3, B), device)
    for zi in range(1, K + 1): 
        D[zi, 0] = torch.zeros_like(D[zi, 0]) + inf # init all rows '0th' row with inf 
    for xi in range(1, N + 1):
        D[0, xi] = torch.zeros_like(D[0, xi]) + all_cum_drop_costs[(xi - 1):xi] # init all columns '0th' col with cumulative drops
        
        
    for zi in range(1, K+1):
        for xi in range(1, N+1):
            z_cost_ind, x_cost_ind = zi-1, xi-1
            
            d_diag, d_left = D[zi-1, xi-1][0:1], D[zi-1, xi][0:1]
            dp_left, dp_up = D[zi, xi-1][2:3], D[zi-1, xi][2:3]
            
            if contiguous:
                pos_neighbours = [d_diag, dp_left]
            else:
                pos_neighbours = [d_diag, d_left]
                
            if not exclusive:
                pos_neighbours.append(dp_up)

            Dp = min_fn(pos_neighbours) + all_zx_costs[z_cost_ind, x_cost_ind]
            
            Dm = d_left + all_drop_costs[x_cost_ind]
            
            D_final = min_fn([Dm, Dp])
            
            D[zi, xi] = torch.cat([D_final, Dm, Dp], dim=0)
    
    min_costs = []
    for i in range(B):
        Ni, Ki = Ns[i], Ks[i]
        min_cost_i = D[Ki, Ni][0, i]
        min_costs.append(min_cost_i / Ni)
        
    return min_costs, D