In [1]:
import numpy as np
import torch
import cv2
import json
from pathlib import Path
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import os
import random
from utils import *
import requests
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch import nn
import time
import wandb
from scipy.special import comb, perm

In [2]:
timestamp = time.strftime('%Y-%m-%d', time.localtime(time.time()))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# wandb.init(
#     project = 'VideoReorder',
#     name = 'shot to scene' + timestamp
#     )

In [3]:
data_path = '/home/jianghui/dataset/VideoReorder-MovieNet'
split = 'train'
train_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='scene')
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)
split = 'val'
val_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=32, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

In [5]:
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 1024),
    nn.ReLU(),
    nn.Linear(1024,2)
)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
   nn.Linear(2048, 1024),
net.apply(init_weights)
net.to(device)
# wandb.watch(net)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=2048, out_features=1024, bias=True)
  (2): ReLU()
  (3): Linear(in_features=1024, out_features=1024, bias=True)
  (4): ReLU()
  (5): Linear(in_features=1024, out_features=2, bias=True)
)

In [8]:
# load and test val
frame_to_shot_model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 1024),
    nn.ReLU(),
    nn.Linear(1024,2)
)
checkpoint = torch.load(Path('./checkpoint', f'frame_to_shot_best_2023-02-23_01.pth'))
frame_to_shot_model.load_state_dict(checkpoint)
frame_to_shot_model.to(device)
frame_to_shot_model.eval()
# wandb.watch(frame_to_shot_model)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=2048, out_features=1024, bias=True)
  (2): ReLU()
  (3): Linear(in_features=1024, out_features=1024, bias=True)
  (4): ReLU()
  (5): Linear(in_features=1024, out_features=2, bias=True)
)

In [9]:
def get_shot_feature(input):
    output = input
    for i in range(len(frame_to_shot_model)-1):
        output = frame_to_shot_model[i](output)
    return output

In [10]:
a = torch.rand(4, 2048).to(device)
get_shot_feature(a)

tensor([[0.0000, 0.2000, 0.0000,  ..., 0.1127, 0.0000, 0.0455],
        [0.0000, 0.1710, 0.0000,  ..., 0.0847, 0.0265, 0.0667],
        [0.0000, 0.2582, 0.0000,  ..., 0.1491, 0.0892, 0.0857],
        [0.0000, 0.1627, 0.0000,  ..., 0.0605, 0.0215, 0.0000]],
       device='cuda:0', grad_fn=<ReluBackward0>)

In [None]:
lr = 1e-4
epoch = 5

loss_func = nn.CrossEntropyLoss()
loss_func.to(device)
pred_func = ClipPairWisePred()
pred_func.to(device)
optim = torch.optim.AdamW(net.parameters(), lr=lr)

best_val_acc = 0.0
for e in range(epoch):
    print(f'epoch {e}:')

    # train
    net.train()
    loss_epoch_list = []
    score_epoch_list = []
    for batch_data in tqdm(train_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for scene_data in batch_data: #clip data
            # read input data
            features, gt = scene_data

            # get average feature
            for i in range(len(features)):
                features[i] = torch.mean(features[i], dim = 0)
            features = torch.stack(features, dim=0)

            # process model
            optim.zero_grad()
            
            N = len(gt)
            loss_scene = 0
            score_scene = 0
            for first_frame in range(N):
                for second_frame in range(first_frame + 1, N):
                    output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                    loss_scene += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                    PRED = get_order_index(output.reshape(-1).cpu())
                    GT = get_order_index([gt[first_frame], gt[second_frame]])
                    # print(PRED, GT)
                    score_scene += int(PRED == GT)
                    # print(score_scene)
            loss_batch_list.append(loss_scene / comb(N, 2))
            score_batch_list.append(score_scene / comb(N, 2))
                
            
        # calcuclate avearge batch
        score_step = sum(score_batch_list) / len(score_batch_list)
        loss_step = sum(loss_batch_list) / len(loss_batch_list)
        loss_step.backward()
        optim.step()
        # caculate avearge score
        score_epoch_list.append(score_step)
        loss_epoch_list.append(loss_step)
        wandb.log({'train loss':loss_step.item(), 'train score':score_step})

    score_epoch = sum(score_epoch_list) / len(score_epoch_list)
    loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
    print('train loss = ', loss_epoch.item(), 'train score = ', score_epoch)  

    # val
    net.eval()
    with torch.no_grad():
        loss_epoch_list = []
        score_epoch_list = []
        for batch_data in tqdm(val_dataloader):
            loss_batch_list = []
            score_batch_list = []

            for scene_data in batch_data: #clip data
                # read input data
                features, gt = scene_data
                
                # get average feature
                for i in range(len(features)):
                    features[i] = torch.mean(features[i], dim = 0)
                features = torch.stack(features, dim=0)

                N = len(gt)
                loss_scene = 0
                score_scene = 0
                for first_frame in range(N):
                    for second_frame in range(first_frame + 1, N):
                        output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                        loss_scene += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                        PRED = get_order_index(output.reshape(-1).cpu())
                        GT = get_order_index([gt[first_frame], gt[second_frame]])
                        # print(PRED, GT)
                        score_scene += int(PRED == GT)
                        # print(score_scene)
                loss_batch_list.append(loss_scene / comb(N, 2))
                score_batch_list.append(score_scene / comb(N, 2))
                
            # calcuclate avearge batch
            score_step = sum(score_batch_list) / len(score_batch_list)
            loss_step = sum(loss_batch_list) / len(loss_batch_list)

            # caculate avearge score
            score_epoch_list.append(score_step)
            loss_epoch_list.append(loss_step)
            wandb.log({'val loss':loss_step.item(), 'val score':score_step})

        score_epoch = sum(score_epoch_list) / len(score_epoch_list)
        loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
        print('val loss = ', loss_epoch.item(), 'val score = ', score_epoch)
        if score_epoch >= best_val_acc: 
            best_val_acc = score_epoch
            torch.save(net.state_dict(), Path('./checkpoint', f'shot_to_scene_best_{timestamp}.pth'))
            print("save epoch ",e)
