In [None]:
import timm
import torch
import torch.nn as nn

In [None]:

embed_sizes={"dinov2_vits14": 384,
        "dinov2_vitb14": 768,
        "dinov2_vitl14": 1024,
        "dinov2_vitg14": 1536}

PATH_MODEL = "../models/RedDino_l.pth"
MODEL_NAME = "dinov2_vitl14"
TIMM_NAME = "Snarcy/RedDino-large"

def get_dino_torch(modelpath="/content/dinobloom-s.pth",modelname="dinov2_vits14"):
    # load the original DINOv2 model with the correct architecture and parameters.
    model=torch.hub.load('facebookresearch/dinov2', modelname)
    # load finetuned weights
    pretrained = torch.load(modelpath, map_location=torch.device('cpu'))
    # make correct state dict for loading
    new_state_dict = {}
    for key, value in pretrained['teacher'].items():
        if 'dino_head' in key or "ibot_head" in key:
            pass
        else:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = value

    #corresponds to 224x224 image. patch size=14x14 => 16*16 patches
    pos_embed = nn.Parameter(torch.zeros(1, 257, embed_sizes[modelname]))
    model.pos_embed = pos_embed

    model.load_state_dict(new_state_dict, strict=True)
    return model

model=get_dino_torch(PATH_MODEL,MODEL_NAME)



In [None]:
model_timm = timm.create_model('hf_hub:'+TIMM_NAME, pretrained=True)

input_img = torch.randn(1, 3, 224, 224)
output = model(input_img)
output_timm = model_timm(input_img)

print(output.shape)
print(output_timm.shape)



In [None]:
#are the outputs the same?

print(torch.allclose(output,output_timm,atol=1e-4))

print(output[0,:10])
print(output_timm[0,:10])