In [1]:
import numpy as np
import torch
import cv2
import json
from pathlib import Path
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import os
import random
from utils import *
import requests
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch import nn
import time
from scipy.special import comb, perm

In [2]:
import wandb
# wandb.init(
#     project = 'VideoReorder',
#     name = 'shot only'
#     )
timestamp = time.strftime('%Y-%m-%d', time.localtime(time.time()))

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [4]:
data_path = '/home/jianghui/dataset/VideoReorder-MovieNet'
split = 'train'
train_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='')
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

split = 'val'
val_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='')
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=32, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

In [5]:
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512,2)
)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)
net.to(device)
# wandb.watch(net)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=2048, out_features=1024, bias=True)
  (2): ReLU()
  (3): Linear(in_features=1024, out_features=512, bias=True)
  (4): ReLU()
  (5): Linear(in_features=512, out_features=2, bias=True)
)

In [6]:
lr = 1e-4
epoch = 5

loss_func = nn.CrossEntropyLoss()
loss_func.to(device)
pred_func = ClipPairWisePred()
pred_func.to(device)
optim = torch.optim.AdamW(net.parameters(), lr=lr)


best_val_acc = 0.0
for e in range(epoch):
    print(f'epoch {e}:')

    # train
    net.train()
    loss_epoch_list = []
    score_epoch_list = []
    for batch_data in tqdm(train_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for all_data in batch_data: #clip data
            # read input data
            features, img_id, _, _ = all_data
            gt = get_order_index(img_id)
            N = len(gt)

            # process model
            optim.zero_grad()

            loss_shot = 0
            score_shot = 0
            for first_frame in range(N):
                for second_frame in range(first_frame + 1, N):
                    output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                    loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                    PRED = get_order_list(output.reshape(-1).cpu())
                    GT = get_order_list([gt[first_frame], gt[second_frame]])
                    # print(PRED, GT)
                    score_shot += int(PRED == GT)
                    # print(score_shot)
            loss_batch_list.append(loss_shot / comb(N, 2))
            score_batch_list.append(score_shot / comb(N, 2))
                
        # calcuclate avearge batch
        score_step = sum(score_batch_list) / len(score_batch_list)
        loss_step = sum(loss_batch_list) / len(loss_batch_list)
        loss_step.backward()
        optim.step()
        # caculate avearge score
        score_epoch_list.append(score_step)
        loss_epoch_list.append(loss_step)
        # wandb.log({'train loss':loss_step.item(), 'train score':score_step})

    score_epoch = sum(score_epoch_list) / len(score_epoch_list)
    loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
    print('train loss = ', loss_epoch.item(), 'train score = ', score_epoch)  

    # val
    net.eval()
    with torch.no_grad():
        loss_epoch_list = []
        score_epoch_list = []
        for batch_data in tqdm(val_dataloader):
            loss_batch_list = []
            score_batch_list = []

            for all_data in batch_data: #clip data
                features, img_id, _, _ = all_data
                gt = get_order_index(img_id)
                N = len(gt)
                
                loss_shot = 0
                score_shot = 0
                for first_frame in range(N):
                    for second_frame in range(first_frame + 1, N):
                        output = net(features[[first_frame, second_frame],...].reshape(-1).unsqueeze(0).to(device))
                        loss_shot += loss_func(output, torch.tensor([1]).to(device) if gt[first_frame] < gt[second_frame] else torch.tensor([0]).to(device))
                        PRED = get_order_list(output.reshape(-1).cpu())
                        GT = get_order_list([gt[first_frame], gt[second_frame]])
                        # print(PRED, GT)
                        score_shot += int(PRED == GT)
                        # print(score_shot)
                loss_batch_list.append(loss_shot / comb(N, 2))
                score_batch_list.append(score_shot / comb(N, 2))
                
            # calcuclate avearge batch
            score_step = sum(score_batch_list) / len(score_batch_list)
            loss_step = sum(loss_batch_list) / len(loss_batch_list)

            # caculate avearge score
            score_epoch_list.append(score_step)
            loss_epoch_list.append(loss_step)
            # wandb.log({'val loss':loss_step.item(), 'val score':score_step})

        score_epoch = sum(score_epoch_list) / len(score_epoch_list)
        loss_epoch = sum(loss_epoch_list) / len(loss_epoch_list)
        print('val loss = ', loss_epoch.item(), 'val score = ', score_epoch)
        if score_epoch >= best_val_acc: 
            best_val_acc = score_epoch
            torch.save(net.state_dict(), Path('./checkpoint', f'all_best_{timestamp}.pth'))
            print("save epoch ",e)


epoch 0:


100%|██████████| 221/221 [04:58<00:00,  1.35s/it]


train loss =  0.6931282877922058 train score =  0.5040786786995005


100%|██████████| 19/19 [00:13<00:00,  1.42it/s]


val loss =  0.6928536295890808 val score =  0.5107599981706323
save epoch  0
epoch 1:


100%|██████████| 221/221 [04:54<00:00,  1.33s/it]


train loss =  0.6926515102386475 train score =  0.5121648323330041


100%|██████████| 19/19 [00:13<00:00,  1.36it/s]


val loss =  0.6924921274185181 val score =  0.5146516291602038
save epoch  1
epoch 2:


100%|██████████| 221/221 [05:09<00:00,  1.40s/it]


train loss =  0.6914034485816956 train score =  0.5234265681512226


100%|██████████| 19/19 [00:13<00:00,  1.39it/s]


val loss =  0.6931868195533752 val score =  0.5079427487786092
epoch 3:


100%|██████████| 221/221 [04:57<00:00,  1.34s/it]


train loss =  0.6859269738197327 train score =  0.5466020443811009


100%|██████████| 19/19 [00:13<00:00,  1.44it/s]


val loss =  0.6984450221061707 val score =  0.5062540407960507
epoch 4:


100%|██████████| 221/221 [04:48<00:00,  1.30s/it]


train loss =  0.6690531969070435 train score =  0.5866740166324739


100%|██████████| 19/19 [00:13<00:00,  1.45it/s]


val loss =  0.714319109916687 val score =  0.4978416102227202
