In [27]:
import torch
from torch import nn, optim
from torch.utils.data import random_split, DataLoader
from torchinfo import summary
from torchvision import datasets, transforms, models

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cuda


## Implementatation

In [147]:
class SelfAttention(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()
        self.input_dim = input_dim
        
        self.Q = nn.Linear(input_dim, embedding_dim)
        self.K = nn.Linear(input_dim, embedding_dim)
        self.V = nn.Linear(input_dim, embedding_dim)
    
    def forward(self, x_q, x_k, x_v, mask=None):
        q = self.Q(x_q)
        k = self.K(x_k)
        v = self.V(x_v)
        
        a = (q @ k.permute(0, 2, 1)) / (self.input_dim ** 0.5)
        if mask is not None:
            a[:, mask] = -1e10
        x = torch.softmax(a, dim=2) @ v
        return x


class MultiheadSelfAttention(nn.Module):
    def __init__(self, input_dim, n_heads):
        super().__init__()
        self.attention_heads = nn.ModuleList([SelfAttention(input_dim, input_dim // n_heads) for _ in range(n_heads)])
        self.linear = nn.Linear(input_dim, input_dim)
    
    def forward(self, x_q, x_k, x_v, mask=None):
        x = torch.concat([attention_head(x_q, x_k, x_v, mask) for attention_head in self.attention_heads], dim=-1)
        x = self.linear(x)
        return x


x_batch = torch.randn(4, 100, 512)
mask = torch.ones(100, 100).tril() == 0

print(MultiheadSelfAttention(512, 16)(x_batch, x_batch, x_batch, mask=mask).shape)
print(nn.MultiheadAttention(512, 16, batch_first=True)(x_batch, x_batch, x_batch, need_weights=False, attn_mask=mask)[0].shape)

torch.Size([4, 100, 512])
torch.Size([4, 100, 512])


In [148]:
class MSABlock(nn.Module):
    def __init__(self, input_dim, n_heads, dropout, torch_msa=True):
        super().__init__()
        self.torch_msa = torch_msa
        
        self.msa = MultiheadSelfAttention(input_dim, n_heads) if torch_msa \
            else nn.MultiheadAttention(input_dim, n_heads, batch_first=True)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x_q, x_k, x_v, mask=None):
        x = self.msa(x_q, x_k, x_v, mask=mask) if self.torch_msa \
            else self.msa(x_q, x_k, x_v, need_weights=False, attn_mask=mask)[0] # mask
        x = self.dropout(x)
        return x


class MLPBlock(nn.Module):
    def __init__(self, input_dim, mlp_dim, dropout):
        super().__init__()
        
        self.linear1 = nn.Linear(input_dim, mlp_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(mlp_dim, input_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x

In [149]:
class EncoderBlock(nn.Module):
    def __init__(self, input_dim, mlp_dim, n_heads, dropout):
        super().__init__()
        self.msa_block = MSABlock(input_dim, n_heads, dropout)
        self.layer_norm1 = nn.LayerNorm(input_dim)
        self.mlp_block = MLPBlock(input_dim, mlp_dim, dropout)
        self.layer_norm2 = nn.LayerNorm(input_dim)
    
    def forward(self, x):
        x = self.msa_block(x, x, x) + x
        x = self.layer_norm1(x)
        x = self.mlp_block(x) + x
        x = self.layer_norm2(x)
        return x


class Encoder(nn.Module):
    def __init__(self, input_dim, mlp_dim, n_heads, dropout, n_layers):
        super().__init__()
        self.encoder_layers = nn.Sequential(
            *[EncoderBlock(input_dim, mlp_dim, n_heads, dropout) for _ in range(n_layers)]
        )
    
    def forward(self, x):
        x = self.encoder_layers(x)
        return x


# summary(Encoder(512, 2048, 8, 0.1, 6), input_size=(1, 500, 512), device='cpu', col_names=['output_size', 'num_params', 'mult_adds'], col_width=15)

In [150]:
class DecoderBlock(nn.Module):
    def __init__(self, input_dim, mlp_dim, n_heads, dropout):
        super().__init__()
        self.masked_msa_block = MSABlock(input_dim, n_heads, dropout)
        self.layer_norm1 = nn.LayerNorm(input_dim)
        self.enc_msa_block = MSABlock(input_dim, n_heads, dropout)
        self.layer_norm2 = nn.LayerNorm(input_dim)
        self.mlp_block = MLPBlock(input_dim, mlp_dim, dropout)
        self.layer_norm3 = nn.LayerNorm(input_dim)
    
    def forward(self, x, enc_out, mask=None):
        x = self.masked_msa_block(x, x, x) + x
        x = self.layer_norm1(x)
        x = self.enc_msa_block(x, enc_out, enc_out, mask=mask) + x # mask!
        x = self.layer_norm2(x)
        x = self.mlp_block(x) + x
        x = self.layer_norm3(x)
        return x


x_batch = torch.randn(4, 100, 512)
print(DecoderBlock(512, 2048, 8, 0.1)(x_batch, x_batch).shape)

torch.Size([4, 100, 512])


In [151]:
class Transformer(nn.Module):
    def __init__(self, seq_len, word_dim, embedding_dim, mlp_dim, n_heads, n_layers, n_classes, dropout):
        super().__init__()
        self.seq_len = seq_len
        self.embedding_dim = embedding_dim
        
        self.input_embedding = nn.Linear(word_dim, embedding_dim)
        self.output_embedding = nn.Linear(word_dim, embedding_dim)
        self.positional_encoding = self.create_positional_encoding()
        
        self.encoder_layers = Encoder(embedding_dim, mlp_dim, n_heads, dropout, n_layers)
        self.decoder_layers = nn.ModuleList([
            DecoderBlock(embedding_dim, mlp_dim, n_heads, dropout) for _ in range(n_layers)
        ])
        
        self.classification_head = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, n_classes)
        )
    
    def create_positional_encoding(self):
        pos = torch.arange(self.seq_len).unsqueeze(1)
        denominator = 10000 ** (torch.arange(0, self.embedding_dim, 2) / self.embedding_dim)
        
        pos_encoding = torch.zeros(seq_len, self.embedding_dim)
        pos_encoding[:, 0::2] = torch.sin(pos / denominator)
        pos_encoding[:, 1::2] = torch.cos(pos / denominator)
        return pos_encoding
    
    def create_attention_mask(self):
        mask = torch.ones(self.seq_len, self.seq_len).tril() == 0
        return mask
    
    def forward(self, input_seq, output_seq):
        input_embedding = self.input_embedding(input_seq) + self.positional_encoding
        output_embedding = self.output_embedding(output_seq) + self.positional_encoding
        
        encoder_out = self.encoder_layers(input_embedding)
        
        mask = self.create_attention_mask()
        decoder_out = output_embedding
        for decoder_layer in self.decoder_layers:
            decoder_layer(decoder_out, encoder_out, mask)
        logits = self.classification_head(decoder_out)
        
        return logits


# summary(Transformer(100, 1000, 512, 2048, 8, 6, 1000, 0.1), input_size=[(1, 100, 1000), (1, 100, 1000)], device='cpu', col_names=['output_size', 'num_params', 'mult_adds'], col_width=13)

In [152]:
# seq_len, word_dim, embedding_dim, mlp_dim, n_heads, n_layers, n_classes, dropout

seq_len = 100
word_dim = 1000
embedding_dim = 512
mlp_dim = embedding_dim*4
n_heads = 8
n_layers = 6
dropout = 0.1

transformer_model = Transformer(seq_len, word_dim, embedding_dim, mlp_dim, n_heads, n_layers, 1000, dropout)
summary(transformer_model, input_size=[(1, 100, 1000), (1, 100, 1000)], device='cpu', col_names=['output_size', 'num_params', 'mult_adds'], depth=3)

Layer (type:depth-idx)                                  Output Shape              Param #                   Mult-Adds
Transformer                                             [1, 100, 1000]            --                        --
├─Linear: 1-1                                           [1, 100, 512]             512,512                   512,512
├─Linear: 1-2                                           [1, 100, 512]             512,512                   512,512
├─Encoder: 1-3                                          [1, 100, 512]             --                        --
│    └─Sequential: 2-1                                  [1, 100, 512]             --                        --
│    │    └─EncoderBlock: 3-1                           [1, 100, 512]             3,152,384                 3,152,384
│    │    └─EncoderBlock: 3-2                           [1, 100, 512]             3,152,384                 3,152,384
│    │    └─EncoderBlock: 3-3                           [1, 100, 512]            

In [16]:
transformer_torch_model = nn.Transformer(embedding_dim, n_heads, n_layers, n_layers, mlp_dim, dropout)
summary(transformer_torch_model, input_size=[(1, 100, 512), (1, 100, 512)], device='cpu', col_names=['output_size', 'num_params', 'mult_adds'], depth=3)

Layer (type:depth-idx)                             Output Shape              Param #                   Mult-Adds
Transformer                                        [1, 100, 512]             --                        --
├─TransformerEncoder: 1-1                          [1, 100, 512]             --                        --
│    └─ModuleList: 2-1                             --                        --                        --
│    │    └─TransformerEncoderLayer: 3-1           [1, 100, 512]             3,152,384                 2,101,760
│    │    └─TransformerEncoderLayer: 3-2           [1, 100, 512]             3,152,384                 2,101,760
│    │    └─TransformerEncoderLayer: 3-3           [1, 100, 512]             3,152,384                 2,101,760
│    │    └─TransformerEncoderLayer: 3-4           [1, 100, 512]             3,152,384                 2,101,760
│    │    └─TransformerEncoderLayer: 3-5           [1, 100, 512]             3,152,384                 2,101,760
│   

# Training

In [8]:
from pathlib import Path

TRAIN_RATIO = 0.8
data_dir = Path('./data/')

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform)
train_ds, val_ds = random_split(train_ds, (TRAIN_RATIO, 1 - TRAIN_RATIO))
val_ds.transform = transform
test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
import wandb
from src.engine import *

config = dict(batch_size=32, lr=3e-4, epochs=20, dataset='CIFAR100')
with wandb.init(project='pytorch-study', name='ViT', config=config) as run:
    w_config = run.config
    train_dl = DataLoader(train_ds, batch_size=w_config.batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=w_config.batch_size, shuffle=True)
    
    n_classes = len(train_ds.dataset.classes)
    vit_model = ViT(224, 3, patch_size, embedding_dim, mlp_dim, n_heads, dropout, n_layers, n_classes).to(DEVICE)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(vit_model.parameters(), lr=w_config.lr)
    
    loss_history, acc_history = train(vit_model, train_dl, val_dl, criterion, optimizer, w_config.epochs, DEVICE, run) 