In [33]:
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import ViTForImageClassification,ViTConfig
from dataclasses import dataclass
import math

In [125]:
model= ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") #finetuned on imagenet

In [126]:
model_hf_1=model.state_dict()

In [69]:
conf=ViTConfig()
print(conf)

ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.41.0"
}



In [70]:
@dataclass
class ViT_Config():
    image_size = 224
    patch_size = 16
    embed_dim=768
    ff_dim=768*4 
    num_heads=12
    layers=12



In [120]:
x=0
for k,v in model_hf.items():
    x=x+1
print(x)

200


In [140]:
import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, max_length, embed_dim, ff_dim, num_heads, dropout=0.1):

        super(TransformerBlock, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisble by num_heads"

        self.max_length = max_length
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.num_heads = num_heads
        self.dp = dropout

        #derv:
        self.head_size = self.embed_dim // self.num_heads

        #attention blocks
        self.attention=nn.ModuleDict(dict(
            attention=nn.ModuleDict(dict(
                                    query = nn.Linear(self.embed_dim, self.embed_dim),
                                    key = nn.Linear(self.embed_dim, self.embed_dim),
                                    value = nn.Linear(self.embed_dim, self.embed_dim),
                                    )),
            
            output=nn.ModuleDict(dict(
                                dense=nn.Linear(self.embed_dim, self.ff_dim)
                                    ))
        ))

        self.intermediate=nn.ModuleDict(dict(
                        dense=nn.Linear(self.embed_dim, self.ff_dim)
        ))
        self.output=nn.ModuleDict(dict(
                        dense=nn.Linear(self.embed_dim, self.ff_dim),
        ))
        
        
        #after attn and ff blocks
        # self.dropout = nn.Dropout(self.dp, inplace=True)

        #depends
        self.layernorm_before = nn.LayerNorm(self.embed_dim)
        self.layernorm_after = nn.LayerNorm(self.embed_dim)


class ViT_Config:
    def __init__(self):
        self.image_size = 224
        self.patch_size = 16
        self.embed_dim = 768
        self.ff_dim = 3072
        self.num_heads = 12
        self.layers = 12

class ViT(nn.Module):
    def __init__(self, image_size, patch_size, embed_dim, ff_dim, num_heads, layers):
        super(ViT, self).__init__()
        assert image_size % patch_size == 0

        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.num_heads = num_heads
        self.layers = layers
    
        self.num_channels = 3
        self.num_patches = (image_size // patch_size) ** 2
        
        self.register_buffer("cls_token", torch.ones(1, 1, self.embed_dim))

        self.position_embeddings = nn.Parameter(torch.randn(1, self.num_patches+1, self.embed_dim))

        #For patch embedding fn:
        self.projection = nn.Conv2d(
            self.num_channels, 
            self.embed_dim, 
            kernel_size=self.patch_size, 
            stride=self.patch_size
        )


        

        self.dropout = nn.Dropout(0.1, inplace=True)
        self.vit=nn.ModuleDict(dict(
            encoder = nn.ModuleDict(layer=nn.ModuleList([
                                        TransformerBlock(
                                            self.num_patches, 
                                            self.embed_dim, 
                                            self.ff_dim, 
                                            self.num_heads
                                        ) for _ in range(self.layers)
                                            ])
        )))
        

        self.ln_f = nn.LayerNorm(self.embed_dim)

        self.head = nn.Linear(self.embed_dim, 1000)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.projection(x).flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.position_embeddings
        x = self.dropout(x)

        for layer in self.encoder_layers:
            x = layer(x)

        x = self.ln_f(x[:, 0])
        x = self.head(x)
        return x
    @classmethod
    def from_pretrained(cls, model_type):
        assert model_type in ('google/vit-base-patch16-224')

        model_hf = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") #finetuned on imagenet
        sd_hf = model_hf.state_dict()
        image_size = 224
        patch_size = 16
        config = {
            'google/vit-base-patch16-224' : dict(embed_dim=768,  ff_dim=768*4,  num_heads=12, layers=12),
        }[model_type]


        model = cls(image_size, patch_size, **config)
        sd = model.state_dict()

        assert len(sd) == len(sd_hf), "mismatch state dict, maybe you forgot to consider something"

        mapping = {
            'vit.embeddings.cls_token': 'cls_token',
            'vit.embeddings.position_embeddings': 'position_embeddings',
            'vit.embeddings.patch_embeddings.projection.weight': 'projection.weight',
            'vit.embeddings.patch_embeddings.projection.bias': 'projection.bias',
            'vit.layernorm.weight': 'ln_f.weight',
            'vit.layernorm.bias': 'ln_f.bias',
            'classifier.weight': 'head.weight',
            'classifier.bias': 'head.bias'
        }
        # assert len(mapping.keys()) == len(sd_hf.keys()), "mismatch mapping between the models"

        from tqdm import tqdm
        print("Importing ViT")
        for k in tqdm(sd_hf):
            kn = mapping[k]
            assert sd_hf[k].shape == sd[kn].shape;
            with torch.no_grad():
                sd[kn].copy_(sd_hf[k])

        return model
conf = ViT_Config()
vit = ViT.from_pretrained("google/vit-base-patch16-224")
# model = ViT(conf)


TypeError: ModuleDict.__init__() got an unexpected keyword argument 'layer'

In [None]:
i=0
for k in model.state_dict().keys():
    print(k)
# print(i)

position_embeddings
cls_token
projection.weight
projection.bias
vit.encoder_layers.0.attention.attention.query.weight
vit.encoder_layers.0.attention.attention.query.bias
vit.encoder_layers.0.attention.attention.key.weight
vit.encoder_layers.0.attention.attention.key.bias
vit.encoder_layers.0.attention.attention.value.weight
vit.encoder_layers.0.attention.attention.value.bias
vit.encoder_layers.0.attention.output.dense.weight
vit.encoder_layers.0.attention.output.dense.bias
vit.encoder_layers.0.intermediate.dense.weight
vit.encoder_layers.0.intermediate.dense.bias
vit.encoder_layers.0.output.dense.weight
vit.encoder_layers.0.output.dense.bias
vit.encoder_layers.0.layernorm_before.weight
vit.encoder_layers.0.layernorm_before.bias
vit.encoder_layers.0.layernorm_after.weight
vit.encoder_layers.0.layernorm_after.bias
vit.encoder_layers.1.attention.attention.query.weight
vit.encoder_layers.1.attention.attention.query.bias
vit.encoder_layers.1.attention.attention.key.weight
vit.encoder_layers

In [129]:
for k in model_hf_1.keys():
    print(k)

vit.embeddings.cls_token
vit.embeddings.position_embeddings
vit.embeddings.patch_embeddings.projection.weight
vit.embeddings.patch_embeddings.projection.bias
vit.encoder.layer.0.attention.attention.query.weight
vit.encoder.layer.0.attention.attention.query.bias
vit.encoder.layer.0.attention.attention.key.weight
vit.encoder.layer.0.attention.attention.key.bias
vit.encoder.layer.0.attention.attention.value.weight
vit.encoder.layer.0.attention.attention.value.bias
vit.encoder.layer.0.attention.output.dense.weight
vit.encoder.layer.0.attention.output.dense.bias
vit.encoder.layer.0.intermediate.dense.weight
vit.encoder.layer.0.intermediate.dense.bias
vit.encoder.layer.0.output.dense.weight
vit.encoder.layer.0.output.dense.bias
vit.encoder.layer.0.layernorm_before.weight
vit.encoder.layer.0.layernorm_before.bias
vit.encoder.layer.0.layernorm_after.weight
vit.encoder.layer.0.layernorm_after.bias
vit.encoder.layer.1.attention.attention.query.weight
vit.encoder.layer.1.attention.attention.query