In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np

# Data augmentation and normalization
mean = [0.4802, 0.4481, 0.3975]
std = [0.2302, 0.2265, 0.2262]

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(64),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_val = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

# Paths to the training and validation directories
train_dir = './tiny-imagenet-200/train'
val_dir = './tiny-imagenet-200/val/sorted_val'

# Get sorted list of class names
class_names = sorted(os.listdir(train_dir))

# Create a mapping from class names to indices
class_to_idx = {cls_name: idx for idx, cls_name in enumerate(class_names)}

# Create datasets with sorted class labels
trainset = torchvision.datasets.ImageFolder(
    root=train_dir,
    transform=transform_train,
    target_transform=lambda label: class_to_idx[trainset.classes[label]],
)

valset = torchvision.datasets.ImageFolder(
    root=val_dir,
    transform=transform_val,
    target_transform=lambda label: class_to_idx[valset.classes[label]],
)

# Data loaders
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
valloader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=4)

classes = 200  # Tiny ImageNet has 200 classes

# Model Definition
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=64, patch_size=8, in_channels=3, embed_dim=512):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        x = self.proj(x)  # [batch_size, embed_dim, H', W']
        x = x.flatten(2)  # [batch_size, embed_dim, num_patches]
        x = x.transpose(1, 2)  # [batch_size, num_patches, embed_dim]
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, dropout=0.1):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        x_norm = self.layer_norm1(x)
        attn_output, _ = self.self_attn(x_norm.transpose(0, 1), x_norm.transpose(0, 1), x_norm.transpose(0, 1))
        x = x + attn_output.transpose(0, 1)

        x_norm = self.layer_norm2(x)
        mlp_output = self.mlp(x_norm)
        x = x + mlp_output
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=64, patch_size=8, in_channels=3, num_classes=200, embed_dim=512, depth=6, num_heads=8, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        # Class token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Position embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        # Transformer encoder layers
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, dropout) for _ in range(depth)
        ])

        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
    
    def forward(self, x):
        x = self.patch_embed(x)
        batch_size = x.size(0)

        # Concatenate class token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add position embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Transformer encoder
        for layer in self.layers:
            x = layer(x)

        # Classification head
        x = self.norm(x)
        cls_token_final = x[:, 0]
        out = self.head(cls_token_final)
        return out

# Calculate total number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

model = VisionTransformer().to('cuda' if torch.cuda.is_available() else 'cpu')
total_params = count_parameters(model)
print(f'Total number of parameters: {total_params}')

# Training Components
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Loss function with label smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Optimizer with weight decay
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

# Accuracy calculation
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)  # Get top k predictions
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res  # Returns list of accuracies

# Training Loop
num_epochs = 100  # Adjust as needed
best_acc = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    top1_train = []
    top3_train = []
    top5_train = []

    for inputs, labels in trainloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

        # Calculate training accuracy
        acc1, acc3, acc5 = accuracy(outputs, labels, topk=(1, 3, 5))
        top1_train.append(acc1.item())
        top3_train.append(acc3.item())
        top5_train.append(acc5.item())
    
    scheduler.step()

    # Average training loss and accuracy
    epoch_loss = running_loss / len(trainloader.dataset)
    train_acc1 = np.mean(top1_train)
    train_acc3 = np.mean(top3_train)
    train_acc5 = np.mean(top5_train)

    # Validation phase
    model.eval()
    top1_val = []
    top3_val = []
    top5_val = []

    with torch.no_grad():
        for inputs, labels in valloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)

            acc1, acc3, acc5 = accuracy(outputs, labels, topk=(1, 3, 5))
            top1_val.append(acc1.item())
            top3_val.append(acc3.item())
            top5_val.append(acc5.item())

    val_acc1 = np.mean(top1_val)
    val_acc3 = np.mean(top3_val)
    val_acc5 = np.mean(top5_val)

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
    print(f'Train Accuracies: Top-1 {train_acc1:.2f}%, Top-3 {train_acc3:.2f}%, Top-5 {train_acc5:.2f}%')
    print(f'Val Accuracies:   Top-1 {val_acc1:.2f}%, Top-3 {val_acc3:.2f}%, Top-5 {val_acc5:.2f}%\n')

    # Save the best model
    if val_acc1 > best_acc:
        best_acc = val_acc1
        torch.save(model.state_dict(), 'best_vit_tiny_imagenet.pth')

# Reporting Final Results
# Load the best model
model.load_state_dict(torch.load('best_vit_tiny_imagenet.pth'))

# Evaluate on the training set
model.eval()
top1_train = []
top3_train = []
top5_train = []

with torch.no_grad():
    for inputs, labels in trainloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)

        acc1, acc3, acc5 = accuracy(outputs, labels, topk=(1, 3, 5))
        top1_train.append(acc1.item())
        top3_train.append(acc3.item())
        top5_train.append(acc5.item())

train_acc1 = np.mean(top1_train)
train_acc3 = np.mean(top3_train)
train_acc5 = np.mean(top5_train)

print(f'Final Training Accuracies: Top-1 {train_acc1:.2f}%, Top-3 {train_acc3:.2f}%, Top-5 {train_acc5:.2f}%')

# Evaluate on the validation set
top1_val = []
top3_val = []
top5_val = []

with torch.no_grad():
    for inputs, labels in valloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)

        acc1, acc3, acc5 = accuracy(outputs, labels, topk=(1, 3, 5))
        top1_val.append(acc1.item())
        top3_val.append(acc3.item())
        top5_val.append(acc5.item())

val_acc1 = np.mean(top1_val)
val_acc3 = np.mean(top3_val)
val_acc5 = np.mean(top5_val)

print(f'Final Validation Accuracies: Top-1 {val_acc1:.2f}%, Top-3 {val_acc3:.2f}%, Top-5 {val_acc5:.2f}%')

# Report the number of parameters
print(f'Total number of parameters: {total_params}')
