In [None]:
# https://yhkim4504.tistory.com/5
# https://yhkim4504.tistory.com/6
# https://yhkim4504.tistory.com/7

In [6]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

In [34]:
# Visualization
def single_image_show(tensor):
    plt.imshow(tensor.permute(1, 2, 0), vmin=0, vmax=255)

# 786
def channels_plt(tensor, row=16, col=48):
    fig, axs = plt.subplots(row, col, figsize=(32, 48))
    for i, ax in enumerate(axs.flat):
        ax.imshow(tensor[0, i, :].detach().cpu().numpy())
        ax.axis('off')
    plt.show()


In [36]:
# Patch Embedding

img_size = (512, 512)

# 배치사이즈10, 채널3, 높이512, 넓이512
x = torch.randn(10, 3, img_size[0], img_size[1])
print(x.shape) # torch.Size([10, 3, 512, 512])

# BATCHxCxH×W 형태를 가진 이미지 --> BATCHxNx(P*P*C)의 벡터로 임베딩: P는 패치사이즈이며 N은 패치의 개수(H*W / (P*P))
patch_size = 16 #16pixels
in_channels = 3
emb_size = 768

# print(nn.Sequential(
#     nn.Conv2d(in_channels=in_channels, out_channels=emb_size, kernel_size=3, stride=1, padding=1),
# )(x).shape) # torch.Size([10, 768, 512, 512])

projection = nn.Sequential(
    nn.Conv2d(in_channels=in_channels, out_channels=emb_size, kernel_size=patch_size, stride=patch_size),
)
print(projection(x).shape) # torch.Size([10, 768, 32, 32])

p_size = projection(x).shape
b = p_size[0]
c = p_size[1]
h = p_size[2]
w = p_size[3]

projection = projection(x).flatten(2).transpose(1,2)
# projection = projection(x).view(b, h * w, c) # wrong
print(projection.shape) # torch.Size([10, 1024, 768])

# cls_token, pos encoding Parameter 정의
cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
positions = nn.Parameter(torch.randn((img_size[0] // patch_size)**2 + 1, emb_size))
print(cls_token.shape) # torch.Size([1, 1, 768])
print(positions.shape) # torch.Size([1025, 768])

# batch 사이즈만큼 cls_token을 반복하여 크기를 맞춤
cls_tokens = cls_token.repeat(b, 1, 1)
print(cls_tokens.shape) # torch.Size([10, 1, 768])

# 배치 다음의 1차원 기준으로 cls_tokens, projection을 concat
cat_x = torch.cat([cls_tokens, projection], dim=1) # torch.Size([10, 1, 768]), torch.Size([10, 1024, 768])
print(cat_x.shape) # torch.Size([10, 1025, 768])

# position encoding을 더해줌
cat_x += positions # torch.Size([10, 1025, 768]) + torch.Size([1025, 768])
print(cat_x.shape) # torch.Size([10, 1025, 768])

torch.Size([10, 3, 512, 512])
torch.Size([10, 768, 32, 32])
torch.Size([10, 1024, 768])
torch.Size([1, 1, 768])
torch.Size([1025, 768])
torch.Size([10, 1, 768])
torch.Size([10, 1025, 768])
torch.Size([10, 1025, 768])


In [37]:
# 위 코드를 클래스로 구현한다면
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels:int=3, patch_size:int=16, emb_size:int=768, img_size:int=512):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=emb_size, kernel_size=patch_size, stride=patch_size),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size)**2 + 1, emb_size))
    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        b, c, h, w = x.shape
        x = x.flatten(2).transpose(2,1)
        cls_token = nn.Parameter(self.cls_token)
        cls_toekns = cls_token.repeat(b, 1, 1)
        x = torch.cat([cls_toekns, x], dim=1)
        x += self.positions
        return x

In [38]:
# Multi Head Attention (MHA)
emb_size = 768
num_heads = 8

x = PatchEmbedding()(x = torch.randn(10, 3, 512, 512))
print(x.shape) # torch.Size([10, 1025, 768])

# Linear Projection

# VIT에서의 MHA는 QKV가 같은 3개의 텐서로 입력
# 3개의 Linear Projection을 통해 임베딩된 후 여러 개의 Head로 나눠짐
# 이후 각각 Scaled Dot-Product Attention을 진행
keys = nn.Linear(emb_size, emb_size)
queries = nn.Linear(emb_size, emb_size)
values = nn.Linear(emb_size, emb_size)
print(keys)
print(queries)
print(values)
# Linear(in_features=768, out_features=768, bias=True) 
# Linear(in_features=768, out_features=768, bias=True) 
# Linear(in_features=768, out_features=768, bias=True)


# Multi-Head

# Linear Projection을 거친 QKV를 8개의 Multi-Head로 나누어줌
keys = rearrange(keys(x), "b n (h d) -> b h n d", h=num_heads)
# keys = keys(x)
# print(keys.shape) # torch.Size([10, 1025, 768])
# b, hw, c = keys.shape
# keys = keys.view(b, num_heads, hw, int(c/num_heads))
print(keys.shape) # torch.Size([10, 8, 1025, 96])

queries = rearrange(queries(x), "b n (h d) -> b h n d", h=num_heads)
# queries = queries(x)
# b, hw, c = queries.shape
# queries = queries.view(b, num_heads, hw, int(c/num_heads))
print(queries.shape) # torch.Size([10, 8, 1025, 96])

values = rearrange(values(x), "b n (h d) -> b h n d", h=num_heads)
# values = values(x)
# b, hw, c = values.shape
# values = values.view(b, num_heads, hw, int(c/num_heads))
print(values.shape) # torch.Size([10, 8, 1025, 96])

# Queries * Keys
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print('energy :', energy.shape) # torch.Size([10, 8, 1025, 1025])

# Get Attention Score
scaling = emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
print('att :', att.shape) # torch.Size([10, 8, 1025, 1025])

# Attention Score * values
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
print('out :', out.shape) # torch.Size([10, 8, 1025, 96])

# Rearrage to emb_size
b, h, n, d = out.shape
out = rearrange(out, "b h n d -> b n (h d)")
print('out2 : ', out.shape) # torch.Size([10, 1025, 768])

torch.Size([10, 1025, 768])
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
torch.Size([10, 8, 1025, 96])
torch.Size([10, 8, 1025, 96])
torch.Size([10, 8, 1025, 96])
energy : torch.Size([10, 8, 1025, 1025])
att : torch.Size([10, 8, 1025, 1025])
out : torch.Size([10, 8, 1025, 96])
out2 :  torch.Size([10, 1025, 768])


In [15]:
# 위 코드를 클래스로 구현한다면
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

# QKV 당 각각 1개씩의 Linear Layer를 적용한 것을 텐서 연산을 한번에 하기 위해 Linear Layer를 emb_size*3으로 설정한 후 
# 연산시 QKV를 각각 나눠주게 됩니다. 또한 Attention 시 무시할 정보를 설정하기 위한 masking 코드도 추가
# 마지막으로 나오는 out은 최종적으로 한번의 Linear Layer를 거쳐서 나오게 되는게 MHA의 모든 구현

In [16]:
# VIT Encoder 구조

# 나중에 fn을 입력받아 fn의 forward 후 res를 더해 사용
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

# MHA 이후에 진행되는 MLP 부분
# inear - GELU - Dropout - Linear 순으로 진행
# 두개의 Linear 레이어가 있는 것을 확인할 수 있으며 
# 첫번째 레이어에서는 expansion을 곱해준 만큼 임베딩 사이즈를 확장하고 GELU와 Dropout 후에 
# 두번째 Linear 레이어에서 다시 원래의 emb_size로 축소
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

# ResidualAdd와 FeedForwardBlock의 구현체
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
        )
    )

In [17]:
x = torch.randn(10, 3, 512, 512) # 배치사이즈10, 채널3, 높이512, 넓이512
patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape # torch.Size([10, 1025, 768])

# Block 쌓기
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

# Head
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

In [18]:
class ViT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )
        
summary(ViT(), (3, 224, 224), device='cpu')

Layer (type:depth-idx)                   Output Shape              Param #
├─PatchEmbedding: 1-1                    [-1, 197, 768]            --
|    └─Sequential: 2-1                   [-1, 768, 14, 14]         --
|    |    └─Conv2d: 3-1                  [-1, 768, 14, 14]         590,592
├─TransformerEncoder: 1-2                [-1, 197, 768]            --
|    └─TransformerEncoderBlock: 2-2      [-1, 197, 768]            --
|    |    └─ResidualAdd: 3-2             [-1, 197, 768]            2,363,904
|    |    └─ResidualAdd: 3-3             [-1, 197, 768]            4,723,968
|    └─TransformerEncoderBlock: 2-3      [-1, 197, 768]            --
|    |    └─ResidualAdd: 3-4             [-1, 197, 768]            2,363,904
|    |    └─ResidualAdd: 3-5             [-1, 197, 768]            4,723,968
|    └─TransformerEncoderBlock: 2-4      [-1, 197, 768]            --
|    |    └─ResidualAdd: 3-6             [-1, 197, 768]            2,363,904
|    |    └─ResidualAdd: 3-7             [-1,

Layer (type:depth-idx)                   Output Shape              Param #
├─PatchEmbedding: 1-1                    [-1, 197, 768]            --
|    └─Sequential: 2-1                   [-1, 768, 14, 14]         --
|    |    └─Conv2d: 3-1                  [-1, 768, 14, 14]         590,592
├─TransformerEncoder: 1-2                [-1, 197, 768]            --
|    └─TransformerEncoderBlock: 2-2      [-1, 197, 768]            --
|    |    └─ResidualAdd: 3-2             [-1, 197, 768]            2,363,904
|    |    └─ResidualAdd: 3-3             [-1, 197, 768]            4,723,968
|    └─TransformerEncoderBlock: 2-3      [-1, 197, 768]            --
|    |    └─ResidualAdd: 3-4             [-1, 197, 768]            2,363,904
|    |    └─ResidualAdd: 3-5             [-1, 197, 768]            4,723,968
|    └─TransformerEncoderBlock: 2-4      [-1, 197, 768]            --
|    |    └─ResidualAdd: 3-6             [-1, 197, 768]            2,363,904
|    |    └─ResidualAdd: 3-7             [-1,