In [2]:
import os
os.environ['HF_HOME'] = "/Users/artemmerinov/data/backbones/huggingface" # before 

import torch.nn as nn
import timm
from timm.models import vision_transformer

from timesformer.models.vit import TimeSformer

  from .autonotebook import tqdm as notebook_tqdm


# Insert weigths from VIT to TimeSformer

In [3]:
vit_model = timm.create_model(
    model_name='timm/vit_base_patch16_224.augreg_in21k_ft_in1k',
    pretrained=True,
    num_classes=2000,
)
vit_state = vit_model.state_dict()

for key in list(vit_state.keys()):
    vit_state[f"model.{key}"] = vit_state.pop(key)

In [5]:
timesformer_model_init = TimeSformer(
    img_size=224,
    num_classes=2000, 
    num_frames=8, 
    attention_type='divided_space_time',
    # pretrained_model="/Users/artemmerinov/.cache/torch/hub/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth"
)
timesformer_model_init_state = timesformer_model_init.state_dict()

timesformer_model = TimeSformer(
    img_size=224, 
    num_classes=2000, 
    num_frames=8, 
    attention_type='divided_space_time',
    # pretrained_model="/Users/artemmerinov/.cache/torch/hub/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth"
)
timesformer_model.load_state_dict(vit_state, strict=False)
timesformer_model_state = timesformer_model.state_dict()

In [6]:
def show_layers_loaded(model_init, model):
    updated_layers = 0
    for current_params, loaded_params in zip(model_init.parameters(), model.parameters()):
        old_weight, new_weight = current_params.data, loaded_params.data
        if (old_weight - new_weight).sum() < 1e-6:
            updated_layers = updated_layers + 1

    print(f"Layers that have been loaded: {updated_layers}")
    
show_layers_loaded(timesformer_model_init, timesformer_model)

Layers that have been loaded: 176


In [7]:
vit_state["model.blocks.0.attn.proj.bias"]

tensor([-6.4844e-01, -2.2557e-02,  5.1206e-02, -4.6747e-02,  4.5531e-01,
        -5.9921e-01,  1.1520e-01,  4.2701e-02, -1.0368e-01,  2.6839e-03,
         1.1767e-05,  1.6197e-02, -4.0263e-02, -3.7523e-02, -9.1415e-02,
         1.3250e-01, -1.4634e+00,  1.9360e-01, -1.9466e-02,  7.3603e-02,
        -5.5523e-02,  5.9937e-03, -4.3688e-02, -8.9596e-02,  9.9444e-02,
         4.1670e-01,  8.1113e-03, -1.7477e-02,  5.9074e-02,  6.5741e-02,
         1.1873e-01, -7.1284e-02, -1.5079e-01, -5.2319e-02,  2.4504e-02,
         1.5026e-01,  1.7120e-01,  4.7337e-02, -1.3337e-02,  3.0768e-01,
        -1.3243e-01, -4.3470e-02,  2.9323e-02,  5.7413e-03,  2.2525e-01,
        -9.4953e-02,  1.3537e-01,  5.2324e-02,  1.8358e-02,  1.0647e-01,
         7.7834e-01,  2.7970e-03,  1.3089e-01,  1.5985e-01, -7.1230e-02,
         1.0196e-01, -4.4460e-03,  1.9149e-01, -6.0605e-02, -1.2802e-02,
        -1.7848e-01,  7.2975e-02, -6.0625e-02,  2.8588e-01,  2.2940e-02,
        -7.7762e-02,  4.1791e-02, -2.2625e-01,  2.8

In [8]:
timesformer_model_state["model.blocks.0.attn.proj.bias"]

tensor([-6.4844e-01, -2.2557e-02,  5.1206e-02, -4.6747e-02,  4.5531e-01,
        -5.9921e-01,  1.1520e-01,  4.2701e-02, -1.0368e-01,  2.6839e-03,
         1.1767e-05,  1.6197e-02, -4.0263e-02, -3.7523e-02, -9.1415e-02,
         1.3250e-01, -1.4634e+00,  1.9360e-01, -1.9466e-02,  7.3603e-02,
        -5.5523e-02,  5.9937e-03, -4.3688e-02, -8.9596e-02,  9.9444e-02,
         4.1670e-01,  8.1113e-03, -1.7477e-02,  5.9074e-02,  6.5741e-02,
         1.1873e-01, -7.1284e-02, -1.5079e-01, -5.2319e-02,  2.4504e-02,
         1.5026e-01,  1.7120e-01,  4.7337e-02, -1.3337e-02,  3.0768e-01,
        -1.3243e-01, -4.3470e-02,  2.9323e-02,  5.7413e-03,  2.2525e-01,
        -9.4953e-02,  1.3537e-01,  5.2324e-02,  1.8358e-02,  1.0647e-01,
         7.7834e-01,  2.7970e-03,  1.3089e-01,  1.5985e-01, -7.1230e-02,
         1.0196e-01, -4.4460e-03,  1.9149e-01, -6.0605e-02, -1.2802e-02,
        -1.7848e-01,  7.2975e-02, -6.0625e-02,  2.8588e-01,  2.2940e-02,
        -7.7762e-02,  4.1791e-02, -2.2625e-01,  2.8

# Load w/o weigths

In [9]:
model = timm.create_model(
    model_name='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k', 
    pretrained=False # random init
)

In [10]:
model.state_dict()["blocks.0.attn.proj.bias"]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

# Load 384

In [11]:
vit_model = timm.create_model(
    model_name='timm/vit_base_patch16_384.augreg_in21k_ft_in1k',
    pretrained=True,
    num_classes=2000,
)
vit_state = vit_model.state_dict()

for key in list(vit_state.keys()):
    vit_state[f"model.{key}"] = vit_state.pop(key)

In [12]:
timesformer_model = TimeSformer(
    img_size=384, 
    num_classes=2000, 
    num_frames=8, 
    attention_type='divided_space_time', 
    # pretrained_model="/Users/artemmerinov/.cache/torch/hub/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth"
)
timesformer_model.load_state_dict(vit_state, strict=False)
timesformer_model_state = timesformer_model.state_dict()