In [2]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [3]:
import os
import json
import torch
import random
import pandas as pd
import numpy as np
import seaborn as sns
import wandb


from dtw import drop_dtw
from data_handlers import YCDataset, SampleBatchIdx
from models import EmbeddingsMapping
from losses import compute_all_costs, compute_clust_loss, compute_alignment_loss
from utils import compute_normalization_parameters
from torch.utils.data import DataLoader
from torch import nn, log, exp
from torch.nn import functional as F
from tqdm import tqdm
from pathlib import Path
from matplotlib import pyplot as plt

opj = lambda x, y: os.path.join(x, y)

In [4]:
training_df = pd.read_csv('training_with_labels_s3dg.csv')
validation_df = pd.read_csv('validation_with_labels_s3dg.csv')

gt_training = torch.load('s3d_labelled_video_train.pkl')
gt_validation = torch.load('s3d_labelled_video_val.pkl')
training_df.shape, validation_df.shape

((1237, 9), (436, 9))

In [5]:
train_dataset = YCDataset(training_df, video_len=775)
batch_sampler = SampleBatchIdx(train_dataset, 8, 24)
train_dl = DataLoader(train_dataset, batch_sampler = batch_sampler)

valid_dataset = YCDataset(validation_df, video_len=775)
valid_dl = DataLoader(valid_dataset, batch_size=len(valid_dataset))

In [6]:
# wandb.init(project="dropdtw", entity="dhruvmetha")

In [7]:
device = 'cuda:3'
epoch_curr = 0
epochs = 10
model = EmbeddingsMapping(512, video_layers=3, text_layers=3, drop_layers=2, learnable_drop=True, normalization_dataset=train_dataset, batch_norm=True)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-5, steps_per_epoch=49, epochs=epochs)
folder_name = 's3d_our_ckpts_notext_s3dg_cl4_dtw2.5_learn_adamw_lr_1cr_1e-5'
model_name = 'best_model_state_25.pth'
if not os.path.exists(folder_name):
    os.makedirs(folder_name)
elif not os.path.exists(opj(folder_name, model_name)):
    pass
else:
    model_state = torch.load(f'{folder_name}/best_model_state_5.pth')
    model.load_state_dict(model_state['model'])
    optimizer.load_state_dict(model_state['optimizer'])
    epoch_curr = model_state['epoch']
model = model.to(device)

In [8]:



l2_normalize = False
drop_cost_type = 'learn'
keep_percentile = 0.3

clust_losses = []
dtw_losses = []

def run_model(vf, sf, distractor, drop_cost_type):
    frame_features = model.map_video(vf)
    step_features = model.map_text(sf)
    if drop_cost_type == 'learn':
        distractor_features = model.compute_distractors(distractor)
    else:
        distractor_features = [None] * frame_features.shape[0] ## bugs - check please
    return step_features, frame_features, distractor_features

def run_model_eval(vf, sf, distractor, drop_cost_type):
    with torch.no_grad():
        frame_features = model.map_video(vf)
        step_features = model.map_text(sf)
        if drop_cost_type == 'learn':
            distractor_features = model.compute_distractors(distractor)
        else:
            distractor_features = [None] * frame_features.shape[0]
        return step_features, frame_features, distractor_features
    
def framewise_accuracy(frame_assignment, gt_assignment, num_frames, use_unlabeled=False):
    # to discount unlabeled frames in gt
    if not use_unlabeled:
        unlabled = np.count_nonzero(gt_assignment == -1)
        num_frames = num_frames - unlabled
        fa = np.logical_and(frame_assignment == gt_assignment, gt_assignment != -1).sum()
    else:
        fa = np.count_nonzero((frame_assignment == gt_assignment))
    # framewise accuracy
    fa = fa / num_frames if num_frames != 0 else 0
    return fa

def IoU(frame_assignment, gt_assignment, num_steps):

    intersection, union = 0, 0
    for s in range(num_steps):
        intersection += np.logical_and(gt_assignment == s, frame_assignment == s).sum()
        union += np.logical_or(gt_assignment == s, frame_assignment == s).sum()
    return intersection / union


def evaluate(batch, drop_cost_type):
    framewise_acc = 0.
    iou = 0.
    
    id_, step_len, step_features, video_len, video_features = batch['id'], batch['step_len'], batch['step_feature'], batch['video_len'], batch['video_feature']
#     print(id_, step_len, step_features, video_len, video_features)
    
    if drop_cost_type == 'learn':
        distractors = torch.stack([ s[:size].mean(0) for s, size in zip(step_features, step_len)], dim=0).to(device)
    else:
        distractors = [None] * len(id_)
    
    for _, sample in enumerate(zip(id_, step_len, step_features, video_len, video_features, distractors)):
        
        _id, s_l, sf, v_l, vf, dis = sample
        with torch.no_grad():
            if model is not None:
                sf = sf.to(device)
                vf = vf.to(device)
                if dis is not None:
                    dis = dis.to(device)
                m_sf, m_vf, m_dis = run_model_eval(vf, sf, dis, drop_cost_type)
                m_sf, m_vf = m_sf.detach().cpu().numpy(), m_vf.detach().cpu().numpy()
                if dis is not None:
                    m_dis = m_dis.detach().cpu().numpy()

            else:
                m_sf, m_vf, m_dis = sf, vf, dis

            
            zx_costs, drop_costs = compute_all_costs((m_sf, s_l, m_vf, v_l, m_dis), l2_normalize=l2_normalize, gamma_xz=10, drop_cost_type=drop_cost_type, keep_percentile=keep_percentile)
            zx_costs, drop_costs = [t.detach().cpu().numpy() for t in [zx_costs, drop_costs]]
            optimal_assignment = drop_dtw(zx_costs, drop_costs, return_labels=True) - 1

            framewise_acc += framewise_accuracy(optimal_assignment, gt_validation[_id], v_l).item()
#             iou += IoU(optimal_assignment, gt_validation[_id], s_l).item()
    return framewise_acc, iou

In [None]:
epochs = 30

l2_normalize = False
drop_cost_type = 'learn'
keep_percentile = 0.3

clust_losses = []
dtw_losses = []

prev_loss = 100000.0
for epoch in range(epoch_curr, epochs):
    loss = 0
    model.train()
    for batch in tqdm(train_dl):
        
        optimizer.zero_grad()
        step_len, step_features, video_len, video_features = batch['step_len'], batch['step_feature'], batch['video_len'], batch['video_feature']
        
        if drop_cost_type == 'learn':
            distractors = torch.stack([s[:size].mean(0) for s, size in zip(step_features, step_len)], dim=0).to(device) # also taking care of the distractor padding (dont worry about it later)
        else:
            distractors = None
            
        sf, ff, dif = run_model(video_features.to(device), step_features.to(device), distractors, drop_cost_type) # adding all the features here to gpu
        
        sample = (sf, step_len, ff, video_len, dif)
        clust_loss = compute_clust_loss(sample, xz_gamma=30, frame_gamma=10, l2_normalize=l2_normalize, device=device)
        
        dtw_loss = compute_alignment_loss(sample, drop_cost_type, gamma_xz=10, gamma_min=1, keep_percentile=keep_percentile, l2_normalize=l2_normalize, device=device)
        
        clust_losses.append(clust_loss.item())
        dtw_losses.append(dtw_loss.item())

        total_loss = dtw_loss
#         print((4 * clust_loss), (2.5 * dtw_loss))
        loss += total_loss.item()
        total_loss.backward()
        optimizer.step()
#         scheduler.step()
        
    loss /= 51.0
#     print(clust_losses)
    model.eval()
    frame_acc, iou_acc = evaluate(next(iter(valid_dl)), drop_cost_type)

    if (epoch + 1) % 5 == 0:

        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, f'{folder_name}/best_model_state_{epoch+1}.pth')

    print(f'Epoch: {epoch} \t  Loss: {loss} Frame Acc: {frame_acc/len(valid_dataset)} IoU: {iou_acc/len(valid_dataset)}')

100%|██████████| 51/51 [02:02<00:00,  2.40s/it]


Epoch: 0 	  Loss: 5.793319141163545 Frame Acc: 0.5899965973878536 IoU: 0.0


100%|██████████| 51/51 [01:51<00:00,  2.18s/it]


Epoch: 1 	  Loss: 4.728760766048057 Frame Acc: 0.5898080412725654 IoU: 0.0


100%|██████████| 51/51 [01:55<00:00,  2.26s/it]


Epoch: 2 	  Loss: 4.004155074848848 Frame Acc: 0.5872949571733218 IoU: 0.0


100%|██████████| 51/51 [01:53<00:00,  2.22s/it]


Epoch: 3 	  Loss: 3.355844965168074 Frame Acc: 0.5850236794885693 IoU: 0.0


100%|██████████| 51/51 [01:56<00:00,  2.28s/it]


Epoch: 4 	  Loss: 2.844718914405972 Frame Acc: 0.5836069938649825 IoU: 0.0


100%|██████████| 51/51 [01:53<00:00,  2.22s/it]


Epoch: 5 	  Loss: 2.460940496594298 Frame Acc: 0.5840526704736259 IoU: 0.0


100%|██████████| 51/51 [01:55<00:00,  2.26s/it]


Epoch: 6 	  Loss: 2.179210106531779 Frame Acc: 0.5854672565372712 IoU: 0.0


100%|██████████| 51/51 [01:55<00:00,  2.27s/it]


Epoch: 7 	  Loss: 1.8782471371631997 Frame Acc: 0.5848227989851335 IoU: 0.0


100%|██████████| 51/51 [01:55<00:00,  2.27s/it]


Epoch: 8 	  Loss: 1.64111889109892 Frame Acc: 0.584742695588721 IoU: 0.0


100%|██████████| 51/51 [01:54<00:00,  2.24s/it]


Epoch: 9 	  Loss: 1.4525970173817055 Frame Acc: 0.5849216485549824 IoU: 0.0


100%|██████████| 51/51 [01:53<00:00,  2.22s/it]


Epoch: 10 	  Loss: 1.253883869040246 Frame Acc: 0.5838135846745257 IoU: 0.0


100%|██████████| 51/51 [01:56<00:00,  2.28s/it]


In [None]:
loss

In [None]:
# import numpy as np

In [None]:
np.diag([10] * 5)