In [1]:
import numpy as np
import torch
import cv2
import json
from pathlib import Path
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import os
import random
from utils import *
import requests
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch import nn
import time

In [2]:
import wandb
os.environ['WANDB_NOTEBOOK_NAME'] = 'VideoReorder'
wandb.init(name = 'shot only')

[34m[1mwandb[0m: Currently logged in as: [33mjianghui[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [4]:
data_path = '/home/jianghui/dataset/VideoReorder-MovieNet'
split = 'train'
train_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

split = 'val'
val_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=32, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

In [5]:
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512,2)
)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)
net.to(device)
wandb.watch(net)

[]

In [6]:
lr = 1e-4
epoch = 40

loss_func = nn.CrossEntropyLoss()
loss_func.to(device)
pred_func = ClipPairWisePred()
pred_func.to(device)
optim = torch.optim.AdamW(net.parameters(), lr=lr)

for e in range(epoch):
    print(f'epoch {e}:')

    # train
    loss_epoch_list = []
    score_epoch_list = []
    for batch_data in tqdm(train_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for shot_data in batch_data: #clip data
            # read input data
            features, gt = shot_data

            # process model
            optim.zero_grad()
            
            if len(gt) == 2:
                output = net(features.reshape(-1).unsqueeze(0).to(device))
                loss_shot = loss_func(output, torch.tensor(gt[1]).unsqueeze(0).to(device))

                PRED = get_order_index(output.reshape(-1).cpu())
                GT = gt
                score_shot = int(PRED == gt)
                # print(PRED, GT)
                loss_batch_list.append(loss_shot)
                score_batch_list.append(score_shot)
            elif len(gt) == 3:
                loss_shot = 0
                score_shot = 0
                for first_frame in range(3):
                    for second_frame in range(first_frame + 1, 3):
                        output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                        loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                        PRED = get_order_index(output.reshape(-1).cpu())
                        GT = get_order_index([gt[first_frame], gt[second_frame]])
                        # print(PRED, GT)
                        score_shot += int(PRED == GT)
                        # print(score_shot)
                loss_batch_list.append(loss_shot / 3)
                score_batch_list.append(score_shot / 3)
                
            else:
                assert False, 'shot frame is neither 2 nor 3'
            
        # calcuclate avearge batch
        score_step = sum(score_batch_list) / len(score_batch_list)
        loss_step = sum(loss_batch_list) / len(loss_batch_list)
        loss_step.backward()
        optim.step()
        # caculate avearge score
        score_epoch_list.append(score_step)
        loss_epoch_list.append(loss_step)
        wandb.log({'train loss':loss_step.item(), 'train score':score_step})

    score_epoch = sum(score_epoch_list) / len(score_epoch_list)
    loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
    print('train loss = ', loss_epoch.item(), 'train score = ', score_epoch)  

    # val
    loss_epoch_list = []
    score_epoch_list = []
    for batch_data in tqdm(val_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for shot_data in batch_data: #clip data
            # read input data
            features, gt = shot_data
            
            if len(gt) == 2:
                output = net(features.reshape(-1).unsqueeze(0).to(device))
                loss_shot = loss_func(output, torch.tensor(gt[1]).unsqueeze(0).to(device))

                PRED = get_order_index(output.reshape(-1).cpu())
                GT = gt
                score_shot = int(PRED == gt)
                # print(PRED, GT)
                loss_batch_list.append(loss_shot)
                score_batch_list.append(score_shot)
            elif len(gt) == 3:
                loss_shot = 0
                score_shot = 0
                for first_frame in range(3):
                    for second_frame in range(first_frame + 1, 3):
                        output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                        loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                        PRED = get_order_index(output.reshape(-1).cpu())
                        GT = get_order_index([gt[first_frame], gt[second_frame]])
                        # print(PRED, GT)
                        score_shot += int(PRED == GT)
                        # print(score_shot)
                loss_batch_list.append(loss_shot / 3)
                score_batch_list.append(score_shot / 3)
                
            else:
                assert False, 'shot frame is neither 2 nor 3'
            
        # calcuclate avearge batch
        score_step = sum(score_batch_list) / len(score_batch_list)
        loss_step = sum(loss_batch_list) / len(loss_batch_list)

        # caculate avearge score
        score_epoch_list.append(score_step)
        loss_epoch_list.append(loss_step)
        wandb.log({'val loss':loss_step.item(), 'val score':score_step})

    score_epoch = sum(score_epoch_list) / len(score_epoch_list)
    loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
    print('val loss = ', loss_epoch.item(), 'val score = ', score_epoch)  


epoch 0:


100%|██████████| 904/904 [00:30<00:00, 29.90it/s]


train loss =  0.6640595197677612 train score =  0.5821244804237062


100%|██████████| 75/75 [00:01<00:00, 42.72it/s]


val loss =  0.6678372621536255 val score =  0.5786574074074073
epoch 1:


100%|██████████| 904/904 [00:30<00:00, 29.64it/s]


train loss =  0.6498633027076721 train score =  0.6152202333065171


100%|██████████| 75/75 [00:02<00:00, 35.74it/s]


val loss =  0.6653242111206055 val score =  0.5896759259259258
epoch 2:


100%|██████████| 904/904 [00:29<00:00, 30.33it/s]


train loss =  0.6347200274467468 train score =  0.6364820075757577


100%|██████████| 75/75 [00:02<00:00, 33.78it/s]


val loss =  0.6696799397468567 val score =  0.5860648148148148
epoch 3:


100%|██████████| 904/904 [00:29<00:00, 31.00it/s]


train loss =  0.6037259101867676 train score =  0.6707101434700993


100%|██████████| 75/75 [00:01<00:00, 37.81it/s]


val loss =  0.6942498683929443 val score =  0.5700462962962963
epoch 4:


100%|██████████| 904/904 [00:28<00:00, 31.43it/s]


train loss =  0.5295026898384094 train score =  0.7323061226200055


100%|██████████| 75/75 [00:02<00:00, 37.01it/s]


val loss =  0.7622363567352295 val score =  0.558101851851852
epoch 5:


100%|██████████| 904/904 [00:27<00:00, 32.41it/s]


train loss =  0.416706919670105 train score =  0.8078908554572267


100%|██████████| 75/75 [00:01<00:00, 38.93it/s]


val loss =  0.9020805954933167 val score =  0.5681944444444444
epoch 6:


100%|██████████| 904/904 [00:27<00:00, 32.44it/s]


train loss =  0.30437973141670227 train score =  0.8693108490882279


100%|██████████| 75/75 [00:02<00:00, 36.99it/s]


val loss =  1.0730971097946167 val score =  0.5601851851851852
epoch 7:


100%|██████████| 904/904 [00:28<00:00, 32.14it/s]


train loss =  0.20832668244838715 train score =  0.9149309047331718


100%|██████████| 75/75 [00:02<00:00, 35.74it/s]


val loss =  1.265468955039978 val score =  0.55125
epoch 8:


100%|██████████| 904/904 [00:28<00:00, 31.35it/s]


train loss =  0.14660795032978058 train score =  0.942421309332261


100%|██████████| 75/75 [00:02<00:00, 34.20it/s]


val loss =  1.4500564336776733 val score =  0.5508333333333334
epoch 9:


100%|██████████| 904/904 [00:29<00:00, 30.73it/s]


train loss =  0.10729944705963135 train score =  0.9590833668543826


100%|██████████| 75/75 [00:02<00:00, 35.93it/s]


val loss =  1.668874979019165 val score =  0.5504166666666668
epoch 10:


100%|██████████| 904/904 [00:29<00:00, 30.33it/s]


train loss =  0.08331193774938583 train score =  0.9683760307723256


100%|██████████| 75/75 [00:02<00:00, 36.23it/s]


val loss =  1.8411871194839478 val score =  0.5503240740740741
epoch 11:


100%|██████████| 904/904 [00:29<00:00, 30.58it/s]


train loss =  0.0772649422287941 train score =  0.9709917203003499


100%|██████████| 75/75 [00:02<00:00, 36.27it/s]


val loss =  1.9625160694122314 val score =  0.5523148148148148
epoch 12:


100%|██████████| 904/904 [00:30<00:00, 29.55it/s]


train loss =  0.057565171271562576 train score =  0.9797490949316173


100%|██████████| 75/75 [00:02<00:00, 33.55it/s]


val loss =  2.1785905361175537 val score =  0.5646759259259259
epoch 13:


100%|██████████| 904/904 [00:30<00:00, 29.97it/s]


train loss =  0.05875478312373161 train score =  0.9785329092920351


100%|██████████| 75/75 [00:02<00:00, 33.02it/s]


val loss =  2.2507638931274414 val score =  0.5668055555555556
epoch 14:


100%|██████████| 904/904 [00:30<00:00, 29.85it/s]


train loss =  0.0482814684510231 train score =  0.9823071701528574


100%|██████████| 75/75 [00:02<00:00, 33.24it/s]


val loss =  2.269176721572876 val score =  0.5551388888888891
epoch 15:


100%|██████████| 904/904 [00:28<00:00, 31.23it/s]


train loss =  0.05411630868911743 train score =  0.9806489256503113


100%|██████████| 75/75 [00:02<00:00, 36.93it/s]


val loss =  2.4173309803009033 val score =  0.5505555555555557
epoch 16:


100%|██████████| 904/904 [00:27<00:00, 32.74it/s]


train loss =  0.04856187850236893 train score =  0.9834772643470119


100%|██████████| 75/75 [00:02<00:00, 37.01it/s]


val loss =  2.475212574005127 val score =  0.5555092592592591
epoch 17:


100%|██████████| 904/904 [00:28<00:00, 31.86it/s]


train loss =  0.033718179911375046 train score =  0.9887075958702092


100%|██████████| 75/75 [00:02<00:00, 34.85it/s]


val loss =  2.6183581352233887 val score =  0.5514351851851852
epoch 18:


100%|██████████| 904/904 [00:29<00:00, 31.01it/s]


train loss =  0.046350374817848206 train score =  0.9832258564628602


100%|██████████| 75/75 [00:02<00:00, 32.06it/s]


val loss =  2.5759100914001465 val score =  0.5513425925925923
epoch 19:


100%|██████████| 904/904 [00:29<00:00, 30.88it/s]


train loss =  0.0389455109834671 train score =  0.9863569321533946


100%|██████████| 75/75 [00:02<00:00, 32.97it/s]


val loss =  2.7424590587615967 val score =  0.5445833333333332
epoch 20:


100%|██████████| 904/904 [00:29<00:00, 30.39it/s]


train loss =  0.03914128243923187 train score =  0.9870671594261223


100%|██████████| 75/75 [00:02<00:00, 32.76it/s]


val loss =  2.7988979816436768 val score =  0.5519907407407407
epoch 21:


100%|██████████| 904/904 [00:29<00:00, 30.27it/s]


train loss =  0.03305542469024658 train score =  0.9882990580584635


100%|██████████| 75/75 [00:02<00:00, 33.04it/s]


val loss =  2.6030690670013428 val score =  0.5570833333333334
epoch 22:


100%|██████████| 904/904 [00:29<00:00, 30.84it/s]


train loss =  0.030483713373541832 train score =  0.9897750318450024


100%|██████████| 75/75 [00:02<00:00, 34.50it/s]


val loss =  2.7652666568756104 val score =  0.5519444444444442
epoch 23:


100%|██████████| 904/904 [00:30<00:00, 29.74it/s]


train loss =  0.034723252058029175 train score =  0.9883158185840734


100%|██████████| 75/75 [00:02<00:00, 32.72it/s]


val loss =  2.822835683822632 val score =  0.5432870370370368
epoch 24:


100%|██████████| 904/904 [00:30<00:00, 29.71it/s]


train loss =  0.032354604452848434 train score =  0.988776733038351


100%|██████████| 75/75 [00:02<00:00, 32.05it/s]


val loss =  2.5385184288024902 val score =  0.5625925925925925
epoch 25:


100%|█████████▉| 900/904 [00:30<00:00, 29.98it/s]


KeyboardInterrupt: 

More test

In [None]:
split = 'test_in_domain'
test_in_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
test_in_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

split = 'test_out_domain'
test_out_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
test_out_dataloader = torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

In [None]:
# val
loss_epoch_list = []
score_epoch_list = []
for batch_data in tqdm(val_dataloader):
    loss_batch_list = []
    score_batch_list = []

    for shot_data in batch_data: #clip data
        # read input data
        features, gt = shot_data
        
        if len(gt) == 2:
            output = net(features.reshape(-1).unsqueeze(0).to(device))
            loss_shot = loss_func(output, torch.tensor(gt[1]).unsqueeze(0).to(device))

            PRED = get_order_index(output.reshape(-1).cpu())
            GT = gt
            score_shot = int(PRED == gt)
            # print(PRED, GT)
            loss_batch_list.append(loss_shot)
            score_batch_list.append(score_shot)
        elif len(gt) == 3:
            loss_shot = 0
            score_shot = 0
            for first_frame in range(3):
                for second_frame in range(first_frame + 1, 3):
                    output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                    loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                    PRED = get_order_index(output.reshape(-1).cpu())
                    GT = get_order_index([gt[first_frame], gt[second_frame]])
                    # print(PRED, GT)
                    score_shot += int(PRED == GT)
                    # print(score_shot)
            loss_batch_list.append(loss_shot / 3)
            score_batch_list.append(score_shot / 3)
            
        else:
            assert False, 'shot frame is neither 2 nor 3'
        
    # calcuclate avearge batch
    score_step = sum(score_batch_list) / len(score_batch_list)
    loss_step = sum(loss_batch_list) / len(loss_batch_list)

    # caculate avearge score
    score_epoch_list.append(score_step)
    loss_epoch_list.append(loss_step)
    wandb.log({'val loss':loss_step, 'val score':score_step})

score_epoch = sum(score_epoch_list) / len(score_epoch_list)
loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
print('val loss = ', loss_epoch.item(), 'val score = ', score_epoch)  