In [1]:
import numpy as np
import torch
import cv2
import json
from pathlib import Path
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import os
import random
from utils import *
import requests
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch import nn

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [3]:
data_path = '/home/jianghui/dataset/VideoReorder-MovieNet'
split = 'train'
train_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

split = 'val'
val_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='shot')
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

In [4]:
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512,2)
)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)
net.to(device)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=2048, out_features=1024, bias=True)
  (2): ReLU()
  (3): Linear(in_features=1024, out_features=512, bias=True)
  (4): ReLU()
  (5): Linear(in_features=512, out_features=2, bias=True)
)

In [5]:
lr = 1e-4
epoch = 40

loss_func = nn.CrossEntropyLoss()
loss_func.to(device)
pred_func = ClipPairWisePred()
pred_func.to(device)
optim = torch.optim.AdamW(net.parameters(), lr=lr)

for e in range(epoch):
    print(f'epoch {e}:')

    # train
    loss_epoch_list = []
    score_epoch_list = []
    for batch_data in tqdm(train_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for shot_data in batch_data: #clip data
            # read input data
            features, gt = shot_data

            # process model
            optim.zero_grad()
            
            if len(gt) == 2:
                output = net(features.reshape(-1).unsqueeze(0).to(device))
                loss_shot = loss_func(output, torch.tensor(gt[1]).unsqueeze(0).to(device))

                PRED = get_order_index(output.reshape(-1).cpu())
                GT = gt
                score_shot = int(PRED == gt)
                # print(PRED, GT)
                loss_batch_list.append(loss_shot)
                score_batch_list.append(score_shot)
            elif len(gt) == 3:
                loss_shot = 0
                score_shot = 0
                for first_frame in range(3):
                    for second_frame in range(first_frame + 1, 3):
                        output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                        loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                        PRED = get_order_index(output.reshape(-1).cpu())
                        GT = get_order_index([gt[first_frame], gt[second_frame]])
                        # print(PRED, GT)
                        score_shot += int(PRED == GT)
                        # print(score_shot)
                loss_batch_list.append(loss_shot / 3)
                score_batch_list.append(score_shot / 3)
                
            else:
                assert False, 'shot frame is neither 2 nor 3'
            
        # calcuclate avearge batch
        score_step = sum(score_batch_list) / len(score_batch_list)
        loss_step = sum(loss_batch_list) / len(loss_batch_list)
        loss_step.backward()
        optim.step()
        # caculate avearge score
        score_epoch_list.append(score_step)
        loss_epoch_list.append(loss_step)

    score_epoch = sum(score_epoch_list) / len(score_epoch_list)
    loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
    print('train loss = ', loss_epoch.item(), 'train score = ', score_epoch)  

    # val
    loss_epoch_list = []
    score_epoch_list = []
    for batch_data in tqdm(val_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for shot_data in batch_data: #clip data
            # read input data
            features, gt = shot_data
            
            if len(gt) == 2:
                output = net(features.reshape(-1).unsqueeze(0).to(device))
                loss_shot = loss_func(output, torch.tensor(gt[1]).unsqueeze(0).to(device))

                PRED = get_order_index(output.reshape(-1).cpu())
                GT = gt
                score_shot = int(PRED == gt)
                # print(PRED, GT)
                loss_batch_list.append(loss_shot)
                score_batch_list.append(score_shot)
            elif len(gt) == 3:
                loss_shot = 0
                score_shot = 0
                for first_frame in range(3):
                    for second_frame in range(first_frame + 1, 3):
                        output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                        loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                        PRED = get_order_index(output.reshape(-1).cpu())
                        GT = get_order_index([gt[first_frame], gt[second_frame]])
                        # print(PRED, GT)
                        score_shot += int(PRED == GT)
                        # print(score_shot)
                loss_batch_list.append(loss_shot / 3)
                score_batch_list.append(score_shot / 3)
                
            else:
                assert False, 'shot frame is neither 2 nor 3'
            
        # calcuclate avearge batch
        score_step = sum(score_batch_list) / len(score_batch_list)
        loss_step = sum(loss_batch_list) / len(loss_batch_list)

        # caculate avearge score
        score_epoch_list.append(score_step)
        loss_epoch_list.append(loss_step)

    score_epoch = sum(score_epoch_list) / len(score_epoch_list)
    loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
    print('val loss = ', loss_epoch.item(), 'val score = ', score_epoch)  


epoch 0:


100%|██████████| 226/226 [00:24<00:00,  9.19it/s]


train loss =  0.6659634113311768 train score =  0.5804912879356031
epoch 1:


100%|██████████| 226/226 [00:24<00:00,  9.19it/s]


train loss =  0.6510956287384033 train score =  0.6137570074621268
epoch 2:


100%|██████████| 226/226 [00:25<00:00,  8.75it/s]


train loss =  0.6396486163139343 train score =  0.6361566687290631
epoch 3:


100%|██████████| 226/226 [00:25<00:00,  8.70it/s]


train loss =  0.6170020699501038 train score =  0.6626561250062498
epoch 4:


100%|██████████| 226/226 [00:26<00:00,  8.65it/s]


train loss =  0.5691925883293152 train score =  0.7099629393530325
epoch 5:


100%|██████████| 226/226 [00:25<00:00,  8.85it/s]


train loss =  0.47833314538002014 train score =  0.7805906579671017
epoch 6:


100%|██████████| 226/226 [00:26<00:00,  8.57it/s]


train loss =  0.36778098344802856 train score =  0.8464348266961658
epoch 7:


100%|██████████| 226/226 [00:26<00:00,  8.65it/s]


train loss =  0.26813822984695435 train score =  0.8966180597220141
epoch 8:


100%|██████████| 226/226 [00:25<00:00,  8.78it/s]


train loss =  0.18456830084323883 train score =  0.9336720663966804
epoch 9:


100%|██████████| 226/226 [00:26<00:00,  8.59it/s]


train loss =  0.12439285218715668 train score =  0.9592135627593623
epoch 10:


100%|██████████| 226/226 [00:26<00:00,  8.58it/s]


train loss =  0.08120259642601013 train score =  0.9760697511999398
epoch 11:


 24%|██▍       | 54/226 [00:07<00:22,  7.57it/s]