In [None]:
patch_size = 16
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1 = patch_size, s2 = patch_size)



In [None]:
import os

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.image as img

import cv2

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, CenterCrop
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary


from torch import optim
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import utils
from torch.optim.lr_scheduler import ReduceLROnPlateau

import copy
import time


print(torch.__version__)

1.13.1+cu116


In [None]:
# PatchEmbedding Ver1
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels = 3, patch_size = 16, emb_size = 768):
        super().__init__()
        self.projection = nn.Sequential(Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1 = patch_size, s2 = patch_size),
                                        nn.Linear(patch_size * patch_size * in_channels, emb_size)
                                        )
        
    def forward(self, x):
        x = self.projection(x)
        
        return x
    
# PatchEmbedding Ver2
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels = 3, patch_size = 16, emb_size = 768):
        super().__init__()
        self.projection = nn.Sequential(nn.Conv2d(in_channels, emb_size, kernel_size = patch_size, stride = patch_size),
                                        Rearrange('b e (h) (w) -> b (h w) e')
                                        )
        
    def forward(self, x):
        x = self.projection(x)
        
        return x
    
# CLS Token
"""
PE를 위한 'cls token' 추가, 각 시퀀스 앞에 붙여주기
"""

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels = 3, patch_size = 16, emb_size = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size = patch_size, stride = patch_size),
            Rearrange('b e (h) (w) -> b (h w ) e')
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_token = repeat(self.cls_token, '() n e -> b n e', b = b)
        
        x = torch.cat([cls_token, x], dim = 1)
        
        return x
    
# Position Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels = 3, patch_size = 16, emb_size = 768, img_size = 224):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size = patch_size, stride = patch_size),
            Rearrange('b e (h) (w) -> b (h w) e')
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))
        
    def forward(self, x : Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_token = repeat(self.cls_token, '() n e -> b n e', b = b)
        
        # Input cls_token
        x = torch.cat([cls_token, x], dim = 1)
        
        # add Position Embedding
        x += self.positions
        
        return x

In [None]:
# MultiHeadAttention

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size = 768, num_heads = 8, dropout = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        self.scaling = (self.emb_size // self.num_heads) ** -0.5
        
    def forward(self, x : Tensor, mask : Tensor = None) -> Tensor:
        # split q, k, v in num_heads
        queries = rearrange(self.queries(x), 'b n (h d) -> b h n d', h = self.num_heads)
        keys = rearrange(self.keys(x), 'b n (h d) -> b h n d', h = self.num_heads)
        values = rearrange(self.values(x), 'b n (h d) -> b h n d', h = self.num_heads)
        
        # sum
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  ## batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        att = F.softmax(energy, dim = -1) * self.scaling
        att = self.att_drop(att)
        
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.projection(out)
        
        return out

In [None]:
# Residuals

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        
        return x
# MLP
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion = 4, drop_p = 0):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size)
        )

In [None]:
# Transformer

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self, emb_size = 768, drop_p = 0, forward_expansion = 4, forward_drop_p = 0, **kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(emb_size, expansion = forward_expansion, drop_p = forward_drop_p),
                nn.Dropout(drop_p)
            ))
        )

In [None]:
# Transformer
"""
ViT에서는 original Transformer의 encoder 부분만을 사용함.
encoder는 TransformerBlock의 L block.

ViT-Base : 12
ViT-Large : 24
ViT-Huge : 32
"""

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth = 12, **kwargs):
        super().__init__( *[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

In [None]:
# Head
"""
마지막 레이어는 noraml FC (레이블에 대한 확률 값)
"""

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size = 768, n_classes = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction = 'mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )

In [None]:
# ViT (PatchEmbedding + TransformerEncoder + ClassificationHead)

class ViT(nn.Sequential):
    def __init__(self,
                 in_channels = 3,
                 patch_size = 16,
                 emb_size = 768,
                 img_size = 224,
                 depth = 12,
                 n_classes = 1000,
                 **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size = emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
## 데이터셋

data_path = './data'

# load
train_ds = datasets.STL10(data_path, split = 'train', download = True, transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(224)
]))
val_ds = datasets.STL10(data_path, split = 'test', download = True, transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(224)
]))

print(len(train_ds))
print(len(val_ds))
train_dl = DataLoader(train_ds, batch_size = 32, shuffle = True)
val_dl = DataLoader(val_ds, batch_size = 64, shuffle = True)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz


  0%|          | 0/2640397119 [00:00<?, ?it/s]

Extracting ./data/stl10_binary.tar.gz to ./data
Files already downloaded and verified
5000
8000


In [None]:
x = torch.randn(16, 3, 224, 224).to(device)

patch_embedding = PatchEmbedding().to(device)
patch_output = patch_embedding(x)
print(patch_output.shape)

MHA = MultiHeadAttention().to(device)
MHA_output = MHA(patch_output)
print(MHA_output.shape)

x = torch.randn(16, 1, 128).to(device)

model = FeedForwardBlock(128).to(device)
output = model(x)
print(output.shape)

model = TransformerEncoderBlock().to(device)
output = model(patch_output)
print(output.shape)

model = TransformerEncoder().to(device)
output = model(patch_output)
print(output.shape)

x = torch.randn(16, 1, 768).to(device)
model = ClassificationHead().to(device)
output = model(x)
print(output.shape)

x = torch.randn(16,3,224,224).to(device)
model = ViT().to(device)
output = model(x)
print(output.shape)

model = ViT().to(device)

torch.Size([16, 197, 768])
torch.Size([16, 197, 768])
torch.Size([16, 1, 128])
torch.Size([16, 197, 768])
torch.Size([16, 197, 768])
torch.Size([16, 1000])
torch.Size([16, 1000])


In [None]:
# Train
loss_fn = nn.CrossEntropyLoss(reduction = 'sum')
optimizer = optim.Adam(model.parameters(), lr = .01)

lr_scheduler = ReduceLROnPlateau(optimizer, mode = 'min', factor = .1, patience = 10)
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
def metric_batch(output, target):
    pred = output.argmax(1, keepdim = True)
    corrects = pred.eq(target.view_as(pred)).sum().item()
    return corrects

def loss_batch(loss_fn, output, target, optimizer = None):
    loss_b = loss_fn(output, target)
    metric_b = metric_batch(output, target)
    
    if optimizer is not None:
        optimizer.zero_grad()
        loss_b.backward()
        optimizer.step()
        
    return loss_b.item(), metric_b
    
def loss_epoch(model, loss_fn, dataset_dl, sanity_check = False, optimizer = None):
    running_loss = 0.0
    running_metric = 0.0
    len_data = len(dataset_dl.dataset)
    
    for xb, yb in dataset_dl:
        xb = xb.to(device)
        yb = yb.to(device)
        output = model(xb)
        
        loss_b, metric_b = loss_batch(loss_fn, output, yb, optimizer)
        running_loss += loss_b
        
        if metric_b is not None:
            running_metric += metric_b
        if sanity_check is True:
            break
        
    loss = running_loss / len_data
    metric = running_metric / len_data
    return loss, metric
def train_val(model, params):
    num_epochs=params['num_epochs']
    loss_func=params['loss_func']
    opt=params['optimizer']
    train_dl=params['train_dl']
    val_dl=params['val_dl']
    sanity_check=params['sanity_check']
    lr_scheduler=params['lr_scheduler']
    path2weights=params['path2weights']

    loss_history = {'train': [], 'val': []}
    metric_history = {'train': [], 'val': []}

    best_loss = float('inf')
    best_model_wts = copy.deepcopy(model.state_dict())
    start_time = time.time()

    for epoch in range(num_epochs):
        current_lr = get_lr(opt)
        print('Epoch {}/{}, current lr= {}'.format(epoch, num_epochs-1, current_lr))

        model.train()
        train_loss, train_metric = loss_epoch(model, loss_func, train_dl, sanity_check, opt)
        loss_history['train'].append(train_loss)
        metric_history['train'].append(train_metric)

        model.eval()
        with torch.no_grad():
            val_loss, val_metric = loss_epoch(model, loss_func, val_dl, sanity_check)
        loss_history['val'].append(val_loss)
        metric_history['val'].append(val_metric)

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), path2weights)
            print('Copied best model weights!')

        lr_scheduler.step(val_loss)
        if current_lr != get_lr(opt):
            print('Loading best model weights!')
            model.load_state_dict(best_model_wts)

        print('train loss: %.6f, val loss: %.6f, accuracy: %.2f, time: %.4f min' %(train_loss, val_loss, 100*val_metric, (time.time()-start_time)/60))
        print('-'*10)

    model.load_state_dict(best_model_wts)
    return model, loss_history, metric_history
params_train = {
    'num_epochs':20,
    'optimizer':optimizer,
    'loss_func':loss_fn,
    'train_dl':train_dl,
    'val_dl':val_dl,
    'sanity_check':False,
    'lr_scheduler':lr_scheduler,
    'path2weights':'./models/weights.pt',
}

# check the directory to save weights.pt
def createFolder(directory):
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSerror:
        print('Error')
createFolder('./models')

In [None]:
model, loss_hist, metric_hist = train_val(model, params_train)

Epoch 0/19, current lr= 0.01
Copied best model weights!
train loss: 2.912359, val loss: 2.375173, accuracy: 14.84, time: 4.3038 min
----------
Epoch 1/19, current lr= 0.01
Copied best model weights!
train loss: 2.224996, val loss: 2.251253, accuracy: 14.22, time: 8.6805 min
----------
Epoch 2/19, current lr= 0.01
Copied best model weights!
train loss: 2.100978, val loss: 2.048102, accuracy: 16.18, time: 13.0611 min
----------
Epoch 3/19, current lr= 0.01
Copied best model weights!
train loss: 2.014524, val loss: 2.013837, accuracy: 18.21, time: 17.4371 min
----------
Epoch 4/19, current lr= 0.01
Copied best model weights!
train loss: 2.016816, val loss: 2.010707, accuracy: 19.61, time: 21.8034 min
----------
Epoch 5/19, current lr= 0.01
Copied best model weights!
train loss: 1.997171, val loss: 1.992174, accuracy: 20.44, time: 26.1699 min
----------
Epoch 6/19, current lr= 0.01
Copied best model weights!
train loss: 1.987082, val loss: 1.966063, accuracy: 21.76, time: 30.5523 min
-----