In [None]:
import torch

In [None]:
from typing import Optional

import torch.nn as nn

from focoos.nn.backbone.vit import VisionTransformer

# 'vit_small_patch14_reg4_dinov2.lvd142m': _cfg(
#     url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth',
#     hf_hub_id='timm/',
#     license='apache-2.0',
#     mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
#     input_size=(3, 518, 518), crop_pct=1.0),
# patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, reg_tokens=4, no_embed_class=True,
# 'vit_base_patch14_reg4_dinov2.lvd142m': _cfg(
#     url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth',
#     hf_hub_id='timm/',
#     license='apache-2.0',
#     mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
#     input_size=(3, 518, 518), crop_pct=1.0),
# patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, reg_tokens=4, no_embed_class=True,
# 'vit_large_patch14_reg4_dinov2.lvd142m': _cfg(
#     url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth',
#     hf_hub_id='timm/',
#     license='apache-2.0',
#     mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
#     input_size=(3, 518, 518), crop_pct=1.0),
#     patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, reg_tokens=4, no_embed_class=True,


class ViT(nn.Module):
    def __init__(
        self,
        img_size: tuple[int, int],
        patch_size=14,
        ckpt_path: Optional[str] = "dinov2_vits14_reg4_pretrain_timm.pth",
        embed_dim=384,
        depth=12,
        num_heads=6,
        reg_tokens=4,
    ):
        super().__init__()

        self.backbone = VisionTransformer(
            img_size=img_size,
            patch_size=patch_size,
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            init_values=1e-5,
            reg_tokens=reg_tokens,
            no_embed_class=True,
        )

        if ckpt_path is not None:
            self.backbone.load_state_dict(torch.load(ckpt_path), strict=False)
        IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
        pixel_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).reshape(1, -1, 1, 1)
        pixel_std = torch.tensor(IMAGENET_DEFAULT_STD).reshape(1, -1, 1, 1)

        self.register_buffer("pixel_mean", pixel_mean)
        self.register_buffer("pixel_std", pixel_std)

    def forward(self, x):
        x = (x - self.pixel_mean) / self.pixel_std
        x = self.backbone.forward_features(x)
        return x

In [None]:
encoder = ViT(img_size=(518, 518), patch_size=14, ckpt_path="dinov2_vits14_reg4_pretrain_timm.pth")

encoder.eval()

encoder(torch.randn(1, 3, 518, 518))

In [None]:
def convert_dinov2(state_dict):
    import re

    out_dict = {}
    state_dict.pop("mask_token", None)
    if "register_tokens" in state_dict:
        # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed)
        out_dict["reg_token"] = state_dict.pop("register_tokens")
        out_dict["cls_token"] = state_dict.pop("cls_token") + state_dict["pos_embed"][:, 0]
        out_dict["pos_embed"] = state_dict.pop("pos_embed")[:, 1:]
    for k, v in state_dict.items():
        if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
            out_dict[k.replace("w12", "fc1")] = v
            continue
        elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
            out_dict[k.replace("w3", "fc2")] = v
            continue
        out_dict[k] = v
    return out_dict


state_dict = torch.load("dinov2_vits14_reg4_pretrain.pth")

out_dict = convert_dinov2(state_dict)
torch.save(out_dict, "dinov2_vits14_reg4_pretrain_timm.pth")

In [None]:
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
from timm.models.vision_transformer import VisionTransformer as TimmVisionTransformer

# Test different resolutions that are multiples of patch_size
resolutions = [518]
times = []

for res in resolutions:
    vit = TimmVisionTransformer(
        img_size=(res, res),  # Will update dynamically
        patch_size=14,
        embed_dim=768,
        depth=12,
        num_heads=12,
        init_values=1e-5,
        reg_tokens=4,
    )

    vit.eval()
    # Create square image of size res x res
    image = torch.randn(1, 3, res, res)

    # Time the forward pass
    start = time.time()
    with torch.no_grad():
        vit.patch_embed.img_size = (res, res)  # Update image size
        features = vit.forward_features(image)
    times.append(time.time() - start)

# Plot results
plt.figure(figsize=(10, 6))
plt.plot(resolutions, times, "-o")
plt.xlabel("Image Resolution")
plt.ylabel("Time (seconds)")
plt.title("ViT Forward Pass Time vs Image Resolution")
plt.grid(True)
plt.show()

# Print some key points
print(f"Min time: {min(times):.4f}s at resolution {resolutions[np.argmin(times)]}")
print(f"Max time: {max(times):.4f}s at resolution {resolutions[np.argmax(times)]}")