In [None]:
import timm
import torch
import torch.nn as nn

In [None]:

embed_sizes={"dinov2_vits14": 384,
        "dinov2_vitb14": 768,
        "dinov2_vitl14": 1024,
        "dinov2_vitg14": 1536}

PATH_MODEL = "../models/RedDino_l.pth"
MODEL_NAME = "dinov2_vitl14"
TIMM_NAME = "vit_large_patch14_dinov2"

def get_dino_torch(modelpath="/content/dinobloom-s.pth",modelname="dinov2_vits14"):
    # load the original DINOv2 model with the correct architecture and parameters.
    model=torch.hub.load('facebookresearch/dinov2', modelname)
    # load finetuned weights
    pretrained = torch.load(modelpath, map_location=torch.device('cpu'))
    # make correct state dict for loading
    new_state_dict = {}
    for key, value in pretrained['teacher'].items():
        if 'dino_head' in key or "ibot_head" in key:
            pass
        else:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = value

    #corresponds to 224x224 image. patch size=14x14 => 16*16 patches
    pos_embed = nn.Parameter(torch.zeros(1, 257, embed_sizes[modelname]))
    model.pos_embed = pos_embed

    model.load_state_dict(new_state_dict, strict=True)
    return model

model=get_dino_torch(PATH_MODEL,MODEL_NAME)



In [None]:
# do the same but now with timm models

model_timm = timm.create_model(TIMM_NAME, pretrained=True)
model_timm.head = nn.Identity()
pos_embed = nn.Parameter(torch.zeros(1, 257, embed_sizes[MODEL_NAME]))
model_timm.pos_embed = pos_embed
new_state_dict = {}
for key, value in model.state_dict().items():
    if 'mask_token' in key:
        pass
    else:
        new_state_dict[key] = value
model_timm.load_state_dict(new_state_dict, strict=True)

# check if the models are the same
model.eval()
model_timm.eval()
#get state dict of model without the mask_token 
state_dict = model.state_dict()
state_dict_timm = model_timm.state_dict()
new_state_dict = {}
new_state_dict_timm = {}
for name, param in state_dict.items():
    #check if is the mask token
    if 'mask_token' in name:
        continue
    new_state_dict[name] = param
for name, param in state_dict_timm.items(): 
    new_state_dict_timm[name] = param
    
#show both names and sizes for each parameter 
for (name, param), (name_timm, param_timm) in zip(new_state_dict.items(), new_state_dict_timm.items()):
    #print(name, param.size(), name_timm, param_timm.size())
    if not torch.equal(param, param_timm):
        print("Not equal")
        break
   


In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:



#default_config = timm.models.vision_transformer.default_cfgs[TIMM_NAME]
#change the input size to 224x224
#print(default_config)
#default_config.cfgs['lvd142m'].input_size=[3,224,224]
#default_config= default_config.cfgs['lvd142m']
config = {
    "architecture": "vit_large_patch14_dinov2",
    "num_classes": 0,
    "num_features": 1024,
    "global_pool": "token",
    "pretrained_cfg": {
        "tag": "lvd142m",
        "custom_load": False,
        "input_size": [3, 224, 224],
        "fixed_input_size": True,
        "interpolation": "bicubic",
        "crop_pct": 1.0,
        "crop_mode": "center",
        "mean": [0.485, 0.456, 0.406],
        "std": [0.229, 0.224, 0.225],
        "num_classes": 0,
        "pool_size": None,
        "first_conv": "patch_embed.proj",
        "classifier": "head",
        "license": "cc-by-nc-4.0"
    }
}
timm.models.push_to_hf_hub(model_timm, 'RedDino-large', model_config=config, commit_message="RedDino_config")