In [None]:
import math
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch import nn
from dataclasses import dataclass
import torch.nn.functional as F

In [None]:
# Define a configuration for the model using a data class
@dataclass
class ModelArgs:
    dim: int = 256          # Dimension of the model embeddings
    hidden_dim: int = 512   # Dimension of the hidden layers
    n_heads: int = 8        # Number of attention heads
    n_layers: int = 6       # Number of layers in the transformer
    patch_size: int = 4     # Size of the patches (typically square)
    n_channels: int = 3     # Number of input channels (e.g., 3 for RGB images)
    n_patches: int = 64     # Number of patches in the input
    n_classes: int = 2     # Number of target classes   -2
    dropout: float = 0.2    # Dropout rate for regularization


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads

        # Linear projections for Q, K, and V
        self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)

    def forward(self, x):
        b, seq_len, dim = x.shape  # b: batch size, seq_len: sequence length

        assert dim == self.dim, "dim is not matching"

        q = self.wq(x)  # [b, seq_len, n_heads*head_dim]
        k = self.wk(x)  # [b, seq_len, n_heads*head_dim]
        v = self.wv(x)  # [b, seq_len, n_heads*head_dim]

        # Reshape the tensors for multi-head operations
        q = q.contiguous().view(b, seq_len, self.n_heads, self.head_dim)  # [b, seq_len, n_heads, head_dim]
        k = k.contiguous().view(b, seq_len, self.n_heads, self.head_dim)  # [b, seq_len, n_heads, head_dim]
        v = v.contiguous().view(b, seq_len, self.n_heads, self.head_dim)  # [b, seq_len, n_heads, head_dim]

        # Transpose to bring the head dimension to the front
        q = q.transpose(1, 2)  # [b, n_heads, seq_len, head_dim]
        k = k.transpose(1, 2)  # [b, n_heads, seq_len, head_dim]
        v = v.transpose(1, 2)  # [b, n_heads, seq_len, head_dim]

        # Compute attention scores and apply softmax
        attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)  # [b, n_heads, seq_len, seq_len]
        attn_scores = F.softmax(attn, dim=-1)  # [b, n_heads, seq_len, seq_len]

        # Compute the attended features
        out = torch.matmul(attn_scores, v)  # [b, n_heads, seq_len, head_dim]
        out = out.contiguous().view(b, seq_len, -1)  # [b, seq_len, n_heads*head_dim]

        return self.wo(out), attn_scores  # Return both output and attention scores


In [None]:
class AttentionBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(args.dim)
        self.attn = MultiHeadAttention(args)

        self.layer_norm_2 = nn.LayerNorm(args.dim)

        self.ffn = nn.Sequential(
            nn.Linear(args.dim, args.hidden_dim),
            nn.GELU(),
            nn.Dropout(args.dropout),
            nn.Linear(args.hidden_dim, args.dim),
            nn.Dropout(args.dropout)
        )

    def forward(self, x):
        attn_weights = self.attn(self.layer_norm_1(x))
        x = x + attn_weights[0]  # Adding the attention weights output
        x_ffn = self.ffn(self.layer_norm_2(x))
        x = x + x_ffn
        return x, attn_weights[1]  # Returning both the output and attention weights


In [None]:
def img_to_patch(x, patch_size, flatten_channels=True):
    # x: Input image tensor
    # B: Batch size, C: Channels, H: Height, W: Width
    B, C, H, W = x.shape  # (B, C, H, W)

    # Reshape the image tensor to get non-overlapping patches
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)  # (B, C, H/patch_size, patch_size, W/patch_size, patch_size)

    # Permute to group the patches and channels
    x = x.permute(0, 2, 4, 1, 3, 5)  # (B, H/patch_size, W/patch_size, C, patch_size, patch_size)

    # Flatten the height and width dimensions for patches
    x = x.flatten(1,2)  # (B, (H/patch_size * W/patch_size), C, patch_size, patch_size)

    # Option to flatten the channel and spatial dimensions
    if flatten_channels:
        x = x.flatten(2,4)  # (B, (H/patch_size * W/patch_size), (C * patch_size * patch_size))

    return x

In [None]:
import torch
import torch.nn as nn

class VisionTransformer(nn.Module):
    def __init__(self, args):
        super().__init__()

        # Define the patch size
        self.patch_size = args.patch_size

        # Embedding layer to transform flattened patches to the desired dimension
        self.input_layer = nn.Linear(args.n_channels * (args.patch_size ** 2), args.dim)

        # Create the attention blocks for the transformer
        attn_blocks = []
        for _ in range(args.n_layers):
            attn_blocks.append(AttentionBlock(args))

        # Create the transformer by stacking the attention blocks
        self.transformer = nn.Sequential(*attn_blocks)

        # Define the classifier
        self.mlp = nn.Sequential(
            nn.LayerNorm(args.dim),
            nn.Linear(args.dim, args.n_classes)
        )

        # Dropout layer for regularization
        self.dropout = nn.Dropout(args.dropout)

        # Define the class token (similar to BERT's [CLS] token)
        self.cls_token = nn.Parameter(torch.randn(1, 1, args.dim))

        # Positional embeddings to give positional information to the transformer
        self.pos_embedding = nn.Parameter(torch.randn(1, 1 + args.n_patches, args.dim))

    def forward(self, x):
        # Convert image to patches and flatten
        x_patches = img_to_patch(x, self.patch_size)
        b, seq_len, _ = x_patches.shape

        # Transform patches using the embedding layer
        x = self.input_layer(x_patches)

        # Add the class token to the beginning of each sequence
        cls_token = self.cls_token.repeat(b, 1, 1)
        x = torch.cat([cls_token, x], dim=1)

        # Add positional embeddings to the sequence
        x = x + self.pos_embedding[:, :seq_len + 1]

        # Apply dropout
        x = self.dropout(x)

        # Process sequence through the transformer, capturing attention weights
        attn_weights = []
        for block in self.transformer:
            x, attn_weight = block(x)
            attn_weights.append(attn_weight)

        # Retrieve the class token's representation (for classification)
        x = x.transpose(0, 1)
        cls = x[0]

        # Classify using the representation of the class token
        out = self.mlp(cls)
        return out, attn_weights


In [None]:
# Path to the directory where CIFAR10 data will be stored/downloaded
DATA_DIR = "../data"

# Define the transformation for testing dataset:
# 1. Convert images to tensors.
# 2. Normalize the tensors using the mean and standard deviation of CIFAR10 dataset.
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
])

# Define the transformation for training dataset:
# 1. Apply random horizontal flip for data augmentation.
# 2. Perform random resizing and cropping of images for data augmentation.
# 3. Convert images to tensors.
# 4. Normalize the tensors using the mean and standard deviation of CIFAR10 dataset.
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
])

# Load the CIFAR10 training dataset with the defined training transformation.
# The dataset will be downloaded if not present in the DATA_DIR.
train_dataset = CIFAR10(root=DATA_DIR, train=True, transform=train_transform, download=True)

# Load the CIFAR10 testing dataset with the defined testing transformation.
# The dataset will be downloaded if not present in the DATA_DIR.
test_set = CIFAR10(root=DATA_DIR, train=False, transform=test_transform, download=True)

# Split the training dataset into training and validation sets.
# The training set will have 45000 images, and the validation set will have 5000 images.
train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 48531066.15it/s]


Extracting ../data/cifar-10-python.tar.gz to ../data
Files already downloaded and verified


In [None]:
# Define the batch size for training, validation, and testing.
batch_size = 64

# Define the number of subprocesses to use for data loading.
num_workers = 16

# Create a DataLoader for the training and validation dataset:
# 1. Shuffle the training data for each epoch.
# 2. Drop the last batch if its size is not equal to `batch_size` to maintain consistency.
train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=num_workers,
                                           drop_last=True)

# Do not drop any data; process all the validation data.
val_loader = torch.utils.data.DataLoader(dataset=val_set,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers,
                                         drop_last=False)

# Create a DataLoader for the testing dataset:
# Do not drop any data; process all the test data.
test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=num_workers,
                                          drop_last=False)




In [None]:
# Model, Loss and Optimizer
device = "cuda:0" if torch.cuda.is_available() else 0
args = ModelArgs()
model = VisionTransformer(args).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 130], gamma=0.1)

In [None]:
def train(model, criterion, optimizer, train_loader, val_loader, device, lr_scheduler, num_epochs=10):
    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        total_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs,_ = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}")

        # Validation Phase
        model.eval()
        total_val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs,_ = model(inputs)
                loss = criterion(outputs, labels)

                total_val_loss += loss.item()

                _, predicted = outputs.max(dim=-1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        avg_val_loss = total_val_loss / len(val_loader)
        val_accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

        lr_scheduler.step()

    print("Training complete!")

# To use this function:
train(model, criterion, optimizer, train_loader, val_loader, device, lr_scheduler, num_epochs=10)
# I  have set epochs=10 only due to limited time




Epoch [1/10], Training Loss: 1.4116
Epoch [1/10], Validation Loss: 1.3596, Validation Accuracy: 50.88%
Epoch [2/10], Training Loss: 1.3060
Epoch [2/10], Validation Loss: 1.2757, Validation Accuracy: 53.82%
Epoch [3/10], Training Loss: 1.2396
Epoch [3/10], Validation Loss: 1.1979, Validation Accuracy: 57.24%
Epoch [4/10], Training Loss: 1.1857
Epoch [4/10], Validation Loss: 1.2308, Validation Accuracy: 56.66%
Epoch [5/10], Training Loss: 1.1428
Epoch [5/10], Validation Loss: 1.1661, Validation Accuracy: 58.76%
Epoch [6/10], Training Loss: 1.1092
Epoch [6/10], Validation Loss: 1.1597, Validation Accuracy: 59.34%
Epoch [7/10], Training Loss: 1.0796
Epoch [7/10], Validation Loss: 1.0899, Validation Accuracy: 61.56%
Epoch [8/10], Training Loss: 1.0443
Epoch [8/10], Validation Loss: 1.1017, Validation Accuracy: 60.42%
Epoch [9/10], Training Loss: 1.0181
Epoch [9/10], Validation Loss: 1.1031, Validation Accuracy: 60.36%
Epoch [10/10], Training Loss: 0.9953
Epoch [10/10], Validation Loss: 1.08

In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs,_ = model(inputs)
        _, predicted = outputs.max(dim=-1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Accuracy: {test_accuracy:.2f}%")

Test Accuracy: 63.29%


In [None]:
# Save the state dictionary of the model
torch.save(model.state_dict(), 'model_weights.pth')
