In [4]:
import numpy as np
import torch
import cv2
import json
from pathlib import Path
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import os
import random
from utils import *
import requests
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch import nn

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [None]:
data_path = '/home/jianghui/dataset/VideoReorder-MovieNet'
split = 'train'
train_data = VideoReorderMovieNetDataFolder(root=data_path, split=split)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

split = 'val'
val_data = VideoReorderMovieNetDataFolder(root=data_path, split=split)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

In [None]:
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(1024, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512)
)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)
net.to(device)

In [None]:
# Loss 1
lr = 1e-4
epoch = 40

loss_func = ClipPairWiseLoss()
loss_func.to(device)
pred_func = ClipPairWisePred()
pred_func.to(device)
optim = torch.optim.AdamW(net.parameters(), lr=lr)

for e in range(epoch):
    print(f'epoch {e}:')
    running_loss = 0.0

    # train
    score_list = []
    for batch_data in tqdm(train_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for clip_data in batch_data: #clip data

            # read input data
            clip_input, clip_gt_id, clip_shot_id, clip_scene_id = clip_data
            clip_input.to(device)

            optim.zero_grad()

            clip_output = net(clip_input.to(device))
            
            # record loss
            loss_clip = loss_func(clip_output, clip_gt_id)
            loss_batch_list.append(loss_clip)
            # record score
            pred_list = pred_func(clip_output)
            score = TripleLengthMatching(pred_list, clip_gt_id)
            score_batch_list.append(score)
        # calcuclate avearge loss
        loss = sum(loss_batch_list) / len(loss_batch_list)
        loss.backward()
        optim.step()
        # caculate avearge score
        average_batch_score = sum(score_batch_list)/len(score_batch_list)
        score_list.append(average_batch_score)

    average_score = sum(score_list) / len(score_list)
    print('train loss = ', loss.item(), 'train score = ', average_score)  

    # val
    score_list = []
    for batch_data in tqdm(val_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for clip_data in batch_data:

            # read input data
            clip_input, clip_gt_id, clip_shot_id, clip_scene_id = clip_data
            clip_output = net(clip_input.to(device))

            # record loss
            loss_clip = loss_func(clip_output, clip_gt_id)
            loss_batch_list.append(loss_clip)
            #record score
            pred_list = pred_func(clip_output)
            score = TripleLengthMatching(pred_list, clip_gt_id)
            score_batch_list.append(score)

        # calcuclate avearge loss
        loss = sum(loss_batch_list) / len(loss_batch_list)
        # cal avearage score
        average_batch_score = sum(score_batch_list)/len(score_batch_list)
        score_list.append(average_batch_score)
    average_score = sum(score_list) / len(score_list)
    print('val loss = ', average_score, 'val score = ', average_score) 


In [2]:
Pred = [0, 1, 3, 2]
GT = [0, 1, 2, 3]

In [5]:
Pred_Sublist = rSublist(Pred, 3)
GT_Sublist = rSublist(GT, 3)
Pred_Sublist, GT_Sublist

([(0, 1, 3), (0, 1, 2), (0, 3, 2), (1, 3, 2)],
 [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)])

In [6]:
score_num = len(set(Pred_Sublist) & set(GT_Sublist))
score_deno = len(Pred_Sublist)
score_num, score_deno

2