<a href="https://colab.research.google.com/github/Vivek-1499/Deep-Learning-Vision-Transformer-ViT-vs-ResNet-18-on-CIFAR-10/blob/main/DL_IA2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
print("PyTorch Version:", torch.__version__)
print("CUDA Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device Count:", torch.cuda.device_count())
    print("Device Name:", torch.cuda.get_device_name(0))
    print("CUDA Version:", torch.version.cuda)
else:
    print("""
    ❌ GPU Still Not Detected - Possible Reasons:
    1. Organization restrictions (school/work account)
    2. Regional GPU shortage
    3. Browser compatibility issue
    """)

In [None]:
# -*- coding: utf-8 -*-
"""ViT vs ResNet-18 CIFAR-10 Comparison.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/[YOUR_LINK]
"""

# Install required libraries
!pip install timm  # For Vision Transformer
!pip install torchmetrics

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm.auto import tqdm
from torchmetrics import F1Score
from timm import create_model
from timm.layers import PatchEmbed

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters (Modified for CPU optimization)
BATCH_SIZE = 128  # Reduced batch size for CPU
NUM_EPOCHS = 5    # Reduced epochs for faster testing
LEARNING_RATE = 1e-3
IMAGE_SIZE = 224   # Keep original size for ViT compatibility

# Data Augmentation and Loaders (Using original 224x224 size)
train_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=BATCH_SIZE,
    shuffle=True, num_workers=0
)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=2
)

# Fixed ViT Initialization
def initialize_vit():
    model = create_model('vit_tiny_patch16_224', pretrained=True, num_classes=10)

    # Freeze first 10 layers
    for i, param in enumerate(model.parameters()):
        if i < 10:
            param.requires_grad = False

    return model.to(device)

# ResNet-18 Initialization
def initialize_resnet():
    model = torchvision.models.resnet18(pretrained=True)

    # Freeze initial layers
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze final layer block
    for param in model.layer4.parameters():
        param.requires_grad = True

    model.fc = nn.Linear(model.fc.in_features, 10)
    return model.to(device)

# Simplified Training Function (CPU compatible)
def train_model(model, criterion, optimizer, num_epochs=5):
    train_losses, val_losses = [], []
    train_acc, val_acc = [], []
    f1_metric = F1Score(task="multiclass", num_classes=10).to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        start_time = time.time()

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        epoch_time = time.time() - start_time
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        train_losses.append(epoch_loss)
        train_acc.append(epoch_acc)

        # Validation
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_running_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
                all_preds.append(predicted)
                all_labels.append(labels)

        val_epoch_loss = val_running_loss / len(test_loader)
        val_epoch_acc = 100 * val_correct / val_total
        val_f1 = f1_metric(torch.cat(all_preds), torch.cat(all_labels)).item()

        val_losses.append(val_epoch_loss)
        val_acc.append(val_epoch_acc)

        print(f"\nEpoch {epoch+1} Stats:")
        print(f"Time: {epoch_time:.1f}s | Loss: {epoch_loss:.3f}→{val_epoch_loss:.3f}")
        print(f"Accuracy: {epoch_acc:.1f}%→{val_epoch_acc:.1f}% | F1: {val_f1:.4f}")

    return train_losses, train_acc, val_losses, val_acc

# Run training
def run_experiment():
    # Train ViT
    vit_model = initialize_vit()
    optimizer_vit = optim.Adam(vit_model.parameters(), lr=LEARNING_RATE)
    print("\nTraining Vision Transformer...")
    vit_metrics = train_model(vit_model, nn.CrossEntropyLoss(), optimizer_vit, NUM_EPOCHS)

    # Train ResNet
    resnet_model = initialize_resnet()
    optimizer_res = optim.Adam(resnet_model.parameters(), lr=LEARNING_RATE)
    print("\nTraining ResNet-18...")
    resnet_metrics = train_model(resnet_model, nn.CrossEntropyLoss(), optimizer_res, NUM_EPOCHS)

    return vit_metrics, resnet_metrics, vit_model, resnet_model

# Execute experiment
vit_metrics, resnet_metrics, vit_model, resnet_model = run_experiment()

# Plot results
def plot_metrics(vit_metrics, resnet_metrics):
    vit_train_loss, vit_train_acc, vit_val_loss, vit_val_acc = vit_metrics
    res_train_loss, res_train_acc, res_val_loss, res_val_acc = resnet_metrics

    plt.figure(figsize=(15, 5))

    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(vit_train_loss, label='ViT Train')
    plt.plot(vit_val_loss, '--', label='ViT Val')
    plt.plot(res_train_loss, label='ResNet Train')
    plt.plot(res_val_loss, '--', label='ResNet Val')
    plt.title('Training/Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(vit_train_acc, label='ViT Train')
    plt.plot(vit_val_acc, '--', label='ViT Val')
    plt.plot(res_train_acc, label='ResNet Train')
    plt.plot(res_val_acc, '--', label='ResNet Val')
    plt.title('Training/Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    plt.tight_layout()
    plt.show()

plot_metrics(vit_metrics, resnet_metrics)

# Final Evaluation
@torch.no_grad()
def benchmark_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    times = []

    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        # Warmup
        if len(times) == 0:
            _ = model(images)

        start_time = time.perf_counter()
        outputs = model(images)
        torch.cuda.synchronize()
        times.append(time.perf_counter() - start_time)

        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    return {
        'accuracy': 100 * correct / total,
        'throughput': len(test_loader.dataset) / sum(times),
        'latency': 1000 * np.mean(times)  # ms per batch
    }

print("\nBenchmarking ViT:")
vit_stats = benchmark_model(vit_model, test_loader)
print(f"Accuracy: {vit_stats['accuracy']:.2f}% | Throughput: {vit_stats['throughput']:.1f} samples/s | Latency: {vit_stats['latency']:.1f}ms/batch")

print("\nBenchmarking ResNet-18:")
resnet_stats = benchmark_model(resnet_model, test_loader)
print(f"Accuracy: {resnet_stats['accuracy']:.2f}% | Throughput: {resnet_stats['throughput']:.1f} samples/s | Latency: {resnet_stats['latency']:.1f}ms/batch")