In [1]:
import os
os.environ['HF_HOME'] = "/Users/artemmerinov/data/backbones/huggingface" # before 
os.chdir('../')

In [2]:
import torch
import torch.nn as nn
import timm
from timm.models import vision_transformer

from src.TimeSformer.timesformer.models.vit import TimeSformer

# VIT 224/384

In [3]:
vit224 = timm.create_model(
    model_name='timm/vit_base_patch16_224.augreg_in21k_ft_in1k',
    pretrained=True,
    num_classes=2000,
)
vit224_params = sum(p.numel() for p in vit224.parameters())
print(f"Total number of parameters in the ViT model: {vit224_params}")

vit224_state = vit224.state_dict()
for key in list(vit224_state.keys()):
    vit224_state[f"model.{key}"] = vit224_state.pop(key)

Total number of parameters in the ViT model: 87336656


In [5]:
vit384 = timm.create_model(
    model_name='timm/vit_base_patch16_384.augreg_in21k_ft_in1k',
    pretrained=True,
    num_classes=2000,
)
vit384_params = sum(p.numel() for p in vit384.parameters())
print(f"Total number of parameters in the ViT model: {vit384_params}")

vit384_state = vit384.state_dict()
for key in list(vit384_state.keys()):
    vit384_state[f"model.{key}"] = vit384_state.pop(key)

Total number of parameters in the ViT model: 87628496


# TimeSformer

In [17]:
timesformer = TimeSformer(
    img_size=224, 
    num_classes=2000, 
    num_frames=16, 
    attention_type='divided_space_time',
    # pretrained_model="/Users/artemmerinov/.cache/torch/hub/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth"
)
timesformer_params = sum(p.numel() for p in timesformer.parameters())
print(f"Total number of parameters in the TimeSformer model: {timesformer_params}")

timesformer_state = timesformer.state_dict()

torch.equal(
    vit224_state["model.blocks.0.attn.proj.bias"],
    timesformer_state["model.blocks.0.attn.proj.bias"]
)

Total number of parameters in the TimeSformer model: 122802896


False

# Load weigths from ViT into TimeSformer

In [22]:
timesformer = TimeSformer(
    img_size=224, 
    num_classes=2000, 
    num_frames=16,
    attention_type='divided_space_time',
    # pretrained_model="/Users/artemmerinov/.cache/torch/hub/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth"
)
timesformer_params = sum(p.numel() for p in timesformer.parameters())
print(f"Total number of parameters in the TimeSformer model: {timesformer_params}")

load_result = timesformer.load_state_dict(vit224_state, strict=False)
timesformer_state = timesformer.state_dict()

torch.equal(
    vit224_state["model.blocks.0.attn.proj.bias"],
    timesformer_state["model.blocks.0.attn.proj.bias"]
)

Total number of parameters in the TimeSformer model: 122802896


True