In [1]:
import numpy as np
import torch
import cv2
import json
from pathlib import Path
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import os
import random
from utils import *
import requests
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch import nn
import time
from scipy.special import comb, perm

In [2]:
import wandb
# wandb.init(
#     project = 'VideoReorder',
#     name = 'shot only'
#     )
timestamp = time.strftime('%Y-%m-%d', time.localtime(time.time()))

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [4]:
data_path = '/home/jianghui/dataset/VideoReorder-MovieNet'
split = 'train'
train_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

split = 'val'
val_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=32, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

In [5]:
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512,2)
)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)
net.to(device)
# wandb.watch(net)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=2048, out_features=1024, bias=True)
  (2): ReLU()
  (3): Linear(in_features=1024, out_features=1024, bias=True)
  (4): ReLU()
  (5): Linear(in_features=1024, out_features=2, bias=True)
)

In [7]:
lr = 1e-4
epoch = 5

loss_func = nn.CrossEntropyLoss()
loss_func.to(device)
pred_func = ClipPairWisePred()
pred_func.to(device)
optim = torch.optim.AdamW(net.parameters(), lr=lr)


best_val_acc = 0.0
for e in range(epoch):
    print(f'epoch {e}:')

    # train
    net.train()
    loss_epoch_list = []
    score_epoch_list = []
    for batch_data in tqdm(train_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for shot_data in batch_data: #clip data
            # read input data
            features, gt = shot_data

            # process model
            optim.zero_grad()
            
            if len(gt) == 2:
                output = net(features.reshape(-1).unsqueeze(0).to(device))
                loss_shot = loss_func(output, torch.tensor(gt[1]).unsqueeze(0).to(device))

                PRED = get_order_index(output.reshape(-1).cpu())
                GT = gt
                score_shot = int(PRED == gt)
                # print(PRED, GT)
                loss_batch_list.append(loss_shot)
                score_batch_list.append(score_shot)
            elif len(gt) == 3:
                loss_shot = 0
                score_shot = 0
                for first_frame in range(3):
                    for second_frame in range(first_frame + 1, 3):
                        output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                        loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                        PRED = get_order_index(output.reshape(-1).cpu())
                        GT = get_order_index([gt[first_frame], gt[second_frame]])
                        # print(PRED, GT)
                        score_shot += int(PRED == GT)
                        # print(score_shot)
                loss_batch_list.append(loss_shot / 3)
                score_batch_list.append(score_shot / 3)
                
            else:
                assert False, 'shot frame is neither 2 nor 3'
            
        # calcuclate avearge batch
        score_step = sum(score_batch_list) / len(score_batch_list)
        loss_step = sum(loss_batch_list) / len(loss_batch_list)
        loss_step.backward()
        optim.step()
        # caculate avearge score
        score_epoch_list.append(score_step)
        loss_epoch_list.append(loss_step)
        # wandb.log({'train loss':loss_step.item(), 'train score':score_step})

    score_epoch = sum(score_epoch_list) / len(score_epoch_list)
    loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
    print('train loss = ', loss_epoch.item(), 'train score = ', score_epoch)  

    # val
    net.eval()
    with torch.no_grad():
        loss_epoch_list = []
        score_epoch_list = []
        for batch_data in tqdm(val_dataloader):
            loss_batch_list = []
            score_batch_list = []

            for shot_data in batch_data: #clip data
                # read input data
                features, gt = shot_data
                
                if len(gt) == 2:
                    output = net(features.reshape(-1).unsqueeze(0).to(device))
                    loss_shot = loss_func(output, torch.tensor(gt[0]).unsqueeze(0).to(device))

                    PRED = get_order_index(output.reshape(-1).cpu())
                    GT = gt
                    score_shot = int(PRED == gt)
                    # print(PRED, GT)
                    loss_batch_list.append(loss_shot)
                    score_batch_list.append(score_shot)
                elif len(gt) == 3:
                    loss_shot = 0
                    score_shot = 0
                    for first_frame in range(3):
                        for second_frame in range(first_frame + 1, 3):
                            output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                            loss_shot += loss_func(output, torch.tensor([0]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([1]).to(device))
                            PRED = get_order_index(output.reshape(-1).cpu())
                            GT = get_order_index([gt[first_frame], gt[second_frame]])
                            # print(PRED, GT)
                            score_shot += int(PRED == GT)
                            # print(score_shot)
                    loss_batch_list.append(loss_shot / 3)
                    score_batch_list.append(score_shot / 3)
                    
                else:
                    assert False, 'shot frame is neither 2 nor 3'
                
            # calcuclate avearge batch
            score_step = sum(score_batch_list) / len(score_batch_list)
            loss_step = sum(loss_batch_list) / len(loss_batch_list)

            # caculate avearge score
            score_epoch_list.append(score_step)
            loss_epoch_list.append(loss_step)
            # wandb.log({'val loss':loss_step.item(), 'val score':score_step})

        score_epoch = sum(score_epoch_list) / len(score_epoch_list)
        loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
        print('val loss = ', loss_epoch.item(), 'val score = ', score_epoch)
        if score_epoch >= best_val_acc: 
            best_val_acc = score_epoch
            torch.save(net.state_dict(), Path('./checkpoint', f'frame_to_shot_best_{timestamp}.pth'))
            print("save epoch ",e)


epoch 0:


  0%|          | 0/904 [00:00<?, ?it/s]


ValueError: not enough values to unpack (expected 4, got 2)

More test

In [None]:
split = 'test_in_domain'
test_in_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
test_in_dataloader = torch.utils.data.DataLoader(test_in_data, batch_size=128, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

split = 'test_out_domain'
test_out_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
test_out_dataloader = torch.utils.data.DataLoader(test_out_data, batch_size=128, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

In [None]:
# load and test val
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512,2)
)
checkpoint = torch.load(Path('./checkpoint', f'frame_to_shot_best_{timestamp}.pth'))
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
wandb.watch(model)

In [None]:
# test val
with torch.no_grad():
    loss_epoch_list = []
    score_epoch_list = []
    for batch_data in tqdm(val_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for shot_data in batch_data: #clip data
            # read input data
            features, gt = shot_data
            
            if len(gt) == 2:
                output = model(features.reshape(-1).unsqueeze(0).to(device))
                loss_shot = loss_func(output, torch.tensor(gt[1]).unsqueeze(0).to(device))

                PRED = get_order_index(output.reshape(-1).cpu())
                GT = gt
                score_shot = int(PRED == gt)
                # print(PRED, GT)
                loss_batch_list.append(loss_shot)
                score_batch_list.append(score_shot)
            elif len(gt) == 3:
                loss_shot = 0
                score_shot = 0
                for first_frame in range(3):
                    for second_frame in range(first_frame + 1, 3):
                        output = model(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                        loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                        PRED = get_order_index(output.reshape(-1).cpu())
                        GT = get_order_index([gt[first_frame], gt[second_frame]])
                        # print(PRED, GT)
                        score_shot += int(PRED == GT)
                        # print(score_shot)
                loss_batch_list.append(loss_shot / 3)
                score_batch_list.append(score_shot / 3)
            else:
                assert False, 'shot frame is neither 2 nor 3'
            
        # calcuclate avearge batch
        score_step = sum(score_batch_list) / len(score_batch_list)
        loss_step = sum(loss_batch_list) / len(loss_batch_list)

        # caculate avearge score
        score_epoch_list.append(score_step)
        loss_epoch_list.append(loss_step)
        wandb.log({'val loss':loss_step.item(), 'val score':score_step})

    score_epoch = sum(score_epoch_list) / len(score_epoch_list)
    loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
    print('val loss = ', loss_epoch.item(), 'val score = ', score_epoch)

In [None]:
# try test_in_domain
with torch.no_grad():
    loss_epoch_list = []
    score_epoch_list = []
    for batch_data in tqdm(test_in_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for shot_data in batch_data: #clip data
            # read input data
            features, gt = shot_data
            
            if len(gt) == 2:
                output = model(features.reshape(-1).unsqueeze(0).to(device))
                loss_shot = loss_func(output, torch.tensor(gt[1]).unsqueeze(0).to(device))

                PRED = get_order_index(output.reshape(-1).cpu())
                GT = gt
                score_shot = int(PRED == gt)
                # print(PRED, GT)
                loss_batch_list.append(loss_shot)
                score_batch_list.append(score_shot)
            elif len(gt) == 3:
                loss_shot = 0
                score_shot = 0
                for first_frame in range(3):
                    for second_frame in range(first_frame + 1, 3):
                        output = model(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                        loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                        PRED = get_order_index(output.reshape(-1).cpu())
                        GT = get_order_index([gt[first_frame], gt[second_frame]])
                        # print(PRED, GT)
                        score_shot += int(PRED == GT)
                        # print(score_shot)
                loss_batch_list.append(loss_shot / 3)
                score_batch_list.append(score_shot / 3)
            else:
                assert False, 'shot frame is neither 2 nor 3'
            
        # calcuclate avearge batch
        score_step = sum(score_batch_list) / len(score_batch_list)
        loss_step = sum(loss_batch_list) / len(loss_batch_list)

        # caculate avearge score
        score_epoch_list.append(score_step)
        loss_epoch_list.append(loss_step)
        wandb.log({'test in domain loss':loss_step.item(), 'test_in_domain score':score_step})

    score_epoch = sum(score_epoch_list) / len(score_epoch_list)
    loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
    print('test_in_domain loss = ', loss_epoch.item(), 'test_in_domain score = ', score_epoch)

In [None]:
# test out domain
with torch.no_grad():
    loss_epoch_list = []
    score_epoch_list = []
    for batch_data in tqdm(test_out_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for shot_data in batch_data: #clip data
            # read input data
            features, gt = shot_data
            
            if len(gt) == 2:
                output = model(features.reshape(-1).unsqueeze(0).to(device))
                loss_shot = loss_func(output, torch.tensor(gt[1]).unsqueeze(0).to(device))

                PRED = get_order_index(output.reshape(-1).cpu())
                GT = gt
                score_shot = int(PRED == gt)
                # print(PRED, GT)
                loss_batch_list.append(loss_shot)
                score_batch_list.append(score_shot)
            elif len(gt) == 3:
                loss_shot = 0
                score_shot = 0
                for first_frame in range(3):
                    for second_frame in range(first_frame + 1, 3):
                        output = model(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                        loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                        PRED = get_order_index(output.reshape(-1).cpu())
                        GT = get_order_index([gt[first_frame], gt[second_frame]])
                        # print(PRED, GT)
                        score_shot += int(PRED == GT)
                        # print(score_shot)
                loss_batch_list.append(loss_shot / 3)
                score_batch_list.append(score_shot / 3)
            else:
                assert False, 'shot frame is neither 2 nor 3'
            
        # calcuclate avearge batch
        score_step = sum(score_batch_list) / len(score_batch_list)
        loss_step = sum(loss_batch_list) / len(loss_batch_list)

        # caculate avearge score
        score_epoch_list.append(score_step)
        loss_epoch_list.append(loss_step)
        wandb.log({'test out domain loss':loss_step.item(), 'test out domain score':score_step})

    score_epoch = sum(score_epoch_list) / len(score_epoch_list)
    loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
    print('test out domain loss = ', loss_epoch.item(), 'test out domain score = ', score_epoch)