In [1]:
import torch
from collections import OrderedDict
from vit import get_vit

In [2]:
ckpt = torch.load("models/vit_small_dino_2x_mixed.ckpt", map_location="cpu")

In [3]:
ckpt["hyper_parameters"]

{'arch': 'small',
 'image_size': 224,
 'patch_size': 16,
 'drop_path_rate': 0.1,
 'pos_emb': 'learned',
 'hidden_dim': 2048,
 'bottleneck_dim': 256,
 'output_dim': 65536,
 'batch_norm': True,
 'freeze_last_layer': 1,
 'norm_last_layer': True,
 'warmup_teacher_temp': 0.04,
 'teacher_temp': 0.06,
 'warmup_teacher_temp_epochs': 30,
 'student_temp': 0.1,
 'center_momentum': 0.9,
 'lr_start': 0.0016875,
 'lr_final': 1e-06,
 'lr_warmup_epochs': 10,
 'wd_start': 0.04,
 'wd_final': 0.4,
 'mm_start': 0.994,
 'mm_final': 1.0}

In [4]:
hparams = {
    k: ckpt["hyper_parameters"][k]
    for k in ["arch", "image_size", "patch_size", "drop_path_rate", "pos_emb"]
}
model = get_vit(**hparams)

In [5]:
train_state_dict = ckpt["state_dict"]
vit_state_dict = OrderedDict()
for key in train_state_dict:
    if key.startswith("teacher_backbone"):
        _, new_key = key.split("teacher_backbone.")
        vit_state_dict[new_key] = train_state_dict[key]
model.load_state_dict(vit_state_dict)

<All keys matched successfully>

In [6]:
model.eval()
x = torch.randn(1, 3, 224, 224)
y = model(x)
y.shape

torch.Size([1, 384])

In [7]:
new_ckpt = {"hyperparams": hparams, "state_dict": vit_state_dict}
torch.save(new_ckpt, "vit_small_dino_2x_mixed.pt")