In [1]:
import timm
import torch
from timm.models.vision_transformer import VisionTransformer, _cfg


# Load a standard Vision Transformer model for reference
reference_model = timm.create_model('vit_base_patch16_224', pretrained=False)


reference_model.head = torch.nn.Identity()
im = torch.Tensor(1,3,224,224)

reference_model(im).shape

torch.Size([1, 768])

In [2]:
reference_model = timm.create_model('vit_base_patch16_224', pretrained=False)
embed_dim = reference_model.blocks[0].attn.qkv.in_features  # Double the typical base model's embed_dim
num_heads = reference_model.blocks[0].attn.num_heads #Double the typical base model's num_heads
out_features = reference_model.blocks[0].mlp.fc1.out_features
mlp_ratio = int(out_features/embed_dim)   # Double the typical base model's mlp_ratio
num_layers = len(reference_model.blocks) #     # Typical number of layers for a base model


model = VisionTransformer(
    img_size=224,
    patch_size=16,
    embed_dim=embed_dim,
    depth=num_layers,
    num_heads=num_heads,
    mlp_ratio=mlp_ratio*4,
    num_classes=1000  
)
pretrained_model = timm.create_model('vit_base_patch16_224', pretrained=True)

model.pos_embed = pretrained_model.pos_embed
model.head = torch.nn.Identity()

In [4]:
model(im).shape

torch.Size([1, 768])

In [8]:
print(reference_model)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [6]:
model.head = torch.nn.Linear(768, 10000)

In [7]:
print(model)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=12288, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity