In [4]:
!pip install --upgrade torchinfo


Collecting torchinfo
  Using cached torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Using cached torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [6]:
# -*- coding: utf-8 -*-
"""
ViT from Scratch for CIFAR-100 Classification and Comparison with ResNet-18.
Based on user-provided sample, extended for multiple configurations and comparison.
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchinfo import summary # For FLOPs/Params estimation
import time
import math
import copy
import pandas as pd # For better table display

# --- Configuration ---
# Define configurations to test
# Format: (config_name, patch_size, embed_dim, num_layers, num_heads, mlp_ratio)
vit_configs = [
    ("ViT_P4_E256_L4_H2_M2", 4, 256, 4, 2, 2.0),
    ("ViT_P4_E512_L4_H4_M4", 4, 512, 4, 4, 4.0), # Example with 4x embed_dim for MLP
    ("ViT_P8_E256_L8_H4_M2", 8, 256, 8, 4, 2.0),
    ("ViT_P8_E512_L8_H4_M4", 8, 512, 8, 4, 4.0), # Example with 4x embed_dim for MLP
    # Add your required configurations from the prompt:
    ("ViT_P4_E256_L8_H4_M4", 4, 256, 8, 4, 4.0),
    ("ViT_P8_E512_L4_H2_M2", 8, 512, 4, 2, 2.0),
]

# Training Hyperparameters
BATCH_SIZE = 64 # From prompt
EPOCHS = 10     # Train for 10 epochs for baseline comparison as requested
LR = 0.001      # From prompt
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLASSES = 100 # For CIFAR-100
IMG_SIZE = 32     # For CIFAR-100
IN_CHANNELS = 3

print(f"Using device: {DEVICE}")

# --- Data Loading and Preprocessing ---
# Using standard CIFAR-100 stats for normalization
transform_train = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)), # Ensure correct size
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.TrivialAugmentWide(), # Add more augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

transform_test = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)), # Ensure correct size
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

# Use CIFAR100 as requested
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# --- Vision Transformer Implementation (Based on Sample) ---

class PatchEmbedding(nn.Module):
    """Split image into patches and embed them."""
    def __init__(self, image_size, patch_size, in_channels=3, embed_dim=256):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, embed_dim, H', W']
        x = x.flatten(2)  # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)  # [B, num_patches, embed_dim]
        return x

# Using nn.MultiheadAttention requires careful handling of shapes (or batch_first=True)
# Adjusted TransformerEncoder based on sample but using standard blocks
class TransformerEncoder(nn.Module):
    """Transformer Encoder Block (using Pre-LN)."""
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        # Ensure batch_first=True matches input shape (B, N, E)
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout) # Add dropout layer if needed outside attention/mlp

    def forward(self, x):
        # Pre-LN structure
        residual = x
        x_norm = self.norm1(x)
        attn_output, _ = self.attention(x_norm, x_norm, x_norm) # Self-attention
        x = residual + self.dropout(attn_output) # Apply dropout after attention

        residual = x
        x_norm = self.norm2(x)
        mlp_output = self.mlp(x_norm)
        x = residual + mlp_output # MLP dropout is inside the sequential
        return x

class VisionTransformer(nn.Module):
    """Vision Transformer main model (based on sample)."""
    def __init__(self, image_size, patch_size, num_classes, embed_dim,
                 num_heads, num_layers, mlp_ratio, dropout=0.1): # Added mlp_ratio
        super().__init__()
        self.patch_embed = PatchEmbedding(image_size, patch_size, IN_CHANNELS, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_dropout = nn.Dropout(dropout) # Use pos_dropout for positional embedding

        # Calculate mlp_dim based on embed_dim and mlp_ratio
        mlp_dim_calculated = int(embed_dim * mlp_ratio)

        # Corrected Transformer blocks using ModuleList
        self.transformer_layers = nn.ModuleList([
            TransformerEncoder(embed_dim, num_heads, mlp_dim_calculated, dropout)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim) # Final LayerNorm
        self.head = nn.Linear(embed_dim, num_classes)

        # Weight Initialization
        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights_fn)

    def _init_weights_fn(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)


    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_dropout(x) # Apply dropout here

        for layer in self.transformer_layers:
            x = layer(x)

        x = self.norm(x) # Apply final norm
        cls_token_final = x[:, 0] # Get CLS token output
        x = self.head(cls_token_final)
        return x

# --- ResNet-18 Baseline ---
# Using torchvision's ResNet-18, modified for CIFAR
resnet18_model = torchvision.models.resnet18(weights=None, num_classes=NUM_CLASSES) # Use weights=None for training from scratch
# Adjust first conv layer and remove maxpool for CIFAR-100
resnet18_model.conv1 = nn.Conv2d(IN_CHANNELS, 64, kernel_size=3, stride=1, padding=1, bias=False)
resnet18_model.maxpool = nn.Identity()

# --- Training and Evaluation Loop ---
def train_model(model, model_name, trainloader, testloader, criterion, optimizer, epochs, device):
    print(f"\n--- Training {model_name} ---")
    model.to(device)
    # Use a scheduler for potentially better results, especially for ViT
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    results = {'train_loss': [], 'test_loss': [], 'test_acc': [], 'epoch_time': []}
    best_acc = 0.0
    # model_checkpoint = copy.deepcopy(model.state_dict()) # Keep track of best model

    for epoch in range(epochs):
        start_time = time.time()
        model.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            # Optional: Print batch loss
            # if (i+1) % 100 == 0:
            #    print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(trainloader)}], Batch Loss: {loss.item():.4f}')


        epoch_loss = running_loss / len(trainloader)
        results['train_loss'].append(epoch_loss)

        # Evaluation
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        epoch_test_loss = test_loss / len(testloader)
        epoch_test_acc = 100 * correct / total
        results['test_loss'].append(epoch_test_loss)
        results['test_acc'].append(epoch_test_acc)

        # Update best accuracy
        if epoch_test_acc > best_acc:
             best_acc = epoch_test_acc
             # model_checkpoint = copy.deepcopy(model.state_dict()) # Save best weights

        scheduler.step() # Step the scheduler

        end_time = time.time()
        epoch_duration = end_time - start_time
        results['epoch_time'].append(epoch_duration)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_loss:.4f}, Test Loss: {epoch_test_loss:.4f}, Test Acc: {epoch_test_acc:.2f}%, LR: {optimizer.param_groups[0]['lr']:.6f}, Time: {epoch_duration:.2f}s")

    print(f"Finished Training {model_name}. Best Test Accuracy: {best_acc:.2f}%")
    avg_epoch_time = sum(results['epoch_time']) / len(results['epoch_time']) if results['epoch_time'] else 0
    # Return final epoch accuracy or best accuracy - returning best as requested by prompt analysis
    final_acc = results['test_acc'][-1] if results['test_acc'] else 0
    print(f"Final Epoch Test Accuracy: {final_acc:.2f}%")
    return best_acc, avg_epoch_time # Return best accuracy achieved over the epochs

# --- Model Analysis and Training Execution ---
criterion = nn.CrossEntropyLoss()
results_data = [] # List to store results for pandas DataFrame

# Analyze and Train ViT Configurations
print("\n=== Processing ViT Configurations ===")
for config in vit_configs:
    config_name, patch_size, embed_dim, num_layers, num_heads, mlp_ratio = config
    print(f"\n--- Analyzing and Training {config_name} ---")

    # Check if patch size is valid for image size
    if IMG_SIZE % patch_size != 0:
        print(f"Skipping config {config_name}: Image size {IMG_SIZE} not divisible by patch size {patch_size}")
        results_data.append({
            "Model Configuration": config_name,
            "Params (M)": "N/A", "FLOPs (GMACs)": "N/A",
            "Avg Epoch Time (s)": "N/A", f"Test Acc (%) @{EPOCHS} epochs": "N/A (Skipped)"
        })
        continue

    vit_model = VisionTransformer(
        image_size=IMG_SIZE,
        patch_size=patch_size,
        num_classes=NUM_CLASSES,
        embed_dim=embed_dim,
        num_layers=num_layers,
        num_heads=num_heads,
        mlp_ratio=mlp_ratio,
    ).to(DEVICE)

    # Get model summary (Params, FLOPs) using torchinfo
    try:
        model_stats = summary(vit_model, input_size=(BATCH_SIZE, IN_CHANNELS, IMG_SIZE, IMG_SIZE), verbose=0)
        num_params = model_stats.total_params
        flops = model_stats.total_mult_adds / 1e9 # GigaMACs
        params_m = f"{num_params/1e6:.2f}"
        flops_g = f"{flops:.2f}"
        print(f"Parameters: {params_m} M")
        print(f"Estimated FLOPs (GMACs): {flops_g} G")
    except Exception as e:
        print(f"Could not get summary for {config_name}: {e}")
        num_params = -1
        flops = -1
        params_m = "Error"
        flops_g = "Error"


    # Train the model
    optimizer = optim.Adam(vit_model.parameters(), lr=LR, weight_decay=0.01) # Added weight decay
    best_acc, avg_epoch_time = train_model(vit_model, config_name, train_loader, test_loader, criterion, optimizer, EPOCHS, DEVICE)

    results_data.append({
        "Model Configuration": config_name,
        "Params (M)": params_m,
        "FLOPs (GMACs)": flops_g,
        "Avg Epoch Time (s)": f"{avg_epoch_time:.2f}",
        f"Test Acc (%) @{EPOCHS} epochs": f"{best_acc:.2f}"
    })
    del vit_model # Free up memory
    if DEVICE == 'cuda': torch.cuda.empty_cache() # Clear CUDA cache


# Analyze and Train ResNet-18 Baseline
print("\n=== Processing ResNet-18 Baseline ===")
print("\n--- Analyzing and Training ResNet-18 ---")
resnet18_model_instance = copy.deepcopy(resnet18_model).to(DEVICE) # Use deepcopy to avoid modifying original
try:
    model_stats_resnet = summary(resnet18_model_instance, input_size=(BATCH_SIZE, IN_CHANNELS, IMG_SIZE, IMG_SIZE), verbose=0)
    num_params_resnet = model_stats_resnet.total_params
    flops_resnet = model_stats_resnet.total_mult_adds / 1e9 # GigaMACs
    params_m_resnet = f"{num_params_resnet/1e6:.2f}"
    flops_g_resnet = f"{flops_resnet:.2f}"
    print(f"Parameters: {params_m_resnet} M")
    print(f"Estimated FLOPs (GMACs): {flops_g_resnet} G")
except Exception as e:
    print(f"Could not get summary for ResNet-18: {e}")
    num_params_resnet = -1
    flops_resnet = -1
    params_m_resnet = "Error"
    flops_g_resnet = "Error"

# Train ResNet-18
optimizer_resnet = optim.Adam(resnet18_model_instance.parameters(), lr=LR, weight_decay=0.01) # Added weight decay
best_acc_resnet, avg_epoch_time_resnet = train_model(resnet18_model_instance, "ResNet-18", train_loader, test_loader, criterion, optimizer_resnet, EPOCHS, DEVICE)

results_data.append({
    "Model Configuration": "ResNet-18",
    "Params (M)": params_m_resnet,
    "FLOPs (GMACs)": flops_g_resnet,
    "Avg Epoch Time (s)": f"{avg_epoch_time_resnet:.2f}",
    f"Test Acc (%) @{EPOCHS} epochs": f"{best_acc_resnet:.2f}"
})

# --- Display Summary Table ---
print("\n--- Results Summary ---")
results_df = pd.DataFrame(results_data)
print(results_df.to_string())

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified

=== Processing ViT Configurations ===

--- Analyzing and Training ViT_P4_E256_L4_H2_M2 ---
Parameters: 2.16 M
Estimated FLOPs (GMACs): 0.12 G

--- Training ViT_P4_E256_L4_H2_M2 ---
Epoch 1/10, Train Loss: 4.3519, Test Loss: 4.1155, Test Acc: 7.02%, LR: 0.000976, Time: 9.14s
Epoch 2/10, Train Loss: 4.2415, Test Loss: 4.0648, Test Acc: 7.83%, LR: 0.000905, Time: 8.96s
Epoch 3/10, Train Loss: 4.1865, Test Loss: 3.9802, Test Acc: 9.34%, LR: 0.000794, Time: 9.11s
Epoch 4/10, Train Loss: 4.1616, Test Loss: 3.9669, Test Acc: 9.13%, LR: 0.000655, Time: 8.81s
Epoch 5/10, Train Loss: 4.1416, Test Loss: 3.9366, Test Acc: 9.89%, LR: 0.000500, Time: 8.82s
Epoch 6/10, Train Loss: 4.1247, Test Loss: 3.9048, Test Acc: 10.51%, LR: 0.000345, Time: 8.81s
Epoch 7/10, Train Loss: 4.1013, Test Loss: 3.8826, Test Acc: 10.74%, LR: 0.000206, Time: 8.86s
Epoch 8/10, Train Loss: 4.0872, Test Loss: 3.8828, Test Acc: 11