In [59]:
import torch
import torch.nn as nn
import math

In [60]:
class Embedding(nn.Module):
    def __init__(self, img_size, num_channels, patch_size) -> None:
        super().__init__()
        self.img_size = img_size
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.patch_num = (img_size // patch_size)**2

        self.proj = nn.Conv2d(self.num_channels, self.num_channels*self.patch_size**2, kernel_size=self.patch_size, stride=self.patch_size)
        self.positional_embedding = self.get_positional_embedding()
        self.cls = nn.Parameter(torch.randn(1, 1, self.num_channels*self.patch_size**2))

    def get_positional_embedding(self, n=10000):
        d_model = self.num_channels*self.patch_size**2
        seq_len = self.patch_num + 1
        even_i = torch.arange(0, d_model, 2).float()
        odd_i = torch.arange(1, d_model, 2).float()
        even_denominator = torch.pow(n, even_i/d_model)
        odd_denominator = torch.pow(n, odd_i/d_model)
        position = torch.arange(seq_len, dtype=torch.float).reshape(-1, 1)
        even_pe = torch.sin(position/even_denominator)
        odd_pe = torch.cos(position/odd_denominator)
        positional_embedding = torch.stack([even_pe, odd_pe], dim=2).flatten(1)
        return positional_embedding
    
    def forward(self, x):
        bs = x.shape[0]
        x = self.proj(x).flatten(2).transpose(1, 2)
        cls = self.cls.expand(bs, -1, -1)
        x = torch.cat([cls, x], dim=1)
        positional_embedding = self.positional_embedding.expand(bs, -1, -1)
        x = x + positional_embedding
        return x 

In [61]:
embedding = Embedding(256, 3, 16)
x = torch.randn(4, 3, 256, 256)
y = embedding(x)
y.shape

torch.Size([4, 257, 768])

In [62]:
class AttentionHead(nn.Module):
    def __init__(self, dim_in, dim_out, dropout_p=0.1) -> None:
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.W_q = nn.Linear(self.dim_in, self.dim_out)
        self.W_k = nn.Linear(self.dim_in, self.dim_out)
        self.W_v = nn.Linear(self.dim_in, self.dim_out)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, x):
        query = self.W_q(x)
        key = self.W_k(x)
        value = self.W_v(x)
        # scaled dot-product
        attention_score = nn.functional.softmax(query @ key.transpose(1, 2) / math.sqrt(self.dim_out), dim=-1)
        attention_score = self.dropout(attention_score)
        attention_output = attention_score @ value
        return attention_output

In [63]:
x = torch.randn(4, 257, 768)
a = AttentionHead(768, 64)
y = a(x)
y.shape

torch.Size([4, 257, 64])

In [64]:
class MultiheadAttention(nn.Module):
    def __init__(self, att_head_dim_in, att_head_dim_out, num_head, dropout_p=0.1) -> None:
        super().__init__()
        self.att_head_dim_in = att_head_dim_in
        self.att_head_dim_out = att_head_dim_out
        self.num_head = num_head
        self.total_dim_out = num_head * att_head_dim_out

        self.heads = nn.ModuleList()
        [self.heads.append(AttentionHead(self.att_head_dim_in, self.att_head_dim_out)) for _ in range(self.num_head)]
        self.proj = nn.Linear(self.total_dim_out, self.total_dim_out)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.proj(x)
        x = self.dropout(x)
        return x

In [65]:
x = torch.randn(4, 257, 768)
a = MultiheadAttention(768, 64, 12)
y = a(x)
y.shape

torch.Size([4, 257, 768])

In [66]:
class MLP(nn.Module):
    def __init__(self, dim_in, dim_mid, dim_out, dropout_p=0.1) -> None:
        super().__init__()
        self.dim_in = dim_in
        self.dim_mid = dim_mid
        self.dim_out = dim_out
        self.dense1 = nn.Linear(self.dim_in, self.dim_mid)
        self.activation = nn.GELU()
        self.dense2 = nn.Linear(self.dim_mid, self.dim_out)
        self.dropout = nn.Dropout(dropout_p)
    
    def forward(self, x):
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return x

In [67]:
x = torch.randn(4, 257, 768)
a = MLP(768, 768, 768)
y = a(x)
y.shape

torch.Size([4, 257, 768])

In [68]:
class TransformerBlock(nn.Module):
    def __init__(self, dim_in, dim_out, num_head) -> None:
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.num_head = num_head

        self.layernorm1 = nn.LayerNorm(self.dim_in)
        self.multihead_attention = MultiheadAttention(self.dim_in, self.dim_in//self.num_head, self.num_head)
        self.layernorm2 = nn.LayerNorm(self.dim_in)
        self.mlp = MLP(self.dim_in, self.dim_in, self.dim_out)

    def forward(self, x):
        x = x + self.multihead_attention(self.layernorm1(x))
        x = x + self.mlp(self.layernorm2(x))
        return x


In [69]:
x = torch.randn(4, 257, 768)
a = TransformerBlock(768, 768, 12)
y = a(x)
y.shape

torch.Size([4, 257, 768])

In [70]:
class ViT(nn.Module):
    def __init__(self, img_size, num_channels, patch_size, num_blocks, num_heads, num_classes) -> None:
        super().__init__()
        self.img_size = img_size
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.hidden_size = num_channels*patch_size**2

        self.num_blocks = num_blocks
        self.num_heads = num_heads
        self.num_classes = num_classes

        self.embedding = Embedding(self.img_size, self.num_channels, self.patch_size)
        self.encoder = nn.Sequential()
        [self.encoder.append(TransformerBlock(self.hidden_size, self.hidden_size, self.num_heads)) for _ in range(self.num_blocks)]
        self.classifier = nn.Linear(self.hidden_size, self.num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = self.classifier(x[:, 0])
        return x

In [71]:
x = torch.randn(4, 3, 256, 256)
vit = ViT(img_size=256, num_channels=3, patch_size=16, num_heads=12, num_blocks=4, num_classes=10)
y = vit(x)
y.shape

torch.Size([4, 10])