In [1]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames[:1]:
        print(os.path.join(dirname, filename))

/kaggle/input/cifar10/readme.txt
/kaggle/input/cifar10/cifar10/labels.txt
/kaggle/input/cifar10/cifar10/test/airplane/6221_airplane.png
/kaggle/input/cifar10/cifar10/test/horse/3009_horse.png
/kaggle/input/cifar10/cifar10/test/truck/9692_truck.png
/kaggle/input/cifar10/cifar10/test/automobile/4067_automobile.png
/kaggle/input/cifar10/cifar10/test/ship/8546_ship.png
/kaggle/input/cifar10/cifar10/test/dog/2915_dog.png
/kaggle/input/cifar10/cifar10/test/bird/1043_bird.png
/kaggle/input/cifar10/cifar10/test/frog/4119_frog.png
/kaggle/input/cifar10/cifar10/test/cat/2585_cat.png
/kaggle/input/cifar10/cifar10/test/deer/6012_deer.png
/kaggle/input/cifar10/cifar10/train/airplane/29606_airplane.png
/kaggle/input/cifar10/cifar10/train/horse/46129_horse.png
/kaggle/input/cifar10/cifar10/train/truck/42362_truck.png
/kaggle/input/cifar10/cifar10/train/automobile/19301_automobile.png
/kaggle/input/cifar10/cifar10/train/ship/32578_ship.png
/kaggle/input/cifar10/cifar10/train/dog/38840_dog.png
/kaggle/

In [2]:
import numpy as np 
import pandas as pd 
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
from typing import Tuple, Union

In [3]:
IMG_SIZE = 224
BS = 16
PS = 16
C = 3
N = (IMG_SIZE**2)//PS**2
D_MODEL = PS**2*C
N_ENCODER_BLOCKS = 4
N_HEADS = 4

In [4]:
# os.listdir()
root_path = "/kaggle/input/cifar10/cifar10"

In [5]:
with Image.open(f"{root_path}/train/airplane/29606_airplane.png") as img:
    image_tensor = torch.tensor(np.asarray(img))
    print(image_tensor.shape)

torch.Size([32, 32, 3])


In [6]:
transform = transforms.Compose([
    transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])

train_ds = ImageFolder(root=f"{root_path}/train",
                        transform=transform)
test_ds = ImageFolder(root=f"{root_path}/test",
                     transform=transform)

In [7]:
train_ds.class_to_idx, test_ds.class_to_idx

({'airplane': 0,
  'automobile': 1,
  'bird': 2,
  'cat': 3,
  'deer': 4,
  'dog': 5,
  'frog': 6,
  'horse': 7,
  'ship': 8,
  'truck': 9},
 {'airplane': 0,
  'automobile': 1,
  'bird': 2,
  'cat': 3,
  'deer': 4,
  'dog': 5,
  'frog': 6,
  'horse': 7,
  'ship': 8,
  'truck': 9})

In [8]:
classes = train_ds.classes

train_dl = DataLoader(dataset=train_ds,
                     batch_size=BS,
                     shuffle=True)
test_dl = DataLoader(dataset=test_ds,
                    batch_size=BS,
                    shuffle=False)

## Embeddings Layer

In [9]:
with Image.open(f"{root_path}/train/airplane/29606_airplane.png") as img:
    sample_img = transform(img)

In [10]:
sample_img.shape

torch.Size([3, 224, 224])

CxHxW -> Nx(PS^2xC)

In [11]:
N, PS**2*C

(196, 768)

In [12]:
patch_conv_layer = nn.Conv2d(in_channels=3,
                            out_channels=D_MODEL,
                            kernel_size=PS,
                            stride=PS)

patch_conv_layer

Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))

In [13]:
img_conv = patch_conv_layer(sample_img)
img_conv.shape

torch.Size([768, 14, 14])

In [14]:
nn.Flatten()(img_conv).permute((1, 0)).shape

torch.Size([196, 768])

In [15]:
class PatchEmbedding(nn.Module):
    
    def __init__(self,
                channels: int=C,
                d_model: int=D_MODEL,
                patch_size: int=PS) -> None:
        
        super().__init__()
        
        self.conv_layer = nn.Conv2d(
            in_channels=channels,
            out_channels = d_model,
            kernel_size = patch_size,
            stride = patch_size
        )
        
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)
    
    def forward(self, image: torch.Tensor) -> torch.Tensor:
        
        # image -> B, C, IMG_SIZE, IMG_SIZE
        patch_embedding = self.flatten(self.conv_layer(image)).permute((0, 2, 1))
        return patch_embedding

class Embedding(nn.Module):
    
    def __init__(self,
                channels: int=C,
                d_model: int=D_MODEL,
                patch_size: int=PS,
                num_patches: int=N) -> None:
        
        super().__init__()
        self.num_patches = num_patches
        self.patch_embedding_layer = PatchEmbedding(channels=channels, d_model=d_model, patch_size=patch_size)
        self.class_token_embedding = nn.Parameter(
            data = torch.randn(size=(1, 1, d_model)),
            requires_grad = True
        )
        self.positional_embedding = nn.Parameter(
            data = torch.randn(size=(1, num_patches+1, d_model)),
            requires_grad = True
        )
    
    def forward(self, image: torch.Tensor) -> torch.Tensor:
        
        # image -> B, C, H, W
        BS = image.shape[0]
        embedding = torch.cat((self.class_token_embedding.expand(BS, 1, -1), self.patch_embedding_layer(image)), dim=1) + \
                    self.positional_embedding # B, N+1, PS^2*C
        return embedding

In [16]:
images, labels = next(iter(train_dl))
images.shape

torch.Size([16, 3, 224, 224])

In [24]:
embedding_layer = Embedding()
embeddings  = embedding_layer(images)
embeddings.shape, labels.shape

(torch.Size([16, 197, 768]), torch.Size([16]))

## MHA Block

In [25]:
class MHABlock(nn.Module):
    
    def __init__(self, 
                d_model:int=D_MODEL,
                n_heads:int=N_HEADS) -> None:
        
        super().__init__()
        self.mha_layer = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            batch_first=True,
        )
        self.ln_layer = nn.LayerNorm(normalized_shape=d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        attn_output, _ = self.mha_layer(x, x, x)
        return self.ln_layer(x + attn_output)

## MLP Block

In [26]:
class MLPBlock(nn.Module):
    
    def __init__(self,
                d_model:int=D_MODEL) -> None:
        
        super().__init__()
        self.mlp_sub_block = nn.Sequential(
            nn.Linear(in_features=d_model, out_features=d_model*4),
            nn.GELU(),
            nn.Linear(in_features=d_model*4, out_features=d_model)
        )
        self.ln_layer = nn.LayerNorm(normalized_shape=d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        return self.ln_layer(x + self.mlp_sub_block(x))

## Encoder Block

In [27]:
class EncoderBlock(nn.Module):
    
    def __init__(self, 
                d_model:int=D_MODEL,
                n_heads:int=N_HEADS) -> None:
        
        super().__init__()
        self.mha_block = MHABlock(d_model=d_model, n_heads=n_heads)
        self.mlp_block = MLPBlock(d_model=d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        return self.mlp_block(self.mha_block(x))

In [28]:
class Encoder(nn.Module):
    
    def __init__(self, 
                n_encoder_blocks:int=N_ENCODER_BLOCKS,
                d_model:int=D_MODEL,
                n_heads:int=N_HEADS) -> None:
        
        super().__init__()
        self.encoder_blocks = nn.ModuleList([EncoderBlock(d_model=d_model, n_heads=n_heads) for _ in range(n_encoder_blocks)])
    
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        
        for block in self.encoder_blocks:
            x = block(x)
        
        return x

## ViT

In [29]:
class ViT(nn.Module):
    
    def __init__(self,
                 channels: int=C,
                 d_model: int=D_MODEL,
                 patch_size: int=PS,
                 num_patches: int=N,
                 n_heads: int=N_HEADS,
                 n_encoder_blocks: int=N_ENCODER_BLOCKS,
                 n_classes: int=len(classes)) -> None:
        
        
        super().__init__()
        
        self.embedding_layer = Embedding(channels=C,
                                      d_model=d_model,
                                      patch_size=patch_size,
                                      num_patches=num_patches)
        self.encoder = Encoder(n_encoder_blocks=n_encoder_blocks,
                              d_model=d_model,
                              n_heads=n_heads)
        self.classification_layer = nn.Linear(in_features=d_model, out_features=n_classes)
    
    def forward(self,
               image: torch.Tensor,
               label: torch.Tensor=None) -> Tuple[torch.Tensor]:
        
        embeddings = self.embedding_layer(image)
        encoder_output = self.encoder(embeddings) # B, NUM_PATCHES+1, D_MODEL
        logits = self.classification_layer(encoder_output[:, 0, :]) # B, NUM_CLASSES
        
        loss = None
        if label is not None:
            loss = F.cross_entropy(logits, label)
        
        return logits, loss

## Training

In [30]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [31]:
vit = ViT().to(device)
optimizer = torch.optim.AdamW(params=vit.parameters(),
                             lr=3e-4,
                             weight_decay=1e-3)

In [32]:
def train_step() -> Tuple[float]:
    
    vit.train()
    total_loss = 0
    n_correct = 0
    n_total = 0
    for image, label in train_dl:
        image, label = image.to(device), label.to(device)
        logits, loss = vit(image, label)
        n_correct += (torch.argmax(logits, dim=-1) == label).float().sum(0).item()
        n_total += image.shape[0]
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss/len(train_dl), n_correct/n_total

@torch.inference_mode()
def eval_step() -> Tuple[float]:
    
    vit.eval()
    total_loss = 0
    n_correct = 0
    n_total = 0
    for image, label in test_dl:
        image, label = image.to(device), label.to(device)
        logits, loss = vit(image, label)
        n_correct += (torch.argmax(logits, dim=-1) == label).float().sum(0).item()
        n_total += image.shape[0]
        total_loss += loss.item()
    
    return total_loss/len(test_dl), n_correct/n_total

In [33]:
def train(epochs: int=2) -> None:
    
    vit.train()
    for epoch in range(1, epochs+1):
        
        train_loss, train_acc = train_step()
        eval_loss, eval_acc = eval_step()
        print(f"""
Epoch: {epoch}/{epochs}
              train_loss: {train_loss:.4f} train_acc: {train_acc:.4f}
              eval_loss:  {eval_loss:.4f}  eval_acc:  {eval_acc:.4f}
""")

In [None]:
train(2)