In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class PatchEmbeddings(nn.Module):
    def __init__(self, img_size, num_channels, patch_size, hidden_size):
        super().__init__()
        self.img_size = img_size
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.hidden_size = hidden_size
        self.patch_num = (self.img_size // self.patch_size)**2
        self.projection = nn.Conv2d(self.num_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)   # here use conv2d to replace cropping+fc

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2).permute(0, 2, 1)
        return x

In [3]:
p = PatchEmbeddings(256, 3, 16, 16*16*3)
x = torch.randn((4, 3, 256, 256))
y = p(x)
print(y.shape)

torch.Size([4, 256, 768])


In [4]:
class Embeddings(nn.Module):
    def __init__(self, img_size, num_channels, patch_size, hidden_size):
        super().__init__()
        self.img_size = img_size
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.hidden_size = hidden_size

        self.patch_embedding = PatchEmbeddings(self.img_size, self.num_channels, self.patch_size, self.hidden_size)
        self.cls = nn.Parameter(torch.randn(1, 1, self.hidden_size))
        self.position_embedding = nn.Parameter(torch.randn(1, self.patch_embedding.patch_num+1, self.hidden_size))
        self.dropout = nn.Dropout()
    def forward(self, x):
        x = self.patch_embedding(x)
        bs = x.shape[0]
        cls_token = self.cls.expand(bs, -1, -1)
        x = torch.concat([x, cls_token], dim=1)
        x = x + self.position_embedding
        x = self.dropout(x)
        return x

In [5]:
class AttentionHead(nn.Module):
    def __init__(self, hidden_size, attention_head_size, bias=True) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.attention_head_size = attention_head_size

        self.query = nn.Linear(hidden_size, attention_head_size, bias)
        self.key = nn.Linear(hidden_size, attention_head_size, bias)
        self.value = nn.Linear(hidden_size, attention_head_size, bias)
        self.dropout = nn.Dropout()

    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        attention_score = query @ key.permute(0, 2, 1)
        attention_score = attention_score / math.sqrt(self.attention_head_size)
        attention_score = nn.functional.softmax(attention_score, dim=-1)
        attention_score = self.dropout(attention_score)
        attention_map = attention_score @ value
        return attention_map
        

In [6]:
a = AttentionHead(768, 64)
x = torch.randn(4, 257, 768)
y = a(x)
y.shape

torch.Size([4, 257, 64])

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads

        self.attention_head_size = self.hidden_size // self.num_attention_heads
        self.total_head_size = self.attention_head_size * self.num_attention_heads

        self.heads = nn.ModuleList()
        [self.heads.append(AttentionHead(self.hidden_size, self.attention_head_size)) for _ in range(self.num_attention_heads)]

        self.output_projection = nn.Linear(self.total_head_size, self.hidden_size)
        self.output_dropout = nn.Dropout()

    def forward(self, x):
        attention_outputs = [head(x) for head in self.heads]
        attention_output = torch.cat(attention_outputs, dim=-1)
        attention_output = self.output_projection(attention_output)
        attention_output = self.output_dropout(attention_output)
        return attention_output


In [8]:
a = MultiHeadAttention(768, 12)
x = torch.randn(4, 257, 768)
y = a(x)
y.shape

torch.Size([4, 257, 768])

In [9]:
class MLP(nn.Module):
    def __init__(self, hidden_size, middle_size) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.middle_size = middle_size
        self.dense1 = nn.Linear(self.hidden_size, self.middle_size)
        self.dense2 = nn.Linear(self.middle_size, self.hidden_size)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout()
    
    def forward(self, x):
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return x

In [10]:
x = torch.randn(4, 257, 768)
mlp = MLP(768, 768)
y = mlp(x)
y.shape

torch.Size([4, 257, 768])

In [11]:
class Block(nn.Module):
    def __init__(self, hidden1, hidden2) -> None:
        super().__init__()
        self.attention1 = MultiHeadAttention(hidden1, 12)
        self.laynorm1 = nn.LayerNorm(hidden1)
        self.attention2 = MultiHeadAttention(hidden2, 12)
        self.laynorm2 = nn.LayerNorm(hidden2)
        self.mlp = MLP(hidden2, hidden2)

    def forward(self, x):
        attention_output = self.attention1(self.laynorm1(x))
        x = x + attention_output
        mlp_output = self.mlp(self.laynorm2(x))
        x = x + mlp_output
        return x

In [12]:
x = torch.randn(4, 257, 768)
b = Block(768, 768)
y = b(x)
y.shape

torch.Size([4, 257, 768])

In [13]:
class Encoder(nn.Module):
    def __init__(self, hidden1, hidden2, block_num) -> None:
        super().__init__()
        self.blocks = nn.ModuleList()
        [self.blocks.append(Block(hidden1, hidden2)) for _ in range(block_num)]
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x
        

In [14]:
class ViT(nn.Module):
    def __init__(self, img_size, num_channels, patch_size, hidden_size, block_num, num_classes) -> None:
        super().__init__()
        self.img_size = img_size
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.hidden_size = hidden_size
        self.block_num = block_num
        self.num_classes = num_classes

        self.embedding = Embeddings(self.img_size, self.num_channels, self.patch_size, self.hidden_size)
        self.encoder = Encoder(self.hidden_size, self.hidden_size, self.block_num)
        self.classifier = nn.Linear(self.hidden_size, self.num_classes)

    def forward(self, x):
        embedding = self.embedding(x)
        encoder_output = self.encoder(embedding)
        logits = self.classifier(encoder_output[:, 0])
        return logits

In [16]:
def count_model_param(model):
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

In [17]:
x = torch.randn(4, 3, 256, 256)
vit = ViT(img_size=256, num_channels=3, patch_size=16, hidden_size=768, block_num=4, num_classes=10)
y = vit(x)
print(count_model_param(vit))

24432394
