### Base settings

In [1]:
%load_ext autoreload
%autoreload 2

# %env CUDA_VISIBLE_DEVICES=1

import os
import time
from pprint import pprint
from pathlib import Path
from collections import OrderedDict
from tqdm import tqdm
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px

import torch
import torch.nn as nn
import torchvision.transforms as transforms

from einops import rearrange, reduce, repeat
from IPython.display import display

import video_transformer.data_transform as T
from video_transformer.dataset import DecordInit, load_annotation_data
from video_transformer.transformer import PatchEmbed, TransformerContainer, ClassificationHead

device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Prepare kinetics400 dataset

In [2]:
import pickle

with open('./data/kinetics400_val_metadata.pkl', 'rb') as f:
    kinetics400_val_metadata = pickle.load(f)

In [14]:
# from torchvision.datasets import Kinetics
from ood_with_vit.datasets.kinetics import MyKinetics
from torch.utils.data import DataLoader, Dataset

dataset_root = '~/workspace/dataset/kinetics/k400'
dataset_root = os.path.expanduser(dataset_root)
print(dataset_root)

dataset_mean, dataset_std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225)
val_transform = T.create_video_transform(
    input_size=224,
    is_training=False,
    interpolation='bicubic',
    mean=dataset_mean,
    std=dataset_std,
)
# data_transform = T.Compose([
#     T.Resize(scale_range=(-1, 256)),
#     T.ThreeCrop(size=224),
#     T.ToTensor(),
#     T.Normalize(dataset_mean, dataset_std)
# ])
# data_transform.randomize_parameters()

kinetics400_val_ds = MyKinetics(
    root=dataset_root,
    frames_per_clip=16,
    split='val',
    num_workers=24,
    frame_rate=2,
    step_between_clips=1,
    transform=val_transform,
    _precomputed_metadata=kinetics400_val_metadata,
)

kinetics400_val_dl = DataLoader(
    dataset=kinetics400_val_ds,
    batch_size=32,
    shuffle=False,
    num_workers=32,
)

/home/simc/workspace/dataset/kinetics/k400


In [4]:
video, label = kinetics400_val_ds[0]
print(video.shape, kinetics400_val_ds.classes[label])

torch.Size([16, 3, 224, 224]) abseiling


In [6]:
video_paths = kinetics400_val_ds.metadata['video_paths']
video_pts = kinetics400_val_ds.metadata['video_pts']
video_fps = kinetics400_val_ds.metadata['video_fps']
clips = kinetics400_val_ds.video_clips.clips
print('video paths:', type(video_paths), len(video_paths), video_paths[0])
print('video pts:', type(video_pts), len(video_pts), video_pts[0].shape)
print('video fps:', type(video_fps), len(video_fps), video_fps[0])
print('clips', type(clips), len(clips), clips[0].shape)
print('cumulative sizes:', kinetics400_val_ds.video_clips.cumulative_sizes[-1])

video paths: <class 'list'> 19881 /home/simc/workspace/dataset/kinetics/k400/val/abseiling/0wR5jVB-WPk_000417_000427.mp4
video pts: <class 'list'> 19881 torch.Size([300])
video fps: <class 'list'> 19881 29.97002997002997
clips <class 'list'> 19881 torch.Size([5, 16])
cumulative sizes: 88540


#### Prepare ViViT model

In [5]:
from video_transformer.video_transformer import ViViT

In [6]:
def replace_state_dict(state_dict):
	for old_key in list(state_dict.keys()):
		if old_key.startswith('model'):
			new_key = old_key[6:]
			state_dict[new_key] = state_dict.pop(old_key)
		else:
			new_key = old_key[9:]
			state_dict[new_key] = state_dict.pop(old_key)

In [7]:
def init_from_pretrain_(module, pretrained, init_module):
    if torch.cuda.is_available():
        state_dict = torch.load(pretrained)
    else:
        state_dict = torch.load(pretrained, map_location=torch.device('cpu'))
    if init_module == 'transformer':
        replace_state_dict(state_dict)
    elif init_module == 'cls_head':
        replace_state_dict(state_dict)
    else:
        raise TypeError(f'pretrained weights do not include the {init_module} module')
    msg = module.load_state_dict(state_dict, strict=False)
    return msg

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
num_frames = 8
frame_interval = 32
num_class = 400
arch = 'vivit' # turn to vivit for initializing vivit model

pretrain_pth = './logs/vivit/vivit_model.pth'
num_frames = num_frames * 2
frame_interval = frame_interval // 2
model = ViViT(
    num_frames=num_frames,
    img_size=224,
    patch_size=16,
    embed_dims=768,
    in_channels=3,
    attention_type='fact_encoder',
    return_cls_token=True,
    pretrain_pth=pretrain_pth,
    weights_from='kinetics',
)

cls_head = ClassificationHead(num_classes=num_class, in_channels=768)
# msg_trans = init_from_pretrain_(model, pretrain_pth, init_module='transformer')
msg_cls = init_from_pretrain_(cls_head, pretrain_pth, init_module='cls_head')
model.eval()
cls_head.eval()
model = model.to(device)
cls_head = cls_head.to(device)
print(f'load model finished, the missing key of cls is:{msg_cls[0]}')

_IncompatibleKeys(missing_keys=[], unexpected_keys=['cls_head.weight', 'cls_head.bias'])
load model finished, the missing key of cls is:[]


In [15]:
model.eval()
criterion = nn.CrossEntropyLoss()
total_test_loss, n_correct, n_total = 0, 0, 0
with torch.no_grad():
    start = time.time()
    for batch_idx, (x, y) in enumerate(tqdm(kinetics400_val_dl)):
        # print(f'batch: {batch_idx} {time.time() - start:.3f}')
        start = time.time()
        x, y = x.to(device), y.to(device)
        outputs = model(x)
        outputs = cls_head(outputs)
        loss = criterion(outputs, y)
        # print(outputs.shape)
        # print(f'model: {time.time() - start:.3f}')
        start = time.time()

        total_test_loss += loss.item()
        _, predicted = outputs.max(1)
        n_total += y.size(0)
        n_correct += predicted.eq(y).sum().item()

    avg_test_loss = total_test_loss / (batch_idx + 1)
    test_accuracy = 100. * n_correct / n_total
    print(f'Test Loss: {avg_test_loss:.3f} | Test Acc: {test_accuracy:.3f}% ({n_correct}/{n_total})')

# return total_test_loss, test_accuracy

100%|██████████| 2767/2767 [57:24<00:00,  1.24s/it]  

Test Loss: 1.105 | Test Acc: 74.190% (65688/88540)





In [6]:
from IPython.display import display, HTML

video_path = './YABnJL_bDzw.mp4'
html_str = '''
<video controls width=\"480\" height=\"480\" src=\"{}\">animation</video>
'''.format(video_path)
display(HTML(html_str))

In [16]:
# Prepare data preprocess
mean, std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225)
data_transform = T.Compose([
        T.Resize(scale_range=(-1, 256)),
        T.ThreeCrop(size=224),
        T.ToTensor(),
        T.Normalize(mean, std)
        ])
temporal_sample = T.TemporalRandomCrop(num_frames*frame_interval)

# Sampling video frames
video_decoder = DecordInit()
v_reader = video_decoder(video_path)
total_frames = len(v_reader)
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
print(total_frames, start_frame_ind, end_frame_ind)
if end_frame_ind-start_frame_ind < num_frames:
    raise ValueError(f'the total frames of the video {video_path} is less than {num_frames}')
frame_indice = np.linspace(0, end_frame_ind-start_frame_ind-1, num_frames, dtype=int)
video = v_reader.get_batch(frame_indice).asnumpy()
del v_reader

print('original:', video.shape)
video = torch.from_numpy(video).permute(0,3,1,2) # Video transform: T C H W
data_transform.randomize_parameters()
video = data_transform(video)
video = video.to(device)
print('transformed:', video.shape)

302 22 278
original: (16, 256, 454, 3)
transformed: torch.Size([3, 16, 3, 224, 224])


In [15]:
# Predict class label
with torch.no_grad():
    logits = model(video)
    output = cls_head(logits)
    print(output.shape)
    output = output.view(3, 400).mean(0)
    cls_pred = output.argmax().item()
    
print(f'the shape of ouptut: {output.shape}, \
    and the prediction is: {kinetics400_val_ds.classes[cls_pred]}')

torch.Size([3, 400])
the shape of ouptut: torch.Size([400]),     and the prediction is: laughing
