In [1]:
import numpy as np
import torch
import cv2
import json
from pathlib import Path
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import os
import random
from utils import *
import requests
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch import nn
import time
import wandb
from scipy.special import comb, perm
import copy
import itertools
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# read inference data
data_path = '/home/jianghui/dataset/VideoReorder-MovieNet'
split = 'val'
val_data = VideoReorderMovieNetDataFolder(root=data_path, split=split, layer='')
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=(split == 'train'), num_workers=8, pin_memory=True, collate_fn=lambda x: x)

In [3]:
loss_func = nn.CrossEntropyLoss()
loss_func.to(device)

CrossEntropyLoss()

In [4]:
# scene order on clip model
scene_model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512,2)
)
checkpoint = torch.load(Path('./checkpoint', f'scene_to_clip_best_2023-02-23_01.pth'))
scene_model.load_state_dict(checkpoint)
scene_model.to(device)
scene_model.eval()

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=2048, out_features=1024, bias=True)
  (2): ReLU()
  (3): Linear(in_features=1024, out_features=512, bias=True)
  (4): ReLU()
  (5): Linear(in_features=512, out_features=2, bias=True)
)

In [5]:
# shot order on scene model
shot_model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512,2)
)
checkpoint = torch.load(Path('./checkpoint', f'frame_to_scene_best_2023-02-22_01.pth'))
shot_model.load_state_dict(checkpoint)
shot_model.to(device)
shot_model.eval()

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=2048, out_features=1024, bias=True)
  (2): ReLU()
  (3): Linear(in_features=1024, out_features=512, bias=True)
  (4): ReLU()
  (5): Linear(in_features=512, out_features=2, bias=True)
)

In [6]:
# frame order on shot model
frame_model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512,2)
)
checkpoint = torch.load(Path('./checkpoint', f'frame_to_shot_best_2023-02-22_01.pth'))
frame_model.load_state_dict(checkpoint)
frame_model.to(device)
frame_model.eval()

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=2048, out_features=1024, bias=True)
  (2): ReLU()
  (3): Linear(in_features=1024, out_features=512, bias=True)
  (4): ReLU()
  (5): Linear(in_features=512, out_features=2, bias=True)
)

In [7]:
score_list = []
for data in tqdm(val_dataloader):
    # load data
    features, img_id, shot_id, scene_id = data[0]
    input_id = [i for i in range(len(img_id))]
    gt_id = get_order_index(img_id)

    # scene cluster
    features.squeeze_()
    gt_scene_clustered, _ = clip_to_clip(img_id, shot_id, scene_id)
    features, input_id = KMeanCLustering(features=features, input_id=input_id, gt_clusters=gt_scene_clustered,layer='scene')

    # scene reorder
    N_scene = len(gt_scene_clustered)
    features_scene = [torch.mean(torch.stack(i, dim=0), dim=0) for i in features]
    score_square = [[float('-inf') for i in range(N_scene)]for i in range(N_scene)]

    for I in range(N_scene):
        for J in range(N_scene):
            if I == J: continue
            output = scene_model(torch.concat((features_scene[I], features_scene[J])).unsqueeze(0).to(device))
            score_square[I][J] = torch_to_list(output[0][1]-output[0][0])
    
    scene_order = beam_search_all(score_square)['path']

    # print(scene_order)

    input_id = same_shuffle(input_id, scene_order)
    features = same_shuffle(features, scene_order)

    # shot cluster
    gt_shot_clustered, _ = clip_to_scene(img_id, shot_id, scene_id)
    features, input_id = KMeanCLustering(features=features, input_id=input_id, gt_clusters=gt_shot_clustered,layer='shot')

    # shot order 
    N_scene = len(features)
    for idx in range(N_scene):
        N_shot = len(input_id[idx])
        features_shot = [torch.mean(torch.stack(i, dim=0), dim=0) for i in features[idx]]
        score_square = [[float('-inf') for i in range(N_shot)]for i in range(N_shot)]

        for I in range(N_shot):
            for J in range(N_shot):
                if I == J: continue
                output = shot_model(torch.concat((features_shot[I], features_shot[J])).unsqueeze(0).to(device))
                score_square[I][J] = torch_to_list(output[0][1]-output[0][0])   

        shot_order = beam_search_all(score_square)['path']

        features[idx] = same_shuffle(features[idx], shot_order)
        input_id[idx] = same_shuffle(input_id[idx], shot_order)

    # frame reorder
    N_scene = len(input_id)
    for idx in range(N_scene):
        N_shot = len(input_id[idx])
        for jdx in range(N_shot):
            N_frame = len(input_id[idx])
            features_frame = [torch.mean(torch.stack(i, dim=0), dim=0) for i in features[idx]]
            score_square = [[float('-inf') for i in range(N_frame)]for i in range(N_frame)]

            for I in range(N_frame):
                for J in range(N_frame):
                    if I == J: continue
                    output = frame_model(torch.concat((features_frame[I], features_frame[J])).unsqueeze(0).to(device))
                    score_square[I][J] = torch_to_list(output[0][1]-output[0][0])

            frame_order = beam_search_all(score_square)['path']

            features[idx][jdx] = same_shuffle(features[idx][jdx], frame_order)
            input_id[idx][jdx] = same_shuffle(input_id[idx][jdx], frame_order)

    # print(input_id)
    pred = list_to_one_dim(input_id)
    # print(pred)
    score = DoubleLengthMatching(pred, gt_id)
    # print(score)

    score_list.append(score)

print(sum(score_list) / len(score_list))

100%|██████████| 589/589 [01:20<00:00,  7.30it/s]

0.49057759298117937



