In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [2]:
class PositionalEmbedding1D(nn.Module):
    """Adds (optionally learned) positional embeddings to the inputs."""

    def __init__(self, seq_len, dim):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim))
    
    def forward(self, x):
        """Input has shape `(batch_size, seq_len, emb_dim)`"""
        # (1, 3, 14, 14) + (1, 3, 196) ? wtf
        # ohhh flatten before embedding kek
        return x + self.pos_embedding

In [3]:
class MLP(nn.Module):
    def __init__(self, dim, ff_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, dim)
        
    def forward(self, x):
        out = self.fc2(F.gelu(self.fc1(x)))
        return out

In [4]:
class MHSA(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.project = nn.Linear(dim, dim)
        
    def forward(self, x): # input shape [b, s, d]
        # split into q, k, v : (query, key, values)
        # expand by h (num heads... hence multi attention)
        
        q, k, v = self.project(x), self.project(x), self.project(x)
        # convert q, k, v -> [b, h, s, w] where h = n_heads
        q = q.view(q.shape[0], self.n_heads, q.shape[1], -1)
        k = k.view(k.shape[0], self.n_heads, k.shape[1], -1)
        v = v.view(v.shape[0], self.n_heads, v.shape[1], -1)
        
        # scaled dot product attention on q, k (queries, keys) then matmul with values
        # matmul + scale
        # [b, h, s, w] @ [b, h, w, s] -> [b, h, s, s]
        k = k.transpose(-2, -1) # swap last two dimensions
        p = torch.matmul(q, k)
        p = p / np.sqrt(k.size(-1)) # where s is the dimension of k

        p = F.softmax(p, dim = -1) # softmax across last dimension
        
        out = torch.matmul(p, v) # [b, h, s, s] @ [b, h, s, w] -> [b, h, s, w]
        out = out.view(out.shape[0], out.shape[2], -1) # [b, s, d]
        return out

In [5]:
# dont need linear and dont need dropout i think
# add after if performance is shit

class Block(nn.Module): # inputs are B, S, D
    def __init__(self, dim, n_heads, ff_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.MHSA = MHSA(dim, n_heads)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp_channels = MLP(dim, ff_dim)
        
    def forward(self, x):
        h = self.MHSA(self.norm1(x))
        x = x + h
        h = self.mlp_channels(self.norm2(x))
        x = x + h
        return x

In [6]:
class Transformer(nn.Module):
    def __init__(self, n_layers, dim, n_heads, ff_dim):
        super().__init__()
        self.blocks = nn.ModuleList([Block(dim, n_heads, ff_dim) for _ in range(n_layers)])
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [7]:
'''
with image 1, 3, 224, 224  and patch encoding of 16x16 we have
((224 - 16) / 16) + 1 = 14
seq len is 14 * 14 bro im trippin bullets
out = (1, 3, 14, 14)
'''

'\nwith image 1, 3, 224, 224  and patch encoding of 16x16 we have\n((224 - 16) / 16) + 1 = 14\nseq len is 14 * 14 bro im trippin bullets\nout = (1, 3, 14, 14)\n'

In [8]:
class ViT(nn.Module):
    def __init__(self, in_channels, dim, fh, fw, n_layers, n_heads, ff_dim, num_classes):
        super().__init__()
        self.fw = fw
        self.fh = fh
        self.patch_encoding = nn.Conv2d(in_channels, dim, kernel_size=(self.fh, self.fw), stride=(self.fh, self.fw))
        # [B, D, FH, FW]
        
        seq_len = 14 * 14
            
        self.positional_embedding = PositionalEmbedding1D(seq_len, dim) # inputs are seq len, dim
        # [B, D, FH, FW]
        
        # flatten into [B, S, D]
        self.Transformer = Transformer(n_layers, dim, n_heads, ff_dim)
        
        # if this doenst work then do it manually
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.mlp_head = nn.Linear(dim, num_classes)
        
    def forward(self, x):
        x = self.patch_encoding(x)
        x = x.view(x.shape[0], -1, x.shape[1]) # b, s, d
        x = self.positional_embedding(x) 
        x = self.Transformer(x)
        x = self.norm(x)
        x = x[:, -1, :] # b, s, d -> b, d
        x = self.mlp_head(x)
        return x

In [9]:
#hyperparameters
fw = fh = 16
dim = 768
ff_dim = 3072
n_heads = 12
n_layers = 12
in_channels = 3
num_classes = 10


In [10]:
model = ViT(in_channels, dim, fh, fw, n_layers, n_heads, ff_dim, num_classes)

In [11]:
# test input of (1, 3, 224, 224) like in research paper
x = torch.rand(1, 3, 224, 224)

In [12]:
out = model(x)

In [13]:
out

tensor([[-0.3238, -0.0299,  0.2320,  0.3181,  0.3008, -0.1919,  0.4774, -0.2592,
          0.3847, -0.4313]], grad_fn=<AddmmBackward0>)

In [14]:
x = torch.load("b16.pth")

### I NEED TO APPEND Transformer.xyz to every block or else it wont load the damn weight fml

## figure out what each key is and rename my model to match... smh
* blocks.n.norm1.bias1 -> transformer blocks (class Block) DO THIS FOR 12 BLOCKS
* blocks.mlp_channels 
    * change mlp to mlp_channels 
        * change l1, l2 to fc1, fc2
    * change self.block to self.blocks
* 

odict_keys(['blocks.0.norm1.bias', 'blocks.0.norm1.weight', 'blocks.0.norm2.bias', 'blocks.0.norm2.weight', 'blocks.0.mlp_channels.fc1.bias', 'blocks.0.mlp_channels.fc1.weight', 'blocks.0.mlp_channels.fc2.bias', 'blocks.0.mlp_channels.fc2.weight', 'blocks.0.mlp_tokens.fc1.bias', 'blocks.0.mlp_tokens.fc1.weight', 'blocks.0.mlp_tokens.fc2.bias', 'blocks.0.mlp_tokens.fc2.weight', 'blocks.1.norm1.bias', 'blocks.1.norm1.weight', 'blocks.1.norm2.bias', 'blocks.1.norm2.weight', 'blocks.1.mlp_channels.fc1.bias', 'blocks.1.mlp_channels.fc1.weight', 'blocks.1.mlp_channels.fc2.bias', 'blocks.1.mlp_channels.fc2.weight', 'blocks.1.mlp_tokens.fc1.bias', 'blocks.1.mlp_tokens.fc1.weight', 'blocks.1.mlp_tokens.fc2.bias', 'blocks.1.mlp_tokens.fc2.weight', 'blocks.10.norm1.bias', 'blocks.10.norm1.weight', 'blocks.10.norm2.bias', 'blocks.10.norm2.weight', 'blocks.10.mlp_channels.fc1.bias', 'blocks.10.mlp_channels.fc1.weight', 'blocks.10.mlp_channels.fc2.bias', 'blocks.10.mlp_channels.fc2.weight', 'blocks.10.mlp_tokens.fc1.bias', 'blocks.10.mlp_tokens.fc1.weight', 'blocks.10.mlp_tokens.fc2.bias', 'blocks.10.mlp_tokens.fc2.weight', 'blocks.11.norm1.bias', 'blocks.11.norm1.weight', 'blocks.11.norm2.bias', 'blocks.11.norm2.weight', 'blocks.11.mlp_channels.fc1.bias', 'blocks.11.mlp_channels.fc1.weight', 'blocks.11.mlp_channels.fc2.bias', 'blocks.11.mlp_channels.fc2.weight', 'blocks.11.mlp_tokens.fc1.bias', 'blocks.11.mlp_tokens.fc1.weight', 'blocks.11.mlp_tokens.fc2.bias', 'blocks.11.mlp_tokens.fc2.weight', 'blocks.2.norm1.bias', 'blocks.2.norm1.weight', 'blocks.2.norm2.bias', 'blocks.2.norm2.weight', 'blocks.2.mlp_channels.fc1.bias', 'blocks.2.mlp_channels.fc1.weight', 'blocks.2.mlp_channels.fc2.bias', 'blocks.2.mlp_channels.fc2.weight', 'blocks.2.mlp_tokens.fc1.bias', 'blocks.2.mlp_tokens.fc1.weight', 'blocks.2.mlp_tokens.fc2.bias', 'blocks.2.mlp_tokens.fc2.weight', 'blocks.3.norm1.bias', 'blocks.3.norm1.weight', 'blocks.3.norm2.bias', 'blocks.3.norm2.weight', 'blocks.3.mlp_channels.fc1.bias', 'blocks.3.mlp_channels.fc1.weight', 'blocks.3.mlp_channels.fc2.bias', 'blocks.3.mlp_channels.fc2.weight', 'blocks.3.mlp_tokens.fc1.bias', 'blocks.3.mlp_tokens.fc1.weight', 'blocks.3.mlp_tokens.fc2.bias', 'blocks.3.mlp_tokens.fc2.weight', 'blocks.4.norm1.bias', 'blocks.4.norm1.weight', 'blocks.4.norm2.bias', 'blocks.4.norm2.weight', 'blocks.4.mlp_channels.fc1.bias', 'blocks.4.mlp_channels.fc1.weight', 'blocks.4.mlp_channels.fc2.bias', 'blocks.4.mlp_channels.fc2.weight', 'blocks.4.mlp_tokens.fc1.bias', 'blocks.4.mlp_tokens.fc1.weight', 'blocks.4.mlp_tokens.fc2.bias', 'blocks.4.mlp_tokens.fc2.weight', 'blocks.5.norm1.bias', 'blocks.5.norm1.weight', 'blocks.5.norm2.bias', 'blocks.5.norm2.weight', 'blocks.5.mlp_channels.fc1.bias', 'blocks.5.mlp_channels.fc1.weight', 'blocks.5.mlp_channels.fc2.bias', 'blocks.5.mlp_channels.fc2.weight', 'blocks.5.mlp_tokens.fc1.bias', 'blocks.5.mlp_tokens.fc1.weight', 'blocks.5.mlp_tokens.fc2.bias', 'blocks.5.mlp_tokens.fc2.weight', 'blocks.6.norm1.bias', 'blocks.6.norm1.weight', 'blocks.6.norm2.bias', 'blocks.6.norm2.weight', 'blocks.6.mlp_channels.fc1.bias', 'blocks.6.mlp_channels.fc1.weight', 'blocks.6.mlp_channels.fc2.bias', 'blocks.6.mlp_channels.fc2.weight', 'blocks.6.mlp_tokens.fc1.bias', 'blocks.6.mlp_tokens.fc1.weight', 'blocks.6.mlp_tokens.fc2.bias', 'blocks.6.mlp_tokens.fc2.weight', 'blocks.7.norm1.bias', 'blocks.7.norm1.weight', 'blocks.7.norm2.bias', 'blocks.7.norm2.weight', 'blocks.7.mlp_channels.fc1.bias', 'blocks.7.mlp_channels.fc1.weight', 'blocks.7.mlp_channels.fc2.bias', 'blocks.7.mlp_channels.fc2.weight', 'blocks.7.mlp_tokens.fc1.bias', 'blocks.7.mlp_tokens.fc1.weight', 'blocks.7.mlp_tokens.fc2.bias', 'blocks.7.mlp_tokens.fc2.weight', 'blocks.8.norm1.bias', 'blocks.8.norm1.weight', 'blocks.8.norm2.bias', 'blocks.8.norm2.weight', 'blocks.8.mlp_channels.fc1.bias', 'blocks.8.mlp_channels.fc1.weight', 'blocks.8.mlp_channels.fc2.bias', 'blocks.8.mlp_channels.fc2.weight', 'blocks.8.mlp_tokens.fc1.bias', 'blocks.8.mlp_tokens.fc1.weight', 'blocks.8.mlp_tokens.fc2.bias', 'blocks.8.mlp_tokens.fc2.weight', 'blocks.9.norm1.bias', 'blocks.9.norm1.weight', 'blocks.9.norm2.bias', 'blocks.9.norm2.weight', 'blocks.9.mlp_channels.fc1.bias', 'blocks.9.mlp_channels.fc1.weight', 'blocks.9.mlp_channels.fc2.bias', 'blocks.9.mlp_channels.fc2.weight', 'blocks.9.mlp_tokens.fc1.bias', 'blocks.9.mlp_tokens.fc1.weight', 'blocks.9.mlp_tokens.fc2.bias', 'blocks.9.mlp_tokens.fc2.weight', 'head.bias', 'head.weight', 'norm.bias', 'norm.weight', 'stem.proj.bias', 'stem.proj.weight'])


In [37]:
from collections import OrderedDict

In [38]:
d = OrderedDict()

In [39]:
for i, j in x.items():
    if i[:7] == "blocks.":
        newkey = "Transformer." + i
        d[newkey] = j
    else:
        d[i] = j

In [40]:
d.keys()

odict_keys(['Transformer.blocks.0.norm1.bias', 'Transformer.blocks.0.norm1.weight', 'Transformer.blocks.0.norm2.bias', 'Transformer.blocks.0.norm2.weight', 'Transformer.blocks.0.mlp_channels.fc1.bias', 'Transformer.blocks.0.mlp_channels.fc1.weight', 'Transformer.blocks.0.mlp_channels.fc2.bias', 'Transformer.blocks.0.mlp_channels.fc2.weight', 'Transformer.blocks.0.mlp_tokens.fc1.bias', 'Transformer.blocks.0.mlp_tokens.fc1.weight', 'Transformer.blocks.0.mlp_tokens.fc2.bias', 'Transformer.blocks.0.mlp_tokens.fc2.weight', 'Transformer.blocks.1.norm1.bias', 'Transformer.blocks.1.norm1.weight', 'Transformer.blocks.1.norm2.bias', 'Transformer.blocks.1.norm2.weight', 'Transformer.blocks.1.mlp_channels.fc1.bias', 'Transformer.blocks.1.mlp_channels.fc1.weight', 'Transformer.blocks.1.mlp_channels.fc2.bias', 'Transformer.blocks.1.mlp_channels.fc2.weight', 'Transformer.blocks.1.mlp_tokens.fc1.bias', 'Transformer.blocks.1.mlp_tokens.fc1.weight', 'Transformer.blocks.1.mlp_tokens.fc2.bias', 'Transform

# basically just mlp_tokens at this point
## everything new i add though imma have to train myself cuz no weights :/

In [20]:
model.load_state_dict(d)

RuntimeError: Error(s) in loading state_dict for ViT:
	Missing key(s) in state_dict: "patch_encoding.weight", "patch_encoding.bias", "positional_embedding.pos_embedding", "Transformer.blocks.0.MHSA.project.weight", "Transformer.blocks.0.MHSA.project.bias", "Transformer.blocks.1.MHSA.project.weight", "Transformer.blocks.1.MHSA.project.bias", "Transformer.blocks.2.MHSA.project.weight", "Transformer.blocks.2.MHSA.project.bias", "Transformer.blocks.3.MHSA.project.weight", "Transformer.blocks.3.MHSA.project.bias", "Transformer.blocks.4.MHSA.project.weight", "Transformer.blocks.4.MHSA.project.bias", "Transformer.blocks.5.MHSA.project.weight", "Transformer.blocks.5.MHSA.project.bias", "Transformer.blocks.6.MHSA.project.weight", "Transformer.blocks.6.MHSA.project.bias", "Transformer.blocks.7.MHSA.project.weight", "Transformer.blocks.7.MHSA.project.bias", "Transformer.blocks.8.MHSA.project.weight", "Transformer.blocks.8.MHSA.project.bias", "Transformer.blocks.9.MHSA.project.weight", "Transformer.blocks.9.MHSA.project.bias", "Transformer.blocks.10.MHSA.project.weight", "Transformer.blocks.10.MHSA.project.bias", "Transformer.blocks.11.MHSA.project.weight", "Transformer.blocks.11.MHSA.project.bias", "mlp_head.weight", "mlp_head.bias". 
	Unexpected key(s) in state_dict: "head.bias", "head.weight", "stem.proj.bias", "stem.proj.weight", "Transformer.blocks.0.mlp_tokens.fc1.bias", "Transformer.blocks.0.mlp_tokens.fc1.weight", "Transformer.blocks.0.mlp_tokens.fc2.bias", "Transformer.blocks.0.mlp_tokens.fc2.weight", "Transformer.blocks.1.mlp_tokens.fc1.bias", "Transformer.blocks.1.mlp_tokens.fc1.weight", "Transformer.blocks.1.mlp_tokens.fc2.bias", "Transformer.blocks.1.mlp_tokens.fc2.weight", "Transformer.blocks.2.mlp_tokens.fc1.bias", "Transformer.blocks.2.mlp_tokens.fc1.weight", "Transformer.blocks.2.mlp_tokens.fc2.bias", "Transformer.blocks.2.mlp_tokens.fc2.weight", "Transformer.blocks.3.mlp_tokens.fc1.bias", "Transformer.blocks.3.mlp_tokens.fc1.weight", "Transformer.blocks.3.mlp_tokens.fc2.bias", "Transformer.blocks.3.mlp_tokens.fc2.weight", "Transformer.blocks.4.mlp_tokens.fc1.bias", "Transformer.blocks.4.mlp_tokens.fc1.weight", "Transformer.blocks.4.mlp_tokens.fc2.bias", "Transformer.blocks.4.mlp_tokens.fc2.weight", "Transformer.blocks.5.mlp_tokens.fc1.bias", "Transformer.blocks.5.mlp_tokens.fc1.weight", "Transformer.blocks.5.mlp_tokens.fc2.bias", "Transformer.blocks.5.mlp_tokens.fc2.weight", "Transformer.blocks.6.mlp_tokens.fc1.bias", "Transformer.blocks.6.mlp_tokens.fc1.weight", "Transformer.blocks.6.mlp_tokens.fc2.bias", "Transformer.blocks.6.mlp_tokens.fc2.weight", "Transformer.blocks.7.mlp_tokens.fc1.bias", "Transformer.blocks.7.mlp_tokens.fc1.weight", "Transformer.blocks.7.mlp_tokens.fc2.bias", "Transformer.blocks.7.mlp_tokens.fc2.weight", "Transformer.blocks.8.mlp_tokens.fc1.bias", "Transformer.blocks.8.mlp_tokens.fc1.weight", "Transformer.blocks.8.mlp_tokens.fc2.bias", "Transformer.blocks.8.mlp_tokens.fc2.weight", "Transformer.blocks.9.mlp_tokens.fc1.bias", "Transformer.blocks.9.mlp_tokens.fc1.weight", "Transformer.blocks.9.mlp_tokens.fc2.bias", "Transformer.blocks.9.mlp_tokens.fc2.weight", "Transformer.blocks.10.mlp_tokens.fc1.bias", "Transformer.blocks.10.mlp_tokens.fc1.weight", "Transformer.blocks.10.mlp_tokens.fc2.bias", "Transformer.blocks.10.mlp_tokens.fc2.weight", "Transformer.blocks.11.mlp_tokens.fc1.bias", "Transformer.blocks.11.mlp_tokens.fc1.weight", "Transformer.blocks.11.mlp_tokens.fc2.bias", "Transformer.blocks.11.mlp_tokens.fc2.weight". 

# im removing mlp tokens and seeing how this performs

In [24]:
len('blocks.11.mlp_tokens.fc1.bias')

29

In [27]:
'blocks.11.mlp_tokens.fc1.bias'[10:20]

'mlp_tokens'

In [23]:
'mlp_tokens' in d.keys()

False

In [60]:
dd = OrderedDict()

In [61]:
for i, j in d.items():
    if i[21:31] != 'mlp_tokens':
        dd[i] = j

In [62]:
dd.keys()

odict_keys(['Transformer.blocks.0.norm1.bias', 'Transformer.blocks.0.norm1.weight', 'Transformer.blocks.0.norm2.bias', 'Transformer.blocks.0.norm2.weight', 'Transformer.blocks.0.mlp_channels.fc1.bias', 'Transformer.blocks.0.mlp_channels.fc1.weight', 'Transformer.blocks.0.mlp_channels.fc2.bias', 'Transformer.blocks.0.mlp_channels.fc2.weight', 'Transformer.blocks.1.norm1.bias', 'Transformer.blocks.1.norm1.weight', 'Transformer.blocks.1.norm2.bias', 'Transformer.blocks.1.norm2.weight', 'Transformer.blocks.1.mlp_channels.fc1.bias', 'Transformer.blocks.1.mlp_channels.fc1.weight', 'Transformer.blocks.1.mlp_channels.fc2.bias', 'Transformer.blocks.1.mlp_channels.fc2.weight', 'Transformer.blocks.10.norm1.bias', 'Transformer.blocks.10.norm1.weight', 'Transformer.blocks.10.norm2.bias', 'Transformer.blocks.10.norm2.weight', 'Transformer.blocks.10.mlp_channels.fc1.bias', 'Transformer.blocks.10.mlp_channels.fc1.weight', 'Transformer.blocks.10.mlp_channels.fc2.bias', 'Transformer.blocks.10.mlp_channe

In [63]:
ddd = OrderedDict()

In [64]:
for i, j in dd.items():
    if i[22:32] != 'mlp_tokens':
        ddd[i] = j

In [65]:
ddd.keys()

odict_keys(['Transformer.blocks.0.norm1.bias', 'Transformer.blocks.0.norm1.weight', 'Transformer.blocks.0.norm2.bias', 'Transformer.blocks.0.norm2.weight', 'Transformer.blocks.0.mlp_channels.fc1.bias', 'Transformer.blocks.0.mlp_channels.fc1.weight', 'Transformer.blocks.0.mlp_channels.fc2.bias', 'Transformer.blocks.0.mlp_channels.fc2.weight', 'Transformer.blocks.1.norm1.bias', 'Transformer.blocks.1.norm1.weight', 'Transformer.blocks.1.norm2.bias', 'Transformer.blocks.1.norm2.weight', 'Transformer.blocks.1.mlp_channels.fc1.bias', 'Transformer.blocks.1.mlp_channels.fc1.weight', 'Transformer.blocks.1.mlp_channels.fc2.bias', 'Transformer.blocks.1.mlp_channels.fc2.weight', 'Transformer.blocks.10.norm1.bias', 'Transformer.blocks.10.norm1.weight', 'Transformer.blocks.10.norm2.bias', 'Transformer.blocks.10.norm2.weight', 'Transformer.blocks.10.mlp_channels.fc1.bias', 'Transformer.blocks.10.mlp_channels.fc1.weight', 'Transformer.blocks.10.mlp_channels.fc2.bias', 'Transformer.blocks.10.mlp_channe

In [157]:
x.keys()

odict_keys(['blocks.0.norm1.bias', 'blocks.0.norm1.weight', 'blocks.0.norm2.bias', 'blocks.0.norm2.weight', 'blocks.0.mlp_channels.fc1.bias', 'blocks.0.mlp_channels.fc1.weight', 'blocks.0.mlp_channels.fc2.bias', 'blocks.0.mlp_channels.fc2.weight', 'blocks.0.mlp_tokens.fc1.bias', 'blocks.0.mlp_tokens.fc1.weight', 'blocks.0.mlp_tokens.fc2.bias', 'blocks.0.mlp_tokens.fc2.weight', 'blocks.1.norm1.bias', 'blocks.1.norm1.weight', 'blocks.1.norm2.bias', 'blocks.1.norm2.weight', 'blocks.1.mlp_channels.fc1.bias', 'blocks.1.mlp_channels.fc1.weight', 'blocks.1.mlp_channels.fc2.bias', 'blocks.1.mlp_channels.fc2.weight', 'blocks.1.mlp_tokens.fc1.bias', 'blocks.1.mlp_tokens.fc1.weight', 'blocks.1.mlp_tokens.fc2.bias', 'blocks.1.mlp_tokens.fc2.weight', 'blocks.10.norm1.bias', 'blocks.10.norm1.weight', 'blocks.10.norm2.bias', 'blocks.10.norm2.weight', 'blocks.10.mlp_channels.fc1.bias', 'blocks.10.mlp_channels.fc1.weight', 'blocks.10.mlp_channels.fc2.bias', 'blocks.10.mlp_channels.fc2.weight', 'blocks.

# BRO I CANT EVEN FIND MLP_TOKENS ON THIS DUDES GITHUB WTF IS AN MLP TOKEN
## ... >:(

In [67]:
model.load_state_dict(ddd)

RuntimeError: Error(s) in loading state_dict for ViT:
	Missing key(s) in state_dict: "patch_encoding.weight", "patch_encoding.bias", "positional_embedding.pos_embedding", "Transformer.blocks.0.MHSA.project.weight", "Transformer.blocks.0.MHSA.project.bias", "Transformer.blocks.1.MHSA.project.weight", "Transformer.blocks.1.MHSA.project.bias", "Transformer.blocks.2.MHSA.project.weight", "Transformer.blocks.2.MHSA.project.bias", "Transformer.blocks.3.MHSA.project.weight", "Transformer.blocks.3.MHSA.project.bias", "Transformer.blocks.4.MHSA.project.weight", "Transformer.blocks.4.MHSA.project.bias", "Transformer.blocks.5.MHSA.project.weight", "Transformer.blocks.5.MHSA.project.bias", "Transformer.blocks.6.MHSA.project.weight", "Transformer.blocks.6.MHSA.project.bias", "Transformer.blocks.7.MHSA.project.weight", "Transformer.blocks.7.MHSA.project.bias", "Transformer.blocks.8.MHSA.project.weight", "Transformer.blocks.8.MHSA.project.bias", "Transformer.blocks.9.MHSA.project.weight", "Transformer.blocks.9.MHSA.project.bias", "Transformer.blocks.10.MHSA.project.weight", "Transformer.blocks.10.MHSA.project.bias", "Transformer.blocks.11.MHSA.project.weight", "Transformer.blocks.11.MHSA.project.bias", "mlp_head.weight", "mlp_head.bias". 
	Unexpected key(s) in state_dict: "head.bias", "head.weight", "stem.proj.bias", "stem.proj.weight". 