<a href="https://colab.research.google.com/github/JAZ201107/PyTorch-DL/blob/main/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np

In [None]:
import matplotlib.pyplot as plt

# Vision Transformer Model

In [None]:
from dataclasses import dataclass


@dataclass
class Config:
    image_size: int = 32
    patch_size: int = 4
    num_channels: int = 3
    num_heads:int = 4
    hidden_size: int = 48
    num_classes: int = 10
    num_layers: int = 4
    ffn_hidden_size: int = 48 * 4
    dropout: float = 0.1

    device = "cuda" if torch.cuda.is_available() else "cpu"
    epochs = 30

    pretrain_model = None

## Patch Embedding

In [None]:
class PatchEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.num_patches = (config.image_size // config.patch_size) ** 2

        self.projection = nn.Conv2d(
            config.num_channels,
            config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
        )

    def forward(self, x):
        # (B, C, H, W) -> (B, hidden_size, H // patch_size, W // patch_size)
        x = self.projection(x)
        # (B, hidden_size, H // patch_size, W // patch_size) -> (B, hidden_size, num_patches)
        x = x.flatten(2)
        # (B, hidden_size, num_patches) -> (B, num_patches, hidden_size)
        return x.transpose(1, 2)

In [None]:
class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.patch_embedding = PatchEmbeddings(config)
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        self.num_patches = (config.image_size // config.patch_size) ** 2
        self.position_embeddings = nn.Parameter(
            torch.randn(1, self.num_patches + 1, config.hidden_size)
        )

        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        # (B, C, H, W) -> (B, num_patches, hidden_size)
        x = self.patch_embedding(x)
        # (1, 1, hidden_size) -> (B, 1, hidden_size)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        # (B, num_patches + 1, hidden_size)
        x = torch.cat((cls_tokens, x), dim=1)

        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

## Multi Headed Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        assert (self.hidden_size % self.num_heads) == 0
        self.qkv_lin = nn.Linear(config.hidden_size, config.hidden_size * 3)
        self.output_projection = nn.Linear(config.hidden_size, config.hidden_size)
        self.output_dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        B, N, D = x.shape
        qkv = self.qkv_lin(x)

        q, k, v = torch.chunk(qkv, 3, dim=-1)

        q = q.view(B, N, self.num_heads, -1).transpose(1, 2)  # (B, num_heads, num_patches + 1, head_dim)
        k = k.view(B, N, self.num_heads, -1).transpose(1, 2)
        v = v.view(B, N, self.num_heads, -1).transpose(1, 2)

        # Calculate Attention
        attn = torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1)**0.5)
        attn = F.softmax(attn, dim=-1)

        # Calculate weighted sum
        logits = torch.matmul(attn, v)
        logits = logits.transpose(1, 2).contiguous().view(B, N, -1)

        return self.output_dropout(self.output_projection(logits))

In [None]:
class GELUActivation(nn.Module):
    def forward(self, input):
        return (
            0.5
            * input
            * (
                1.0
                + torch.tanh(
                    np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))
                )
            )
        )

## MLP

In [None]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.fc1 = nn.Linear(config.hidden_size, config.ffn_hidden_size)
        self.fc2 = nn.Linear(config.ffn_hidden_size, config.hidden_size)

        self.activation = GELUActivation()
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.attention = MultiHeadAttention(config)
        self.mlp = MLP(config)

        self.ln1 = nn.LayerNorm(config.hidden_size)
        self.ln2 = nn.LayerNorm(config.hidden_size)

    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.layers = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.num_layers)]
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

## Projection Head

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.linear = nn.Linear(config.hidden_size, config.num_classes)

    def forward(self, x):
        return self.linear(x[:, 0])  # return only the [CLS] token

## ViT Model

In [None]:
class ViT(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embeddings = Embeddings(config)
        self.encoder = TransformerEncoder(config)
        self.projection_head = ProjectionHead(config)

    def forward(self, x):
        x = self.embeddings(x)
        x = self.encoder(x)
        x = self.projection_head(x)
        return x

In [None]:
#  test
model = ViT(Config())
x = torch.randn(8, 3, 32, 32)
out = model(x)

assert out.shape == (8, 10)

# Prepare Data


In [None]:
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets

from torch.utils.data import DataLoader

In [None]:
transform = transforms.Compose(
        [
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
        ]
    )
trainset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform= transform)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:08<00:00, 20.7MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [None]:
test_loader = DataLoader(
        trainset,
        batch_size=20,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

In [None]:
batch = next(iter(test_loader))

In [None]:
batch[1]

tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6])

In [None]:
def prepare_data(
    batch_size: int,
    image_size: int,
    num_workers: int = 4,
    pin_memory: bool = True,
):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ]
    )

    train_dataset = datasets.CIFAR10(
        root="data", train=True, download=True, transform=transform
    )
    test_dataset = datasets.CIFAR10(
        root="data", train=False, download=True, transform=transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    classes = (
        "plane",
        "car",
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck",
    )

    dataloaders = {
        'train': train_loader,
        'val': test_loader
    }

    return dataloaders, classes

# Utils Functions

Utils functions include show images, save models


In [None]:
def save_checkpoint(model, optimizer, filename="best_model.pth.tar"):
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(model, optimizer, filename="best_model.pth.tar"):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

In [None]:
def visualize_images():
    trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True)
    classes = (
        "plane",
        "car",
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck",
    )

    # Pick 30 random images
    indices = torch.randperm(len(trainset))[:30]
    images = [np.asarray(trainset[i][0]) for i in indices]
    labels = [classes[trainset[i][1]] for i in indices]

    fig = plt.figure(figsize=(15, 15))
    for i in range(30):
        ax = fig.add_subplot(6, 5, i + 1, xticks=[], yticks=[])
        ax.imshow(images[i])
        ax.set_title(labels[i])

# Start Training


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
from tqdm.autonotebook import tqdm

  from tqdm.autonotebook import tqdm


In [None]:
class ViTTrainer:
    def __init__(self, model, dataloaders, optimizer, criterion, config):
        self.model = model.to(config.device)
        self.optimizer = optimizer
        self.criterion = criterion
        self.dataloaders = dataloaders
        self.device = config.device
        self.config = config

        self.train_losses = []
        self.val_losses = []
        self.train_acc = []
        self.val_acc = []

        if config.pretrain_model is None:
            self.model.apply(self._init_weight)

    def _init_weight(self, model):
        for m in self.model.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')  # He initialization
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)  # LayerNorm weights initialized to ones
                nn.init.zeros_(m.bias)

    def train_one_epoch(self):
        self.model.train()
        total_loss = 0
        total_correct = 0
        with tqdm(self.dataloaders["train"]) as t:
            t.set_description(desc='Training', refresh=False)
            for batch in self.dataloaders["train"]:
                images, labels = batch
                images, labels = images.to(self.device), labels.to(self.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, labels)

                predictions = outputs.argmax(dim=-1)
                total_correct += (predictions == labels).sum().item()

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()

                t.set_postfix(loss=total_loss / (t.n + 1), acc = total_correct / (len(labels) * (t.n + 1)))
                t.update()

            accuracy = total_correct / len(self.dataloaders["train"].dataset)
            return total_loss / len(self.dataloaders["train"]), accuracy


    @torch.no_grad()
    def evaluate(self):
        self.model.eval()

        total_loss = 0
        total_correct = 0
        with tqdm(self.dataloaders["val"]) as t:
            t.set_description(desc='Evaluation', refresh=False)
            for batch in self.dataloaders["val"]:
                images, labels = batch
                images, labels = images.to(self.device), labels.to(self.device)

                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                total_loss += loss.item()

                predictions = outputs.argmax(dim=-1)
                total_correct += (predictions == labels).sum().item()

                t.set_postfix(
                    loss=total_loss / (t.n + 1),
                    acc = total_correct / (len(labels) * (t.n + 1))
                    )
                t.update()

        accuracy = total_correct / len(self.dataloaders["val"].dataset)
        return total_loss / len(self.dataloaders["val"]), accuracy

    def train(self):
        best_accuracy = 0
        if self.config.pretrain_model is None:
            print("Training from Scratch")

        for epoch in range(self.config.epochs):
            train_loss, train_accuracy = self.train_one_epoch()
            val_loss, val_accuracy = self.evaluate()

            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_acc.append(train_accuracy)
            self.val_acc.append(val_accuracy)

            print(
                f"Epoch: {epoch + 1}/{self.config.epochs}, Train Loss: {train_loss}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
            )

            if val_accuracy > best_accuracy:
                print("Find Better Model")
                best_accuracy = val_accuracy
                save_checkpoint(self.model, self.optimizer)
                print("Saved Better Model")

        print(f"Best accuracy: {best_accuracy}")

In [None]:
model = ViT(Config)

dataloaders, classes  = prepare_data(
    batch_size = 64,
    image_size = 32,
)

optimizer = optim.Adam(model.parameters(), lr = 0.01, weight_decay=1e-2)
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified




In [None]:
trainier = ViTTrainer(
    model = model,
    dataloaders = dataloaders,
    optimizer = optimizer,
    criterion = criterion,
    config = Config
)

In [None]:
trainier.train()

Training from Scratch


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 1/30, Train Loss: 2.1774021123376346, Val Loss: 2.005177462936207, Val Accuracy: 0.2335
Find Better Model
Saved Better Model


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 2/30, Train Loss: 1.9724548905706771, Val Loss: 1.9618319227437304, Val Accuracy: 0.2648
Find Better Model
Saved Better Model


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 3/30, Train Loss: 1.9122105251492747, Val Loss: 2.035239169552068, Val Accuracy: 0.2638


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 4/30, Train Loss: 1.8774781842975665, Val Loss: 1.931493999851737, Val Accuracy: 0.2718
Find Better Model
Saved Better Model


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 5/30, Train Loss: 1.8649314605366543, Val Loss: 1.950323467801331, Val Accuracy: 0.2841
Find Better Model
Saved Better Model


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 6/30, Train Loss: 1.8527543973130034, Val Loss: 1.8383961240197444, Val Accuracy: 0.317
Find Better Model
Saved Better Model


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 7/30, Train Loss: 1.8588098432401867, Val Loss: 1.8088606368204592, Val Accuracy: 0.3261
Find Better Model
Saved Better Model


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 8/30, Train Loss: 1.8438994713756434, Val Loss: 1.7646391839738105, Val Accuracy: 0.3288
Find Better Model
Saved Better Model


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 9/30, Train Loss: 1.8204318230109446, Val Loss: 1.7791368186853493, Val Accuracy: 0.322


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 10/30, Train Loss: 1.8184429076321595, Val Loss: 1.8072660523615065, Val Accuracy: 0.3082


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 11/30, Train Loss: 1.8193548698254558, Val Loss: 1.7555700115337494, Val Accuracy: 0.3284


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 12/30, Train Loss: 1.8157779961595755, Val Loss: 1.7625967613451041, Val Accuracy: 0.3268


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 13/30, Train Loss: 1.819080479919453, Val Loss: 1.7659614701179942, Val Accuracy: 0.3304
Find Better Model
Saved Better Model


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 14/30, Train Loss: 1.8167556312382984, Val Loss: 1.8147197809948283, Val Accuracy: 0.3062


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 15/30, Train Loss: 1.8213676734043813, Val Loss: 1.7481985722377802, Val Accuracy: 0.3323
Find Better Model
Saved Better Model


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 16/30, Train Loss: 1.81746095906743, Val Loss: 1.7902821651689567, Val Accuracy: 0.301


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 17/30, Train Loss: 1.817449292715858, Val Loss: 1.8344933827211902, Val Accuracy: 0.3069


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

Epoch: 18/30, Train Loss: 1.8132387674068247, Val Loss: 1.7482054127249749, Val Accuracy: 0.3387
Find Better Model
Saved Better Model


  0%|          | 0/782 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

In [None]:
plt.figure(figsize=(20, 10))
trainset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True)
for index in range(40):
    image, label = trainset["test"][index]

    # Model inference
    model.eval()
    with torch.inference_mode():
        pred = model(image.unsqueeze(dim=0).to(device))
        pred = pred.argmax(dim=1)

    # Convert from CHW to HWC for visualization
    image = image.permute(1, 2, 0)

    # Convert from class indices to class names
    pred = trainset["test"].classes[pred]
    label = trainset["test"].classes[label]

    # Visualize the image
    plt.subplot(4, 10, index + 1)
    plt.imshow(image)
    plt.title(f"pred: {pred}" + "\n" + f"label: {label}")
    plt.axis("off")
plt.show()