In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [2]:
import os
import json
import torch
import random
import pandas as pd
import numpy as np
import seaborn as sns

from data_handlers import YCDataset, SampleBatchIdx
from models import EmbeddingsMapping
from utils import compute_normalization_parameters
from torch.utils.data import DataLoader
from torch import nn, log, exp
from torch.nn import functional as F
from tqdm import tqdm
from pathlib import Path
from matplotlib import pyplot as plt

opj = lambda x, y: os.path.join(x, y)

In [3]:
training_df = pd.read_csv('training_with_labels.csv')
validation_df = pd.read_csv('validation_with_labels.csv')
training_df.shape, validation_df.shape

((1193, 10), (417, 9))

In [6]:
train_dataset = YCDataset(training_df, video_len=775)
batch_sampler = SampleBatchIdx(train_dataset, 8, 24)
train_dl = DataLoader(train_dataset, batch_sampler = batch_sampler)

In [8]:
device = 'cuda:3'
model = EmbeddingsMapping(512, video_layers=2, text_layers=2, drop_layers=2, learnable_drop=True, normalization_dataset=train_dataset)
model = model.to(device)

In [9]:
def mil_nce(features_1, features_2, correspondance_mat, eps=1e-8, gamma=1, hard_ratio=1):
    corresp = correspondance_mat.to(torch.float32)
    prod = (features_1 @ features_2.T) / gamma
    
    prod_exp = exp(prod - prod.max(dim=1, keepdim=True).values)
    nominator = (prod_exp * corresp).sum(dim=1) # this sum needs to maximized?
    denominator = prod_exp.sum(dim=1)
    
    nll = -log(nominator / (denominator + eps)) # minimize this ratio will give spread to the data?
    
    return nll.mean()


def compute_clust_loss(sf, step_len, ff, video_len, dif, frame_gamma=10, l2_normalize=False):
    all_pooled_frames = []
    all_step_features = []
    frame_labels = [0]
    for idx, sample in enumerate(zip(sf, step_len, ff, video_len, dif)):
        st, s_l, fr, v_l, dis = sample
        st, fr = torch.vstack([st[:s_l], dis]), fr[:v_l] # appending distractor to step as a frame can also be correlated to be dropped
        
        if l2_normalize:
            st = F.normalize(st, p=2, dim=1)
            fr = F.normalize(fr, p=2, dim=1)
            frame_gamma = 0.1
        
        sim = (st @ fr.T) # similarity comparison between steps and frames
        # TODO: check to see if some kind of attention can be learned here
        weights = F.softmax(sim / frame_gamma, dim=1) # this gamma allows expanded attention of steps to all frames -> temperature
        attended_st = weights @ fr
        all_pooled_frames.append(attended_st)
        frame_labels.append(s_l + 1)
        all_step_features.append(st)
    all_pooled_frames = torch.cat(all_pooled_frames, dim=0)
    all_step_features = torch.cat(all_step_features, dim=0)
    N_steps = all_pooled_frames.shape[0]
    frame_labels = np.cumsum(frame_labels)
    xz_label_mat = torch.zeros([N_steps, N_steps])
    for i in range(1, len(frame_labels)):
        xz_label_mat[frame_labels[i-1]:frame_labels[i], frame_labels[i-1]:frame_labels[i]] = 1.
    xz_label_mat = xz_label_mat.to(device)
    
    xz_loss = mil_nce(all_pooled_frames, all_step_features, xz_label_mat)
    return xz_loss

In [39]:
def compute_all_costs(sample, l2_normalize, gamma_xz, drop_cost_type, keep_percentile):
    sf, step_len, ff, video_len, dis = sample
    sf, ff = sf[:step_len], ff[:video_len]
    
    if l2_normalize:
        sf = F.normalize(sf, p=2, dim=1)
        ff = F.normalize(ff, p=2, dim=1)
    sim = sf @ ff.T # getting similarity costs
    if drop_cost_type == 'logit':
        k = max([1, int(torch.numel(sim) * keep_percentile)])
        baseline_logit = torch.topk(sim.reshape(-1), k).values[-1].detach()
        baseline_logits = baseline_logit.repeat([1, sim.shape[1]])
        sims_ext = torch.cat([sim, baseline_logits], dim=0)
    else:
        if l2_normalize:
            dis = F.normalize()
        distractor_sim = ff @ dis
        sims_ext = torch.cat([sim, distractor_sim[None, :]], dim=0)

    softmax_sims = F.softmax(sims_ext/gamma_xz, dim=0) 
    matching_probs, drop_probs = softmax_sims[:-1], softmax_sims[-1]
    zx_costs = - log(matching_probs + 1e-5)
    drop_costs = - log(drop_probs + 1e-5)  
    return zx_costs, drop_costs

In [37]:
def compute_alignment_loss(samples, l2_normalize, gamma_xz, drop_cost_type, keep_percentile):
    
    gamma_xz = 0.1 if l2_normalize else gamma_xz
    sf, step_len, ff, video_len, dif = samples

    zx_costs_list = []
    drop_costs_list = []
    print(len(samples))
    
    for idx, sample in enumerate(zip(sf, step_len, ff, video_len, dif)):
        zx_costs, drop_costs = compute_all_costs(sample, l2_normalize=False, gamma_xz=10, drop_cost_type=drop_cost_type, keep_percentile=keep_percentile)
        zx_costs_list.append(zx_costs)
        drop_costs_list.append(drop_costs)
        break
        
    return zx_costs_list, drop_costs_list
    

In [38]:
def run_model(vf, sf, distractor):
    frame_features = model.map_video(vf)
    step_features = model.map_text(sf)
    distractor_features = model.compute_distractors(distractor)

    return frame_features, step_features, distractor_features

frame_gamma = 10
gamma_xz = 10
keep_percentile = 1
drop_cost_type = 
l2_normalize = False
for batch in train_dl:
    step_len, step_features, video_len, video_features = batch['step_len'], batch['step_feature'], batch['video_len'], batch['video_feature']
    distractors = torch.stack([ s[:size].mean(0) for s, size in zip(step_features, step_len)], dim=0) # also taking care of the distractor padding (dont worry about it later)
    ff, sf, dif = run_model(video_features.to(device), step_features.to(device), distractors.to(device))
#     loss = compute_clust_loss(sf, step_len, ff, video_len, dif, frame_gamma, l2_normalize)
    loss = compute_alignment_loss((sf, step_len, ff, video_len, dif), l2_normalize, gamma_xz, keep_percentile) 
    break

5
torch.Size([6, 90])
540 torch.Size([6, 90])
162
tensor(-8.0052, device='cuda:3')
tensor([[-8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052, -8.0052,
         -8.0052, -8.0052]