## ViT 구현

### Multi-Head Attention

In [7]:
!pip install einops
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.init as init
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, patches_dim, n_heads):
        super().__init__()

        self.n_heads = n_heads
        self.root_dk = torch.sqrt(torch.tensor(patches_dim / n_heads, dtype=torch.float32))

        self.q_Linear = nn.Linear(patches_dim, patches_dim)
        self.k_Linear = nn.Linear(patches_dim, patches_dim)
        self.v_Linear = nn.Linear(patches_dim, patches_dim)
        self.last_Linear = nn.Linear(patches_dim, patches_dim)
        self.linear_layers = [self.q_Linear, self.k_Linear, self.v_Linear, self.last_Linear]

        for layer in self.linear_layers:
            init.xavier_uniform_(layer.weight)
            if layer.bias is not None:
                init.constant_(layer.bias, 0)

    def forward(self, x):
        Q, K, V = [linear(x) for linear in self.linear_layers[:3]]
        Q, K, V = [rearrange(tensor, 'b n (h d) -> b h n d', h=self.n_heads) for tensor in [Q, K, V]]

        qkt_dk = torch.matmul(Q, K.transpose(-2,-1)) / self.root_dk
        s_qkt_dk = torch.softmax(qkt_dk, dim=-1)
        attention_result = torch.matmul(s_qkt_dk, V)

        concat_out = rearrange(attention_result, 'b h n d -> b n (h d)')
        result = self.last_Linear(concat_out)

        return result



### MLP

In [8]:
# MLP (Linear -> GELU -> Dropout -> Linear)
class MLP(nn.Module):
    def __init__(self, patches_dim, n_hidden_layer, drop_p):
        super().__init__()

        self.mlp = nn.Sequential(
            nn.Linear(patches_dim, n_hidden_layer),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(n_hidden_layer, patches_dim))

    def forward(self, x):
        result = self.mlp(x)
        return result

### Encoder

In [9]:
# Encoder 블럭
class EncoderBlock(nn.Module):
    def __init__(self, patches_dim, n_hidden_layer, n_heads, drop_p):
        super().__init__()

        self.first_norm = nn.LayerNorm(patches_dim, eps=1e-6)
        self.self_attention = MultiHeadAttention(patches_dim, n_heads)
        self.second_norm = nn.LayerNorm(patches_dim, eps=1e-6)
        self.mlp = MLP(patches_dim, n_hidden_layer, drop_p)
        self.dropout = nn.Dropout(drop_p)

    def forward(self, x):
        norm_out = self.first_norm(x)
        attention_result = self.self_attention(norm_out)
        attention_result = self.dropout(attention_result)
        mha_result_with_skip = x + attention_result

        norm_out = self.second_norm(mha_result_with_skip)
        mlp_out = self.mlp(norm_out)
        mlp_out = self.dropout(mlp_out)
        result_with_skip = mha_result_with_skip + mlp_out

        return result_with_skip

# Encoder 실제 작동
class Encoder(nn.Module):
    def __init__(self, n_patches_with_cls, n_layers, patches_dim, n_hidden_layer, n_heads, drop_p):
        super().__init__()

        self.position_embedding = nn.Parameter(0.01 * torch.randn(n_patches_with_cls, patches_dim))
        self.encoder_blocks = nn.ModuleList([EncoderBlock(patches_dim, n_hidden_layer, n_heads, drop_p) for _ in range(n_layers)])
        self.norm_for_cls = nn.LayerNorm(patches_dim, eps=1e-6)
        self.dropout = nn.Dropout(drop_p)

    def forward(self, x):
        ready_for_encoder = x + self.position_embedding
        ready_for_encoder = self.dropout(ready_for_encoder)

        for encoder_block in self.encoder_blocks:
            encoder_result = encoder_block(ready_for_encoder)

        CLS = encoder_result[:,0,:]
        ready_for_head = self.norm_for_cls(CLS)

        return ready_for_head

### ViT Model

In [10]:
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, n_layers, patches_dim, n_hidden_layer, n_heads, n_MlpHead_hidden_layer, drop_p, n_classes, fine_tuning):
        super().__init__()

        seq_length_with_cls = (image_size // patch_size) ** 2 + 1
        self.cls = nn.Parameter(torch.randn(patches_dim) * 0.02)
        self.encoder = Encoder(seq_length_with_cls, n_layers, patches_dim, n_hidden_layer, n_heads, drop_p)
        self.patch_embedding = nn.Conv2d(3, patches_dim, patch_size, stride=patch_size)

        # conv
        fan_in = self.patch_embedding.in_channels * self.patch_embedding.kernel_size[0] * self.patch_embedding.kernel_size[1]
        init.trunc_normal_(self.patch_embedding.weight, std=math.sqrt(1 / fan_in))
        if self.patch_embedding.bias is not None:
            init.zeros_(self.patch_embedding.bias)

        # head
        if fine_tuning:
            self.head = nn.Linear(patches_dim, n_classes)
            init.zeros_(self.head.weight)
            init.zeros_(self.head.bias)
        else:
            self.head = nn.Sequential(
                nn.Linear(patches_dim, n_MlpHead_hidden_layer),
                nn.Tanh(),
                nn.Linear(n_MlpHead_hidden_layer, n_classes))
            fan_in = self.head[0].in_features
            init.trunc_normal_(self.head[0].weight, std=math.sqrt(1 / fan_in))
            init.zeros_(self.head[0].bias)

    def forward(self, x):
        patches = self.patch_embedding(x)
        ready_patches = rearrange(patches, 'b d ph pw -> b (ph pw) d')

        batch_cls = self.cls.expand(ready_patches.shape[0], 1, -1)
        all_patches_with_cls = torch.cat([batch_cls, ready_patches], dim=1)

        encoder_out = self.encoder(all_patches_with_cls)

        model_result = self.head(encoder_out)

        return model_result

## 모델 생성 예시

In [21]:
!pip install torchinfo
from torchinfo import summary
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = ViT(
    image_size = 224,
    patch_size = 16,
    n_layers = 12,
    patches_dim = 768,
    n_hidden_layer = 3072,
    n_heads = 12,
    n_MlpHead_hidden_layer = 512,
    drop_p = 0.1,
    n_classes = 1000,
    fine_tuning = False).to(DEVICE)

summary(model, input_size=(2, 3, 224, 224), device=DEVICE)