# **Prototype ViT+CSP**

*   **Package install & Library import**
*   **Sub-design**
 *  Multi-Head Self-Attention layer
*   **Hyperparameter & Dataset**
*   **Vanilla ViT**
 *  Encoder
 *  ViT(Vision Transformer)
 *  Setting
 *  Train & Test
*   **ViT+CSP**
 *  Encoder
 *  ViT+CSP
 *  Setting
 *  Train & Test
*   **ViT+CSP version 2**
 *  MHA ver2
 *  Encoder
 *  ViT+CSP version2 
 *  Setting
 *  Train & Test
*  **Spec comparison**
 *  Accuracy
 *  FLOPS & Parameters
 *  Inference time
* **Model save & load (colab only)**

# Package install & Library import

In [None]:
""" Install einops, torchsummary packages """
!pip install einops           # Tensor dim
!pip install torchsummary     # Model summary
!pip install thop             # FLOPS check

In [2]:
""" Torch library """
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim 
import torch.backends.cudnn as cudnn 
import torchvision 
import torchvision.transforms as transforms
from torchsummary import summary

""" Tensor(or data)-related library """
import numpy as np
import pandas as pd
from einops import rearrange

""" Others """
import matplotlib.pyplot as plt
import pdb
import albumentations
import time
import os
import csv
import random
from tqdm import tqdm
from thop import profile

# Sub-design

In [3]:
""" Residual connection """
class Residual(nn.Module):
  def __init__(self, fn): 
    super().__init__()
    self.fn = fn 

  # y = x+f(x) : f(x) is residual info
  def forward(self, x):
    return self.fn(x) + x

In [4]:
""" Layer normalization """
class PreNorm(nn.Module): 
  def __init__(self, dim, fn):
    super().__init__()
    self.norm = nn.LayerNorm(dim)
    self.fn = fn 

  def forward(self, x):
    return self.fn(self.norm(x))  # Layer norm

In [5]:
""" MLP(Feedforward) layer """
class FeedForward(nn.Module):
  def __init__(self, dim, hidden_dim, dropout=0.):
    super().__init__() 
    self.net = nn.Sequential(
        nn.Linear(dim, hidden_dim), 
        nn.GELU(), 
        nn.Dropout(dropout),
        nn.Linear(hidden_dim, dim), 
        nn.Dropout(dropout)
    )
    
  # Linear -> GELU -> Dropout -> Linear -> Dropout
  def forward(self, x):
    return self.net(x)

## Multi-Head Self-Attention layer
<img src = "https://drive.google.com/uc?id=1PYw-0jfRuZR9c5YkZaDNFNGjvzpLA5ww" height = 300 width = 400>

In [6]:
""" Multi-Head Self-Attention """
class Attention(nn.Module):
  def __init__(self, dim, heads=8, dropout=0.):
    super().__init__()
    self.heads = heads 
    self.scales = heads ** (-0.5)

    self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
    self.to_out = nn.Sequential(
        nn.Linear(dim, dim), 
        nn.Dropout(dropout)
    )

  def forward(self, x):
    b, n, _, h = *x.shape, self.heads

    # input vector를 이용해서 query, key, value vector 생성
    # (batch, n, dim)이 3개인 tuple 형태
    qkv = self.to_qkv(x).chunk(3, dim=-1)

    # 각각의 vector를 head 만큼 split
    # (batch, n, dim) -> (batch, head, n, dim/head)
    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)

    # compute similarity (scaled dot product) (batch와 head를 유지하면서 행렬곱 진행)
    # (batch, head, i, d), (batch, head, j, d) -> (batch, head, i, j) // (i는 query의 개수, j는 key의 개수)
    dots = torch.einsum("bhid, bhjd->bhij", q, k) * self.scales

    # compute attention weights
    # (batch, head, i, j)
    attn = dots.softmax(dim=-1)

    # compute output vector (batch와 head 유지하면서 행렬곱 진행)
    # (batch, head, i, j), (batch, head, i, d) -> (batch, head, i, d)
    out = torch.einsum("bhij,bhjd->bhid", attn, v)

    # split했던 head들을 concat
    # (batch, head, i, d) -> (batch, i, head * d)
    out = rearrange(out, "b h n d -> b n (h d)")

    out = self.to_out(out)
    return out

# Hyperparameter & Dataset

In [7]:
"""
Initial Hyperparameter
o learning rate : 0.0001 (10^-4)
o batch size : 128
o # of epochs : 50
"""

lr = 1e-4 
bs = 128
n_epochs = 50

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Default = 'cuda'
best_acc = 0
start_epoch = 0

In [9]:
""" Make transformed data(set) module """
transform_train = transforms.Compose([
                                     # 이미지를 첫번째 인자의 사이즈로 자름, padding 크기만큼 의미없는 pixel(값 0) 채움
                                     transforms.RandomCrop(32, padding=4),
                                      
                                     # 이미지를 수평으로 뒤집는다
                                     transforms.RandomHorizontalFlip(),
                                      
                                     # 데이터를 tensor로 변환
                                     transforms.ToTensor(),
                                      
                                     # (mean, std) 괄호안의 수의 개수는 채널의 수
                                     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])  # Train

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])  # Test(Validation)

In [None]:
""" CIFAR-10 dataset dataloader for test run """
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')  # CIFAR-10 classes

# Vanilla ViT




## ViT encoder
<img src = "https://drive.google.com/uc?id=1IkQVsxWBwX28DXHHDFPHMcQd2lOe17b2" height = 200 width = 100>

In [11]:
""" Vanilla ViT encoder """
class Transformer(nn.Module):
  def __init__(self, dim, depth, heads, mlp_dim, dropout):
    super().__init__()
    self.layers = nn.ModuleList([]) # Initialize NULL layers

    # depth만큼 transformer block 쌓음
    for _ in range(depth):
      self.layers.append(nn.ModuleList([
                                        Residual(PreNorm(dim, Attention(dim, heads=heads, dropout=dropout))), 
                                        Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))
      ])) # 1 encoder * depth

  def forward(self, x):
    for attn, ff in self.layers:
      x = attn(x) # LayerNorm->Attention->Residual connection
      x = ff(x)   # LayerNorm->MLP Layer->Residual connection
    return x

## ViT(Vision Transformer)
<img src = "https://drive.google.com/uc?id=18GXXnl2X3d12q-YUcGa05VrnGAKdA3cb" height = 200 width = 300>

In [12]:
MIN_NUM_PATCHES = 16

""" Vanilla ViT """
class ViT(nn.Module): 
  def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dropout=0., emb_dropout=0.):
    super().__init__()

    """ Check parameter """
    assert image_size % patch_size == 0, "image size must be divisible for the patch size"
    num_patches = (image_size // patch_size) ** 2
    patch_dim = channels * patch_size**2 
    assert num_patches >= MIN_NUM_PATCHES,  f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'

    self.patch_size = patch_size 

    self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
    self.patch_to_embedding = nn.Linear(patch_dim, dim, bias=False)
    self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
    self.dropout = nn.Dropout(emb_dropout)

    self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) # Multiple encoders
    self.to_cls_token = nn.Identity() # Identity (x -> x)

    self.mlp_head = nn.Sequential(
        nn.LayerNorm(dim), 
        nn.Linear(dim, mlp_dim), 
        nn.GELU(), 
        nn.Dropout(dropout), 
        nn.Linear(mlp_dim, num_classes)
    ) # MLP head

  def forward(self, img):
    # img shape = (batch size, channel, height, width)

    p = self.patch_size
    
    # 이미지를 패치로 자름, 각 패치를 flatten(1-Dimension) 하게 만듬
    # (batch size, channel, height, width) -> (batch size, number of patch, patch_dim)
    x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p)
    
    # patch embedding
    # (batch size, number of patch, patch_dim) -> (batch size, number of patch, dim)
    x = self.patch_to_embedding(x)
    b, n, _ = x.shape

    # class token embedding vector를 batch size만큼 확장
    # (batch size, 1, dim)
    cls_tokens = self.cls_token.expand(b, -1, -1)

    # x 맨 앞에 cls_tokens 붙힘
    # (batch size, number of patch, dim) -> (batch size, number of patch + 1, dim)
    x = torch.cat((cls_tokens, x), dim=1)

    # x에 pos_embedding 더해줌
    # shape 같음
    x += self.pos_embedding[:, :(n+1)]
    x = self.dropout(x)

    # trnasformer encoder layer
    # (batch size, number of patch + 1, dim) -> (batch size, number of patch + 1, dim)
    x = self.transformer(x)

    # x 맨앞의 head(class token embedding vector) 추출
    # (batch size, number of patch + 1, dim) -> (batch size, dim)
    x = self.to_cls_token(x[:, 0])

    # mlp layer를 통해 output 구함
    # (batch size, dim) -> (batch size, class size(10))
    return self.mlp_head(x)

## ViT setting

In [13]:
""" Sample ViT using CIFAR-10 """
net = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)
net = net.to(device)

In [14]:
""" ViT Optimizer & Scheduler """
criterion = nn.CrossEntropyLoss() # Cross entropy loss
optimizer = optim.Adam(net.parameters(), lr=lr) # Adam
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, verbose=True, min_lr=1e-3*1e-5, factor=0.1)

In [None]:
summary(net, (3,32,32))

## Train & Test

In [16]:
def train(model):
  model.train()
  train_loss = 0
  correct = 0
  total = 0

  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to(device), targets.to(device)
    optimizer.zero_grad()
    outputs = model(inputs)

    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

  return train_loss/(batch_idx+1)

In [17]:
def test(model):
  global best_acc
  model.eval() 
  test_loss = 0 
  correct = 0 
  total = 0 

  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader): 
      inputs, targets = inputs.to(device), targets.to(device)
      outputs = model(inputs)

      loss = criterion(outputs, targets)
      test_loss += loss.item() 
      
      _, predicted = outputs.max(1)
      total += targets.size(0) 
      correct += predicted.eq(targets).sum().item()
      
  scheduler.step(test_loss)
  acc = 100. * correct/total

  if acc > best_acc: 
    print("Best accuracy: {}\n".format(acc))
  return test_loss, acc

In [None]:
list_loss_vit = []
list_acc_vit = []
for epoch in range(n_epochs):
  print('Epoch: %d' % epoch)
  train(net)
  val_loss, val_acc = test(net)

  list_loss_vit.append(val_loss)
  list_acc_vit.append(val_acc)

In [19]:
def random_check(model, x):
  model.eval() 

  with torch.no_grad():
    outputs = model(x)
    _, predicted = outputs.max(1)

  return predicted

In [None]:
######### 0 ~ 9999 ###############
num = random.randint(0,9999)

compare_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)
img, _ = compare_set[num]

x, label = testset[num]
x = x.unsqueeze(dim=0).to(device)

predicted = random_check(net, x)

print('label:', classes[label])
print('predicted:', classes[predicted])
plt.imshow(img)
print()

# ViT+CSP

## ViT+CSP encoder

In [21]:
""" ViT+CSP encoder """
class CSP_Transformer(nn.Module):
  def __init__(self, dim, depth, heads, mlp_dim, dropout):
    super().__init__()
    self.layers = nn.ModuleList([]) # Initialize NULL layers
    
    self.CSPS = Residual(PreNorm(dim, Attention(dim, heads=heads, dropout=dropout)))  # TR-CSP block
    self.FFN = Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))

    """ First 3 : TR-CSP 
        Last  3 : Vanilla attention """
    for _ in range(depth - 3):
      self.layers.append(nn.ModuleList([
                                        Residual(PreNorm(dim, Attention(dim, heads=heads, dropout=dropout))), 
                                        Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))
      ]))

  def forward(self, x):
    """ 1st TR-CSP """
    #origin = torch.tensor(x, dtype=torch.int32, device=device)   # Miti-DETR concept (not use)
    x1, x2 = x.chunk(2, dim = 1)
    x1 = self.CSPS(x1)  # 1st block : x1
    x = torch.cat([x1, x2], dim=1)
    x = self.FFN(x)
    #x += origin   # Miti-DETR concept (not use)

    """ 2nd TR-CSP """
    #origin = torch.tensor(x, dtype=torch.int32, device=device)
    x1, x2 = x.chunk(2, dim = 1)
    x2 = self.CSPS(x2)  # 2nd block : x2
    x = torch.cat([x1, x2], dim=1)
    x = self.FFN(x)
    #x += origin

    """ 3rd TR-CSP """
    #origin = torch.tensor(x, dtype=torch.int32, device=device)
    x1, x2 = x.chunk(2, dim = 1)
    x1 = self.CSPS(x1)  # 3rd block : x1
    x = torch.cat([x1, x2], dim=1)
    x = self.FFN(x)
    #x += origin

    for attn, ff in self.layers:
      x = attn(x) # LayerNorm->Attention->Residual connection
      x = ff(x)   # LayerNorm->MLP Layer->Residual connection
    
    return x

## ViT+CSP

In [22]:
MIN_NUM_PATCHES = 16

""" CSP_ViT """
class CSP_ViT(nn.Module): 
  def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dropout=0., emb_dropout=0.):
    super().__init__()
    
    """ Check parameter """
    assert image_size % patch_size == 0, "image size must be divisible for the patch size"
    num_patches = (image_size // patch_size) ** 2
    patch_dim = channels * patch_size**2 
    assert num_patches >= MIN_NUM_PATCHES,  f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'

    self.patch_size = patch_size 

    self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
    self.patch_to_embedding = nn.Linear(patch_dim, dim, bias=False)
    self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
    self.dropout = nn.Dropout(emb_dropout)

    self.transformer = CSP_Transformer(dim, depth, heads, mlp_dim, dropout) # Multiple encoders
    self.to_cls_token = nn.Identity() # Identity (x -> x)

    self.mlp_head = nn.Sequential(
        nn.LayerNorm(dim), 
        nn.Linear(dim, mlp_dim), 
        nn.GELU(), 
        nn.Dropout(dropout), 
        nn.Linear(mlp_dim, num_classes)
    ) # MLP head

  def forward(self, img):
    # img shape = (batch size, channel, height, width)

    p = self.patch_size
    
    # 이미지를 패치로 자름, 각 패치를 flatten(1-Dimension) 하게 만듬
    # (batch size, channel, height, width) -> (batch size, number of patch, patch_dim)
    x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p)
    
    # patch embedding
    # (batch size, number of patch, patch_dim) -> (batch size, number of patch, dim)
    x = self.patch_to_embedding(x)
    b, n, _ = x.shape

    # class token embedding vector를 batch size만큼 확장
    # (batch size, 1, dim)
    cls_tokens = self.cls_token.expand(b, -1, -1)

    # x 맨 앞에 cls_tokens 붙힘
    # (batch size, number of patch, dim) -> (batch size, number of patch + 1, dim)
    x = torch.cat((cls_tokens, x), dim=1)

    # x에 pos_embedding 더해줌
    # shape 같음
    x += self.pos_embedding[:, :(n+1)]
    x = self.dropout(x)

    # trnasformer encoder layer
    # (batch size, number of patch + 1, dim) -> (batch size, number of patch + 1, dim)
    x = self.transformer(x)

    # x 맨앞의 head(class token embedding vector) 추출
    # (batch size, number of patch + 1, dim) -> (batch size, dim)
    x = self.to_cls_token(x[:, 0])

    # mlp layer를 통해 output 구함
    # (batch size, dim) -> (batch size, class size(10))
    return self.mlp_head(x)

## ViT+CSP setting

In [23]:
""" ViT + CSP using CIFAR-10 """
csp_vit_net = CSP_ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)
csp_vit_net = csp_vit_net.to(device)

In [24]:
""" ViT + CSP Optimizer & Scheduler """
criterion = nn.CrossEntropyLoss() # Cross entropy loss
optimizer = optim.Adam(csp_vit_net.parameters(), lr=lr) # Adam
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, verbose=True, min_lr=1e-3*1e-5, factor=0.1)

In [None]:
summary(csp_vit_net, (3,32,32))

## Train & Test

In [None]:
""" ViT + CSP Optimizer & Scheduler """
list_loss_csps_vit = []
list_acc_csps_vit = []
for epoch in range(n_epochs):
  print('\nEpoch: %d' % epoch)
  train(csp_vit_net)  # train func in ViT
  val_loss, val_acc = test(csp_vit_net)  # test func in ViT

  list_loss_csps_vit.append(val_loss)
  list_acc_csps_vit.append(val_acc)

In [None]:
######### 0 ~ 9999 ###############
num = random.randint(0,9999)

compare_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)
img, _ = compare_set[num]

x, label = testset[num]
x = x.unsqueeze(dim=0).to(device)

predicted = random_check(csp_vit_net, x)

print('label:', classes[label])
print('predicted:', classes[predicted])
plt.imshow(img)
print()

# ViT+CSP ver2

## MHA ver2

In [28]:
""" Multi-Head Self-Attention for ViT + CSP ver2 """
class Attention_ver2(nn.Module): 
  def __init__(self, dim, dim2, heads=4, dropout=0.):
    super().__init__()
    self.heads = heads 
    self.scales = heads ** (-0.5)

    self.to_qkv = nn.Linear(dim2, dim2 * 3, bias=False)
    self.to_out = nn.Sequential(
        nn.Linear(dim2, dim2), 
        nn.Dropout(dropout)
    )

  def forward(self, x):
    b, n, _, h = *x.shape, self.heads

    # input vector를 이용해서 query, key, value vector 생성
    # (batch, n, dim)이 3개인 tuple 형태
    qkv = self.to_qkv(x).chunk(3, dim=-1)

    # 각각의 vector를 head 만큼 split
    # (batch, n, dim) -> (batch, head, n, dim/head)
    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)

    # compute similarity (scaled dot product) (batch와 head를 유지하면서 행렬곱 진행)
    # (batch, head, i, d), (batch, head, j, d) -> (batch, head, i, j) // (i는 query의 개수, j는 key의 개수)
    dots = torch.einsum("bhid, bhjd->bhij", q, k) * self.scales

    # compute attention weights
    # (batch, head, i, j)
    attn = dots.softmax(dim=-1)

    # compute output vector (batch와 head 유지하면서 행렬곱 진행)
    # (batch, head, i, j), (batch, head, i, d) -> (batch, head, i, d)
    out = torch.einsum("bhij,bhjd->bhid", attn, v)

    # split했던 head들을 concat
    # (batch, head, i, d) -> (batch, i, head * d)
    out = rearrange(out, "b h n d -> b n (h d)")

    out = self.to_out(out)
    return out

## ViT+CSP ver2 encoder

In [29]:
""" ViT + CSP ver2 encoder """
class CSP_Transformer_ver2(nn.Module):
  def __init__(self, dim, dim2, depth, heads, mlp_dim, dropout):
    super().__init__()
    self.layers = nn.ModuleList([])
    
    self.CSPS = Residual(PreNorm(dim2, Attention_ver2(dim, dim2, dropout=dropout))) # TR-CSP ver2
    self.FFN = Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))

    """ First 3 : TR-CSP ver2
        Last  3 : Vanilla attention """
    for _ in range(depth - 3):
      self.layers.append(nn.ModuleList([
                                        Residual(PreNorm(dim, Attention(dim, heads=heads, dropout=dropout))), 
                                        Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))
      ]))

  def forward(self, x):
    """ 1st TR-CSP ver2 """
    x1, x2 = x.chunk(2, dim = 2)
    x1 = self.CSPS(x1)
    x = torch.cat([x1, x2], dim=2)
    x = self.FFN(x)

    """ 2nd TR-CSP ver2 """
    x1, x2 = x.chunk(2, dim = 2)
    x2 = self.CSPS(x2)
    x = torch.cat([x1, x2], dim=2)
    x = self.FFN(x)

    """ 3rd TR-CSP ver2 """
    x1, x2 = x.chunk(2, dim = 2)
    x1 = self.CSPS(x1)
    x = torch.cat([x1, x2], dim=2)
    x = self.FFN(x)

    for attn, ff in self.layers:
      x = attn(x) # LayerNorm->Attention->Residual connection
      x = ff(x)   # LayerNorm->MLP Layer->Residual connection
    return x

## ViT+CSP ver2

In [30]:
MIN_NUM_PATCHES = 16

""" CSP_ViT ver2 """
class CSP_ViT_ver2(nn.Module): 
  def __init__(self, *, image_size, patch_size, num_classes, dim, dim2, depth, heads, mlp_dim, channels=3, dropout=0., emb_dropout=0.):
    super().__init__()
    
    """ Check parameter """
    assert image_size % patch_size == 0, "image size must be divisible for the patch size"
    num_patches = (image_size // patch_size) ** 2
    patch_dim = channels * patch_size**2 
    assert num_patches >= MIN_NUM_PATCHES,  f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'

    self.patch_size = patch_size 

    self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
    self.patch_to_embedding = nn.Linear(patch_dim, dim, bias=False)
    self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
    self.dropout = nn.Dropout(emb_dropout)

    self.transformer = CSP_Transformer_ver2(dim, dim2, depth, heads, mlp_dim, dropout) # Multiple encoders
    self.to_cls_token = nn.Identity() # Identity (x -> x)

    self.mlp_head = nn.Sequential(
        nn.LayerNorm(dim), 
        nn.Linear(dim, mlp_dim), 
        nn.GELU(), 
        nn.Dropout(dropout), 
        nn.Linear(mlp_dim, num_classes)
    ) # MLP head

  def forward(self, img):
    # img shape = (batch size, channel, height, width)

    p = self.patch_size
    
    # 이미지를 패치로 자름, 각 패치를 flatten(1-Dimension) 하게 만듬
    # (batch size, channel, height, width) -> (batch size, number of patch, patch_dim)
    x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p)
    
    # patch embedding
    # (batch size, number of patch, patch_dim) -> (batch size, number of patch, dim)
    x = self.patch_to_embedding(x)
    b, n, _ = x.shape

    # class token embedding vector를 batch size만큼 확장
    # (batch size, 1, dim)
    cls_tokens = self.cls_token.expand(b, -1, -1)

    # x 맨 앞에 cls_tokens 붙힘
    # (batch size, number of patch, dim) -> (batch size, number of patch + 1, dim)
    x = torch.cat((cls_tokens, x), dim=1)

    # x에 pos_embedding 더해줌
    # shape 같음
    x += self.pos_embedding[:, :(n+1)]
    x = self.dropout(x)

    # trnasformer encoder layer
    # (batch size, number of patch + 1, dim) -> (batch size, number of patch + 1, dim)
    x = self.transformer(x)

    # x 맨앞의 head(class token embedding vector) 추출
    # (batch size, number of patch + 1, dim) -> (batch size, dim)
    x = self.to_cls_token(x[:, 0])

    # mlp layer를 통해 output 구함
    # (batch size, dim) -> (batch size, class size(10))
    return self.mlp_head(x)

## ViT+CSP ver2 setting

In [31]:
""" ViT + CSP ver2 using CIFAR-10 """
csp_vit_ver2_net = CSP_ViT_ver2(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    dim2= 256,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)
csp_vit_ver2_net = csp_vit_ver2_net.to(device)

In [32]:
""" ViT + CSP ver2 Optimizer & Scheduler """
criterion = nn.CrossEntropyLoss() # Cross entropy loss
optimizer = optim.Adam(csp_vit_ver2_net.parameters(), lr=lr) # Adam
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, verbose=True, min_lr=1e-3*1e-5, factor=0.1)

In [None]:
summary(csp_vit_ver2_net, (3,32,32))

## Train & Test

In [None]:
list_loss_csps_vit_ver2 = []
list_acc_csps_vit_ver2 = []
for epoch in range(n_epochs):
  print('\nEpoch: %d' % epoch)
  train(csp_vit_ver2_net)
  val_loss, val_acc = test(csp_vit_ver2_net)

  list_loss_csps_vit_ver2.append(val_loss)
  list_acc_csps_vit_ver2.append(val_acc)

In [None]:
######### 0 ~ 9999 ###############
num = random.randint(0,9999)

compare_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)
img, _ = compare_set[num]

x, label = testset[num]
x = x.unsqueeze(dim=0).to(device)

predicted = random_check(csp_vit_ver2_net, x)

print('label:', classes[label])
print('predicted:', classes[predicted])
plt.imshow(img)
print()

# Spec comparison

## Accuracy

In [None]:
%matplotlib inline

""" LOSS """
plt.plot(list_loss_vit, label = "novel_loss", color = 'black')
plt.plot(list_loss_csps_vit, label = "csps_loss", color = 'blue')
plt.plot(list_loss_csps_vit_ver2, label = "csps2_loss", color = 'skyblue')
plt.legend()
plt.show()

""" ACCURACY """
plt.plot(list_acc_vit, label = "novel_accuracy", color = 'black')
plt.plot(list_acc_csps_vit, label = "csps_accuracy", color = 'blue')
plt.plot(list_acc_csps_vit_ver2, label = "csps2_accuracy", color = 'skyblue')
plt.legend()
plt.show()

In [None]:
""" In colab : download accuracy """

"""
from google.colab import drive
drive.mount('/content/gdrive')

np.save(F"/content/gdrive/MyDrive/VIT_result/np_csps_ver2_list_loss.npy", list_loss_csps_vit_ver2)
np.save(F"/content/gdrive/MyDrive/VIT_result/np_csps_ver2_list_accu.npy", list_acc_csps_vit_ver2)
np.save(F"/content/gdrive/MyDrive/VIT_result/np_csps_list_loss.npy", list_loss_csps_vit)
np.save(F"/content/gdrive/MyDrive/VIT_result/np_csps_list_accu.npy", list_acc_csps_vit)
np.save(F"/content/gdrive/MyDrive/VIT_result/novel_vit_loss.npy", list_loss)
np.save(F"/content/gdrive/MyDrive/VIT_result/novel_vit_accu.npy", list_acc)

!ls /content/gdrive/MyDrive/VIT_result
"""

## FLOPS & Parameters

In [None]:
input = torch.randn(1, 3, 32, 32) # CIFAR-10 input size
input = input.to(device)

nobel_flops, nobel_params = profile(net, inputs = (input, ))
csps_flops, csps_params = profile(csp_vit_net, inputs = (input, ))
csps2_flops, csps2_params = profile(csp_vit_ver2_net, inputs = (input, ))

In [None]:
print('FLOPS: ')
print('Valilla ViT : FLOPS =', int(nobel_flops), ', parameters =', int(nobel_params))
print('ViT+CSP : FLOPS =', int(csps_flops), ', parameters =', int(csps_params))
print('CSP ver2 : FLOPS =', int(csps2_flops), ', parameters =', int(csps2_params))

## Inference time

In [40]:
def measure_inference_time_ver2(batch):
  starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
  repetitions = 5000    # Iteration count

  novel_vit_inference_time = np.zeros((repetitions,1))
  csps_vit_inference_time = np.zeros((repetitions,1))
  csps2_vit_inference_time = np.zeros((repetitions,1))

  input_image = torch.randn(batch,3,32,32)
  input_image = input_image.to(device)

  for _ in range(10):
    _ = net(input_image)
    _ = csp_vit_net(input_image)
    _ = csp_vit_ver2_net(input_image)
  
  pbar = range(repetitions)
  # MEASURE PERFORMANCE
  with torch.no_grad():
    for rep in tqdm(pbar, desc = 'Iteration', total = repetitions, leave = False):
        starter.record()
        _ = net(input_image)
        ender.record()
        # WAIT FOR GPU SYNC
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        novel_vit_inference_time[rep] = curr_time
        
        starter.record()
        _ = csp_vit_net(input_image)
        ender.record()
        # WAIT FOR GPU SYNC
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        csps_vit_inference_time[rep] = curr_time
        
        starter.record()
        _ = csp_vit_ver2_net(input_image)
        ender.record()
        # WAIT FOR GPU SYNC
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        csps2_vit_inference_time[rep] = curr_time
            
  input_image = torch.randn(1,3,32,32)
  return novel_vit_inference_time, csps_vit_inference_time, csps2_vit_inference_time

In [41]:
def print_inference(batch):
  title = 'inference time (batch size = ' + str(batch) + ')'

  plt.plot(novel_vit_inference_time, label = "novel_vit", color = 'black')
  plt.plot(csps_vit_inference_time, label = "csps_vit", color = 'blue')
  plt.plot(csps2_vit_inference_time, label = "csps2_vit", color = 'skyblue')

  plt.legend()
  plt.title(title)
  plt.ylabel('inference_time')
  plt.show()

  print("Inference time")
  print('Valilla ViT :', np.average(novel_vit_inference_time))
  print('ViT+CSP :', np.average(csps_vit_inference_time))
  print('CSP ver2 :', np.average(csps2_vit_inference_time))

In [None]:
""" BATCH SIZE = 100 """
batch = 100
novel_vit_inference_time, csps_vit_inference_time, csps2_vit_inference_time = measure_inference_time_ver2(batch)
print_inference(batch)

""" BATCH SIZE = 1 """
batch = 1
novel_vit_inference_time, csps_vit_inference_time, csps2_vit_inference_time = measure_inference_time_ver2(batch)
print_inference(batch)

# Model save & load (colab only)

In [None]:
"""from google.colab import drive

def SAVE(name, model):
  path = F"/content/gdrive/MyDrive/{name}"
  torch.save(model.state_dict(), path)

def LOAD(name, model):
  path = F"/content/gdrive/MyDrive/{name}"
  model.load_state_dict(torch.load(path))

drive.mount('/content/gdrive')"""

In [None]:
"""SAVE('Vanilla_ViT.pt', net)
SAVE('CSP+ViT.pt', csp_vit_net)
SAVE('CSP+ViT_ver2.pt', csp_vit_ver2_net)

!ls /content/gdrive/MyDrive"""

In [45]:
"""LOAD('Vanilla_ViT.pt', net)
LOAD('CSP+ViT.pt', csp_vit_net)
LOAD('CSP+ViT_ver2.pt', csp_vit_ver2_net)"""