In [1]:
import numpy as np
import torch
import cv2
import json
from pathlib import Path
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import os
import random
from utils import *
import requests
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch import nn

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [3]:
data_path = '/home/jianghui/dataset/VideoReorder-MovieNet'
split = 'train'
train_data = VideoReorderMovieNetDataFolder(root=data_path, split=split)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

split = 'val'
val_data = VideoReorderMovieNetDataFolder(root=data_path, split=split)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

In [4]:
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(1024, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512)
)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)
net.to(device)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=1024, out_features=1024, bias=True)
  (2): ReLU()
  (3): Linear(in_features=1024, out_features=512, bias=True)
)

In [5]:
# Loss 1
lr = 1e-4
epoch = 40

loss_func = ClipPairWiseLoss()
loss_func.to(device)
pred_func = ClipPairWisePred()
pred_func.to(device)
optim = torch.optim.AdamW(net.parameters(), lr=lr)

for e in range(epoch):
    print(f'epoch {e}:')
    running_loss = 0.0

    # train
    score_list = []
    
    for batch_data in tqdm(train_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for clip_data in batch_data: #clip data

            # read input data
            clip_input, clip_gt_id, clip_shot_id, clip_scene_id = clip_data
            clip_input.to(device)

            optim.zero_grad()

            clip_output = net(clip_input.to(device))
            
            # record loss
            loss_clip = loss_func(clip_output, clip_gt_id)
            loss_batch_list.append(loss_clip)
            # record score
            pred_list = pred_func(clip_output)
            # score = TripleLengthMatching(pred_list, clip_gt_id)
            score = DoubleLengthMatching(pred_list, clip_gt_id)
            score_batch_list.append(score)
            print(pred_list, clip_gt_id, score, loss_clip.item())
            # calcuclate avearge loss
            loss_clip.backward()
            optim.step()
        # caculate avearge score
        average_batch_score = sum(score_batch_list)/len(score_batch_list)
        score_list.append(average_batch_score)

    loss = sum(loss_batch_list) / len(loss_batch_list)
    average_score = sum(score_list) / len(score_list)
    print('train loss = ', loss.item(), 'train score = ', average_score)  

    # val
    score_list = []
    for batch_data in tqdm(val_dataloader):
        loss_batch_list = []
        score_batch_list = []

        for clip_data in batch_data:

            # read input data
            clip_input, clip_gt_id, clip_shot_id, clip_scene_id = clip_data
            clip_output = net(clip_input.to(device))

            # record loss
            loss_clip = loss_func(clip_output, clip_gt_id)
            loss_batch_list.append(loss_clip)
            #record score
            pred_list = pred_func(clip_output)
            # score = TripleLengthMatching(pred_list, clip_gt_id)
            score = DoubleLengthMatching(pred_list, clip_gt_id)
            score_batch_list.append(score)

        # calcuclate avearge loss
        loss = sum(loss_batch_list) / len(loss_batch_list)
        # cal avearage score
        average_batch_score = sum(score_batch_list)/len(score_batch_list)
        score_list.append(average_batch_score)
    average_score = sum(score_list) / len(score_list)
    print('val loss = ', loss.item(), 'val score = ', average_score) 


epoch 0:


100%|██████████| 56/56 [02:25<00:00,  2.61s/it]


train loss =  0.001221206272020936 train score =  0.5010716027460849


100%|██████████| 5/5 [00:07<00:00,  1.52s/it]


val loss =  0.0020686599891632795 val score =  0.5056603271104285
epoch 1:


100%|██████████| 56/56 [02:31<00:00,  2.70s/it]


train loss =  9.351068729301915e-05 train score =  0.5006146899230474


100%|██████████| 5/5 [00:07<00:00,  1.55s/it]


val loss =  0.0007488965638913214 val score =  0.502397060811776
epoch 2:


100%|██████████| 56/56 [02:59<00:00,  3.20s/it]


train loss =  0.0 train score =  0.500019162340784


100%|██████████| 5/5 [00:10<00:00,  2.02s/it]


val loss =  0.0006653349264524877 val score =  0.5070711575260924
epoch 3:


100%|██████████| 56/56 [02:59<00:00,  3.21s/it]


train loss =  0.0 train score =  0.5000512651720045


100%|██████████| 5/5 [00:10<00:00,  2.00s/it]


val loss =  0.0006559159373864532 val score =  0.5070711575260924
epoch 4:


100%|██████████| 56/56 [03:00<00:00,  3.22s/it]


train loss =  0.0 train score =  0.5002365779551153


100%|██████████| 5/5 [00:10<00:00,  2.02s/it]


val loss =  0.0006464524776674807 val score =  0.5070711575260924
epoch 5:


100%|██████████| 56/56 [02:58<00:00,  3.19s/it]


train loss =  0.0 train score =  0.5003982373332599


100%|██████████| 5/5 [00:10<00:00,  2.03s/it]


val loss =  0.0006373508949764073 val score =  0.5070711575260924
epoch 6:


100%|██████████| 56/56 [03:04<00:00,  3.29s/it]


train loss =  0.0 train score =  0.5001207926222572


100%|██████████| 5/5 [00:10<00:00,  2.02s/it]


val loss =  0.0006282954127527773 val score =  0.5070711575260924
epoch 7:


  2%|▏         | 1/56 [00:04<04:20,  4.73s/it]