In [3]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [6]:
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

## 1. Image를 patch로 Projection

In [12]:
x = torch.randn(8, 3, 224, 224) # batch_size=8, channel=3, h=224, w=224
print(f'x: {x.shape}')

patch_size = 16 # 16 x 16 pixel patch
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)',
                    s1=patch_size, s2=patch_size)
print(f'patches : ', patches.shape)                    

x: torch.Size([8, 3, 224, 224])
patches :  torch.Size([8, 196, 768])


In [18]:
# 위와 같이 rearrange를 통해 단순히 reshape을 해줄 수 있지만, 
# 아래와 같이 Conv를 이용해 patch를 만들면, 성능에 이점이 있다고 합니다.

patch_size = 16
in_channels = 3
emb_size = 768 # channel * patch_size * patch_size

# using a conv layer instaed of a linear one => performance gains
projection = nn.Sequential(
  nn.Conv2d(in_channels, emb_size,
            kernel_size=patch_size, stride=patch_size), # torch.Size([8, 768, 14, 14])
  Rearrange('b e (h) (w) -> b (h w) e') # torch.Size([8, 196,768])
)

summary(projection, x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
Total params: 590,592
Trainable params: 590,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 2.30
Params size (MB): 2.25
Estimated Total Size (MB): 5.12
----------------------------------------------------------------


## 2. Class token 추가 및 Positional Embedding


In [26]:

emb_size = 768
img_size = 224
patch_size = 16

# image를 patch_size로 나누고 flatten
projected_x = projection(x)
print(f'Projected X shape: {projected_x.shape}') # torch.Size([8, 196, 768])

# cls_token과 pos encoding Parameter 정의
cls_token = nn.Parameter(torch.randn(1,1, emb_size)) # torch.Size([1, 1, 768])
positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1,emb_size)) # torch.Size([197, 768])
print(f'cls : {cls_token.shape}, pos: {positions.shape}')

# cls_token을 반복하여 batch_size의 크기와 맞춰줌.
batch_size = 8
cls_tokens = repeat(cls_token, '() n e -> b n e', b=batch_size) # torch.Size(8, 1, 768)
print(f'Repeated Cls shape: {cls_tokens.shape}') 

# cls_token과 projected_x를 concatenate
cat_x = torch.cat([cls_tokens, projected_x], dim=1) # torch.Size([8, 197, 768])

# position encoding을 더해줌
cat_x += positions 
print(f'output: {cat_x.shape}')

Projected X shape: torch.Size([8, 196, 768])
cls : torch.Size([1, 1, 768]), pos: torch.Size([197, 768])
Repeated Cls shape: torch.Size([8, 1, 768])
output: torch.Size([8, 197, 768])


### 1 + 2 ) Patch Projection, Positional Encoding을 Class 형태로 만들어줌

In [45]:
class PatchEmbedding(nn.Module):
  def __init__(self, in_channels: int=3, patch_size: int=16,
                emb_size: int=768, img_size: int=224):
      self.patch_size = patch_size
      super().__init__()
      self.projection = nn.Sequential(
        nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
        Rearrange('b e (h) (w) -> b (h w) e')
      )
      self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
      self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))

  def forward(self, x: Tensor) -> Tensor:
    batch_size = x.shape[0] # batch_size
    x = self.projection(x)
    cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=batch_size)
    # prepend the cls_token to the input
    x = torch.cat([cls_tokens, x], dim=1)
    # add position embedding
    x += self.positions # torch.size([8, 197, 768])

    return x
      



In [46]:
PE = PatchEmbedding()
summary(PE, (3, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
Total params: 590,592
Trainable params: 590,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 2.30
Params size (MB): 2.25
Estimated Total Size (MB): 5.12
----------------------------------------------------------------


## 3. Multi Head Attention (MHA)

patch들에 대해 Self-Attention 메커니즘을 적용함

In [52]:
x = torch.randn(8, 3, 224, 224) # batch_size=8, channel=3, h=224, w=224
emb_size = 768
num_heads = 8

keys = nn.Linear(emb_size, emb_size)
queries = nn.Linear(emb_size, emb_size)
values = nn.Linear(emb_size, emb_size)
print(f'keys: {keys},\nqueries: {queries},\nvalues: {values}\n')

x = PE(x)
print(f'queries(x): {queries(x).shape}\n') # torch.Size([8, 197, 768]) = [batch, n, emb_size]
queries = rearrange(queries(x), 'b n (h d) -> b h n d', h=num_heads)
keys = rearrange(keys(x), 'b n (h d) -> b h n d', h=num_heads)
values = rearrange(values(x), 'b n (h d) -> b h n d', h=num_heads)

print(f'queires: {queries.shape}\nkeys: {keys.shape}\nvalues: {values.shape}')


keys: Linear(in_features=768, out_features=768, bias=True),
queries: Linear(in_features=768, out_features=768, bias=True),
values: Linear(in_features=768, out_features=768, bias=True)

queries(x): torch.Size([8, 197, 768])

queires: torch.Size([8, 8, 197, 96])
keys: torch.Size([8, 8, 197, 96])
values: torch.Size([8, 8, 197, 96])


In [56]:
# Queries * Keys
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print(f'energy: {energy.shape}')

# Get Attention Score
scaling = emb_size ** (1/2)
att = F.softmax(energy/scaling, dim=-1)
print(f'att: {att.shape}')

# Attention Score * Values
out = torch.einsum('bhal, bhlv -> bhav', att, values)
print('out:', out.shape)

# Rearrange to emb_size
out = rearrange(out, 'b h n d -> b n (h d)')
print('rearranged_out: ', out.shape)


energy: torch.Size([8, 8, 197, 197])
att: torch.Size([8, 8, 197, 197])
out: torch.Size([8, 8, 197, 96])
rearranged_out:  torch.Size([8, 197, 768])


### MultiHead Attention을 Class로 묶어주기

In [62]:
class MultiHeadAttention(nn.Module):
  def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
    super().__init__()
    self.emb_size = emb_size
    self.num_heads = num_heads
    # fuse the queries, keys and values in one matrix
    self.qkv = nn.Linear(emb_size, emb_size*3)
    self.att_drop = nn.Dropout(dropout)
    self.projection = nn.Linear(emb_size, emb_size)
  
  def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
    # split queries, keys and vlaues in num_head
    qkv = rearrange(self.qkv(x), 'b n (h d qkv) -> (qkv) b h n d', h = self.num_heads, qkv=3)
    queries, keys, values = qkv[0], qkv[1], qkv[2]
    # sum up over the last axis
    energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
    # energy: [8, 8, 197, 197]

    if mask is not None:
      fill_value = torch.finfo(torch.float32).min
      energy.mask_fill(~mask, fill_value)

    scaling = self.emb_size ** (1/2)
    att = F.softmax(energy/scaling, dim=-1)
    att = self.att_drop(att)

    # sum up over the third axis
    out = torch.einsum('bhal, bhlv -> bhav', att, values)
    out = rearrange(out, 'b h n d -> b n (h d)')
    out = self.projection(out)
    return out

In [63]:
x = torch.randn(8, 3, 224, 224)
PE = PatchEmbedding()
x = PE(x)
print(x.shape)
MHA = MultiHeadAttention()
summary(MHA, x.shape[1:], device='cpu')

torch.Size([8, 197, 768])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1            [-1, 197, 2304]       1,771,776
           Dropout-2          [-1, 8, 197, 197]               0
            Linear-3             [-1, 197, 768]         590,592
Total params: 2,362,368
Trainable params: 2,362,368
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.58
Forward/backward pass size (MB): 6.99
Params size (MB): 9.01
Estimated Total Size (MB): 16.57
----------------------------------------------------------------


## 4. Transformer Encoder Block

In [66]:
class ResidualAdd(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn

  def forward(self, x, **kwargs):
    res = x
    x = self.fn(x, **kwargs)
    x += res
    return x

class FeedForwardBlock(nn.Sequential):
  def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
    super().__init__(
      nn.Linear(emb_size, expansion * emb_size),
      nn.GELU(),
      nn.Dropout(drop_p),
      nn.Linear(expansion * emb_size, emb_size),
    )

class TransformerEncoderBlock(nn.Sequential):
  def __init__(self, emb_size: int = 768, drop_p: float = 0., 
               forward_expansion: int = 4, forward_drop_p: float=0.,
               **kwargs):
      super().__init__(
        ResidualAdd(
          nn.Sequential(
            nn.LayerNorm(emb_size),
            MultiHeadAttention(emb_size, **kwargs),
            nn.Dropout(drop_p)
        )),
        ResidualAdd(
          nn.Sequential(
            nn.LayerNorm(emb_size),
            FeedForwardBlock(
              emb_size, expansion=forward_expansion, drop_p=forward_drop_p
            ),
            nn.Dropout(drop_p),
          )
        )
      )

In [67]:
x  = torch.randn(8, 3, 224, 224)
x = PE(x)
x = MHA(x)
TE = TransformerEncoderBlock()
summary(TE, x.shape[1:], device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         LayerNorm-1             [-1, 197, 768]           1,536
            Linear-2            [-1, 197, 2304]       1,771,776
           Dropout-3          [-1, 8, 197, 197]               0
            Linear-4             [-1, 197, 768]         590,592
MultiHeadAttention-5             [-1, 197, 768]               0
           Dropout-6             [-1, 197, 768]               0
       ResidualAdd-7             [-1, 197, 768]               0
         LayerNorm-8             [-1, 197, 768]           1,536
            Linear-9            [-1, 197, 3072]       2,362,368
             GELU-10            [-1, 197, 3072]               0
          Dropout-11            [-1, 197, 3072]               0
           Linear-12             [-1, 197, 768]       2,360,064
          Dropout-13             [-1, 197, 768]               0
      ResidualAdd-14             [-1, 1

ViT에는 이런 Encoder block이 12개가 있다.



## 5. 마지막으로 다 묶어서 ViT 빌드
classification을 위한 ClassificationHead를 만들어 모델의 마지막 단에 넣어준다.

In [69]:
class TransformerEncoder(nn.Sequential):
  def __init__(self, depth: int = 12, **kwargs):
    super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

class ClassificationHead(nn.Sequential):
  def __init__(self, emb_size: int = 768, n_classes: int = 1000):
    super().__init__(
      Reduce('b n e -> b e', reduction='mean'), # [8, 196, 768] -> [8, 768]
      nn.LayerNorm(emb_size),
      nn.Linear(emb_size, n_classes)
    )

class ViT(nn.Sequential)    :
  def __init__(self, in_channels: int=3, patch_size: int=16, emb_size:int = 768, 
               img_size: int = 224, depth: int = 12, n_classes: int = 1000,
               **kwargs):
      
      super().__init__(
        PatchEmbedding(in_channels, patch_size, emb_size, img_size),
        TransformerEncoder(depth, emb_size=emb_size, **kwargs),
        ClassificationHead(emb_size, n_classes)
      )

summary(ViT(), (3, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
         Rearrange-2             [-1, 196, 768]               0
    PatchEmbedding-3             [-1, 197, 768]               0
         LayerNorm-4             [-1, 197, 768]           1,536
            Linear-5            [-1, 197, 2304]       1,771,776
           Dropout-6          [-1, 8, 197, 197]               0
            Linear-7             [-1, 197, 768]         590,592
MultiHeadAttention-8             [-1, 197, 768]               0
           Dropout-9             [-1, 197, 768]               0
      ResidualAdd-10             [-1, 197, 768]               0
        LayerNorm-11             [-1, 197, 768]           1,536
           Linear-12            [-1, 197, 3072]       2,362,368
             GELU-13            [-1, 197, 3072]               0
          Dropout-14            [-1, 19