## Imports

In [16]:
from pipelines.dataset import SignDataset
from torch.utils.data import DataLoader
import json
import os

In [17]:
ann_file = '../data/dummy_wlasl/wlasl_amaglgam.JSON'

In [18]:
dataset = SignDataset(ann_file=ann_file,
                     root_dir='../data/dummy_wlasl/rawframes/',
                     split='test',
                     clip_len=16,
                     frame_interval=4,
                     num_clips=1,
                     resolution=224,
                     test_mode=True)

In [19]:
# Accident 0
accident_0 = dataset[0][0]
accident_0.shape

torch.Size([3, 16, 224, 224])

## Trimming the transformer model

In [24]:
checkpoint = torch.load('model/vit-base-p16_videomaev2-vit-g-dist-k710-pre_16x4x1_kinetics-400_20230510-3e7f93b2.pth')

In [43]:
import torch.nn as nn
from model.cls_head import ClassifierHead
from model.vit_mae import VisionTransformer

In [44]:
backbone = VisionTransformer(
                            img_size=224,
                            patch_size=16,
                            embed_dims=768,
                            depth=12,
                            num_heads=12,
                            mlp_ratio=4,
                            qkv_bias=True,
                            num_frames=16,
                            norm_cfg=dict(type='LN', eps=1e-6))

In [46]:
cls_head = ClassifierHead(in_features=768,
                         num_classes=400)

In [47]:
class VideoMAE(nn.Module):
    def __init__(self, backbone, cls_head):
        super(VideoMAE, self).__init__()
        self.backbone = backbone
        self.cls_head = cls_head
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.cls_head(x)
        
        return x

In [48]:
model = VideoMAE(backbone, cls_head)

In [51]:
backbone.pos_embed

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  8.2843e-01,  ...,  1.0000e+00,
           1.0243e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.2799e-01,  ...,  1.0000e+00,
           2.0486e-04,  1.0000e+00],
         ...,
         [ 4.6785e-01,  8.8381e-01,  8.8921e-01,  ...,  9.8655e-01,
           1.5961e-01,  9.8718e-01],
         [ 9.9648e-01,  8.3839e-02,  8.7704e-01,  ...,  9.8653e-01,
           1.5971e-01,  9.8716e-01],
         [ 6.0895e-01, -7.9321e-01,  9.3232e-02,  ...,  9.8652e-01,
           1.5982e-01,  9.8715e-01]]])

In [52]:
model.load_state_dict(checkpoint, strict=False)

_IncompatibleKeys(missing_keys=['backbone.pos_embed'], unexpected_keys=[])

In [54]:
value = model.backbone(accident_0.unsqueeze(dim=0))

In [55]:
value.shape

torch.Size([1, 768])