<a href="https://colab.research.google.com/github/VishnuAravind-RG/AIRL-Internship/blob/main/q1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


config = {
    'image_size': 32,
    'patch_size': 4,
    'num_classes': 10,
    'dim': 256,
    'depth': 6,
    'heads': 8,
    'mlp_dim': 512,
    'channels': 3,
    'dropout': 0.1,
    'emb_dropout': 0.1,
    'batch_size': 128,
    'epochs': 100,
    'lr': 3e-4,
    'weight_decay': 0.03,
    'warmup_epochs': 10
}

print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")


train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])


train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=test_transform
)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'],
                         shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'],
                        shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")


class PatchEmbedding(nn.Module):
    def __init__(self, image_size=32, patch_size=4, channels=3, dim=256):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = channels * patch_size ** 2

        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(2),
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))
        self.dropout = nn.Dropout(config['emb_dropout'])

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.to_patch_embedding(x)
        x = x.transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding
        x = self.dropout(x)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads=8, dropout=0.1):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        B, N, D = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(B, N, self.heads, D // self.heads).transpose(1, 2), qkv)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        out = (attn @ v).transpose(1, 2).reshape(B, N, D)
        return self.to_out(out)

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, heads, dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_dim, dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_embed = PatchEmbedding(
            config['image_size'], config['patch_size'],
            config['channels'], config['dim']
        )

        self.transformer = nn.Sequential(*[
            TransformerBlock(config['dim'], config['heads'], config['mlp_dim'], config['dropout'])
            for _ in range(config['depth'])
        ])

        self.norm = nn.LayerNorm(config['dim'])
        self.head = nn.Linear(config['dim'], config['num_classes'])

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.transformer(x)
        x = self.norm(x)
        cls_token = x[:, 0]
        return self.head(cls_token)


model = VisionTransformer(config).to(device)
optimizer = optim.AdamW(model.parameters(), lr=config['lr'],
                       weight_decay=config['weight_decay'])


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

total_steps = len(train_loader) * config['epochs']
warmup_steps = len(train_loader) * config['warmup_epochs']
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)

criterion = nn.CrossEntropyLoss()


def train_epoch(model, loader, optimizer, scheduler, epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc=f'Epoch {epoch+1}/{config["epochs"]}')
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%',
            'LR': f'{scheduler.get_last_lr()[0]:.6f}'
        })

    return total_loss / len(loader), 100. * correct / total


def test_epoch(model, loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    return 100. * correct / total


train_losses = []
train_accs = []
test_accs = []
best_acc = 0

print("Starting training...")
for epoch in range(config['epochs']):
    start_time = time.time()

    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, epoch)
    test_acc = test_epoch(model, test_loader)

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    test_accs.append(test_acc)

    epoch_time = time.time() - start_time

    print(f'Epoch {epoch+1:03d}: '
          f'Train Loss: {train_loss:.4f} | '
          f'Train Acc: {train_acc:.2f}% | '
          f'Test Acc: {test_acc:.2f}% | '
          f'Time: {epoch_time:.2f}s')

    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'best_vit_model.pth')
        print(f'New best model saved! Test Acc: {best_acc:.2f}%')

print(f"\nBest Test Accuracy: {best_acc:.2f}%")

# Plot results
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(test_accs, label='Test Acc')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()


print("\nFinal Evaluation:")
model.load_state_dict(torch.load('best_vit_model.pth'))
final_test_acc = test_epoch(model, test_loader)
print(f"Final Test Accuracy: {final_test_acc:.2f}%")

torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'test_accuracy': final_test_acc
}, 'final_vit_model.pth')

print("Training completed!")

Using device: cpu
Configuration:
  image_size: 32
  patch_size: 4
  num_classes: 10
  dim: 256
  depth: 6
  heads: 8
  mlp_dim: 512
  channels: 3
  dropout: 0.1
  emb_dropout: 0.1
  batch_size: 128
  epochs: 100
  lr: 0.0003
  weight_decay: 0.03
  warmup_epochs: 10


100%|██████████| 170M/170M [00:02<00:00, 57.4MB/s]


Training samples: 50000
Test samples: 10000
Starting training...


Epoch 1/100:  29%|██▉       | 114/391 [09:11<22:00,  4.77s/it, Loss=2.2844, Acc=11.32%, LR=0.000009]