In [1]:
import torch
import torch.nn as nn

In [2]:
class PatchEmbed(nn.Module):
    """
    Split image into patches and them embed them

    Parameters 
    ----------

    img_size: int
        Size of image (square assumed)

    patch_size: int
        Size of each patch (square)

    in_chans: int
        Number of input channels

    embed_dim: int
        The embedding dimension

    Attributes
    ----------

    n_patches: int
        Number of patches inside of our image

    proj: nn.Conv2d
        Convolutional layer that does both the splitting into patches and their embedding

    """

    def __init__(self, img_size, patch_size, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    
    def forward(self, x):
        """
        Run forward pass

        Parameters
        ----------
        x: torch.Tensor
            Shape `(n_samples, in_chans, img_size, img_size)`

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches, embed_dim)`

        """
        x = self.proj(x) # (n_samples, embed_dim, n_patches ** 0.5, n_patches ** 0.5)
        
        x = x.flatten(2) # (n_samples, embed_dim, n_patches)
        x = x.transpose(1, 2) # (n_samples, n_patches, embed_dim)
        return x
    
class Attention(nn.Module):
    """
    Attention mechanism

    Parameters
    ----------
    dim: int
        The input and out dimension of per token features
    
    n_heads: int
        Number of attention heads

    qkv_bias: bool
        If True then we include bias to the query, key and value projections

    attn_p: float
        Dropout probability applied to the query, key and value tensors

    proj_p: float
        Dropout probability applied to the output tensor

    Attributes
    ----------
    scale: float
        Normalizing consant for the dot product

    qkv: nn.Linear

    proj: nn.Linear
        Linear mapping that takes in the concatenated output of all attention heads and maps it into a new space

    attn_drop, proj_drop: nn.Dropout
        Dropout layers

    """

    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.n_heads = n_heads
        head_dim = dim // n_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self,x):
        """
        Run forward pass.

        Parameters
        ----------

        x: torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`. The +1 is for the [CLS] token

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`

        """
        n_samples, n_tokens, dim = x.shape
        
        if dim != self.n_heads:
            raise ValueError
        
        qkv = self.qkv(x) # (n_samples, n_patches + 1, 3 * dim)
        qkv = qkv.reshape(n_samples, n_tokens, 3, self.n_heads).permute(2, 0, 3, 1)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2] # (n_samples, n_heads, n_patches + 1, head_dim)
        k_t = k.transpose(-2, -1) # (n_samples, n_heads, head_dim, n_patches + 1)
        dp = (q @ k_t) * self.scale
        attn = dp.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        weighted_avg = attn @ v # (n_samples, n_heads, n_patches + 1, head_dim)
        weighted_avg = weighted_avg.transpose(1, 2)
        weighted_avg = weighted_avg.flatten(2)

        x = self.proj(weighted_avg)
        x = self.proj_drop(x)
        return x
    
class MLP(nn.Module):   
    """
    Multilayer perceptron

    Parameters
    ----------
    in_features: int
        Number of input features

    hidden_features: int
        Number of nodes in the hidden layer

    out_features: int
        Number of output features

    p: float
        Dropout probability

    Attributes
    ----------
    
    fc: nn.Linear
        The first linear layer

    act: nn.GELU
        GELU activation function

    fc2: nn.Linear
        The second linear layer

    drop: nn.Dropout
        Dropout layer

    """
    def __init__(self, in_features, hidden_features=None, out_features=None, p=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)
    
    def forward(self, x):
        """
        Run forward pass.

        Parameters
        ----------

        x: torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`. The +1 is for the [CLS] token

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`

        """
        x = self.fc(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    """
    Transformer block

    Parameters
    ----------
    dim: int
        Embedding dimension

    n_heads: int   
        Number of attention heads

    mlp_ratio: float
        Determines the hidden dimension size of the `MLP` module relative to `dim`

    qkv_bias: bool
        If True then we include bias to the query, key and value projections

    p, attn_p, proj_p: float
        Dropout probabilities

    Attributes
    ----------
    norm1, norm2: nn.LayerNorm
        Layer normalization

    attn: Attention
        Attention module

    mlp: MLP
        MLP module

    """

    def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, p=0., attn_p=0., proj_p=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(
            dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_p=attn_p, proj_p=proj_p
        )
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=hidden_dim, out_features=dim, p=p)
    
    def forward(self, x):
        """
        Run forward pass.

        Parameters
        ----------

        x: torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`. The +1 is for the [CLS] token

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`

        """
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x
    

class VisionTransformer(nn.Module):
    """
    Vision Transformer

    Parameters
    ----------
    img_size: int
        Both height and the width of the image

    patch_size: int
        Both height and the width of the patch

    in_chans: int
        Number of input channels

    n_classes: int
        Number of classes of the classification problem

    embed_dim: int
        Dimensionality of the token/patch embeddings

    depth: int
        Number of blocks in the `Block` class

    n_heads: int
        Number of attention heads inside of the `Attention` class

    mlp_ratio: float
        Determines the hidden dimension size of the `MLP` module relative to `embed_dim`

    qkv_bias: bool
        If True then we include bias to the query, key and value projections

    p, attn_p, proj_p: float
        Dropout probabilities

    Attributes
    ----------
    patch_embed: PatchEmbed
        Instance of `PatchEmbed` layer

    cls_token: nn.Parameter
        Learnable parameter that will represent the first token in the sequence.
        It has `embed_dim` elements

    pos_emb: nn.Parameter
        Positional embedding of the cls_token + all the patches.
        It has `(n_patches + 1) * embed_dim` elements

    pos_drop: nn.Dropout
        Dropout layer

    blocks: nn.ModuleList
        List of `Block` layers

    norm: nn.LayerNorm
        Layer normalization

    """

    def __init__(
        self,
        img_size=384,
        patch_size=16,
        in_chans=3,
        n_classes=1000,
        embed_dim=768,
        depth=12,
        n_heads=12,
        mlp_ratio=4.,
        qkv_bias=True,
        p=0.,
        attn_p=0.,
        proj_p=0.
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
        )
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_emb = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=p)
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, n_heads=n_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, p=p, attn_p=attn_p, proj_p=proj_p)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        self.head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        """
        Run forward pass.

        Parameters
        ----------

        x: torch.Tensor
            Shape `(n_samples, in_chans, img_size, img_size)`

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_classes)`

        """
        n_samples = x.shape[0]
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(n_samples, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_emb
        x = self.pos_drop(x)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        x = x[:, 0]
        x = self.head(x)
        return x

In [None]:
import numpy as np
import timm

# helpers
def get_n_params(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

def assert_tensors_equal(t1, t2):
    a1, a2 = t1.detach().numpy(), t2.detach().numpy()

    np.testing.assert_allclose(a1, a2)

model_name = 'vit_base_patch16_384'
model_official = timm.create_model(model_name, pretrained=True)
model_official.eval()
print(type(model_official))

custom_config = {
    'img_size': 384,
    'patch_size': 16,
    'in_chans': 3,
    'n_classes': 1000,
    'embed_dim': 768,
    'depth': 12,
    'n_heads': 12,
    'mlp_ratio': 4
}

model_custom = VisionTransformer(**custom_config)
model_custom.eval()

for (n_o, p_o), (n_c, p_c) in zip(model_official.named_parameters(), model_custom.named_parameters()):
    assert n_o == n_c
    assert p_o.shape == p_c.shape
    p_c.data = p_o.data
    assert_tensors_equal(p_o, p_c)