#  Mini Vision Transformer (ViT) on CIFAR-10
### End-to-End Training & Inference Pipeline

Welcome to this self-contained Mini-ViT demo! This notebook is designed to train a Vision Transformer from scratch on your local machine, respecting your hardware constraints.

**Key Features:**
*   **Automatic Resource Management**: Automatically falls back to smaller models or batch sizes if CUDA OOM occurs.
*   **Time Budgeting**: Ensures training fits within your specified time limit (default: 30 mins) by adjusting epochs dynamically.
*   **Interactive Inference**: Test the model with your own images!

---
**Instructions:**
1.  Run all cells in order.
2.  Watch the **Training Loop** section for live progress.
3.  Use the final cells to save your model and run predictions on your own images.



## 1. Environment Setup & Imports
We start by importing PyTorch and setting a global random seed for reproducibility. We also detect if a GPU is available.



In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time
import sys
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        
set_seed(42)

print(f"üì¶ PyTorch Version: {torch.__version__}")
print(f"üîß Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")


üì¶ PyTorch Version: 2.7.1+cu118
üîß Device: cuda
   GPU: NVIDIA GeForce RTX 4050 Laptop GPU


## 2. Dynamic Configuration
This class handles the training configuration. It includes a smart `downgrade()` method that automatically reduces the batch size, model dimension, or network depth if an Out-Of-Memory (OOM) error causes a crash.



In [2]:
class Config:
    def __init__(self, mode='full', time_budget_min=30, force_cpu=False):
        self.device = 'cuda' if torch.cuda.is_available() and not force_cpu else 'cpu'
        self.image_size = 32  # CIFAR-10 default
        self.patch_size = 4   # 32/4 = 8x8 patches
        self.num_classes = 10
        
        # Primary Config (Target)
        self.dim = 128
        self.depth = 4
        self.heads = 4
        self.mlp_dim = 256
        self.batch_size = 64
        
        # Training Settings
        self.epochs = 10 if mode == 'demo' else 50
        self.lr = 1e-3
        self.weight_decay = 0.05
        
        self.time_budget = time_budget_min * 60  # seconds
        self.fallback_level = 0
        
    def downgrade(self):
        """Attempts to reduce resource usage. Returns True if downgraded, False if min limit reached."""
        self.fallback_level += 1
        
        # Level 1: Reduce Batch Size
        if self.batch_size > 16:
            print(f"‚ö†Ô∏è [Fallback] Reducing batch size from {self.batch_size} to {self.batch_size // 2}")
            self.batch_size //= 2
            return True
        
        # Level 2: Reduce Model Dim
        if self.dim == 128:
            print(f"‚ö†Ô∏è [Fallback] Reducing model dim from 128 to 64")
            self.dim = 64
            self.mlp_dim = 128
            self.batch_size = 64 # Reset batch size to try again
            return True
            
        # Level 3: Reduce Depth
        if self.depth == 4:
            print(f"‚ö†Ô∏è [Fallback] Reducing depth from 4 to 2")
            self.depth = 2
            self.batch_size = 64 # Reset batch size
            return True
            
        return False

    def __str__(self):
        return (f"Img:{self.image_size}, Patch:{self.patch_size}, Dim:{self.dim}, "
                f"Depth:{self.depth}, Heads:{self.heads}, Batch:{self.batch_size}, Epochs:{self.epochs}")


## 3. Mini-ViT Architecture
We implement the Vision Transformer from scratch.
*   **PatchEmbedding**: Breaks the image into flattened patches and projects them.
*   **TransformerBlock**: The core unit with Multi-Head Attention and an MLP.
*   **MiniViT**: The main class combining embeddings, position tokens, blocks, and the classification head.



In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, dim, channels=3):
        super().__init__()
        self.proj = nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        x = self.proj(x) # (B, D, H/P, W/P)
        x = x.flatten(2).transpose(1, 2) # (B, N, D)
        return x

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )
    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_dim)

    def forward(self, x):
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x

class MiniViT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_embed = PatchEmbedding(config.image_size, config.patch_size, config.dim)
        num_patches = (config.image_size // config.patch_size) ** 2
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, config.dim))
        
        self.blocks = nn.ModuleList([
            TransformerBlock(config.dim, config.heads, config.mlp_dim)
            for _ in range(config.depth)
        ])
        
        self.norm = nn.LayerNorm(config.dim)
        self.head = nn.Linear(config.dim, config.num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        return self.head(x[:, 0])

    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


## 4. Data Preparation (CIFAR-10)
We use `torchvision` to download and load CIFAR-10.
*   If you placed the `cifar-10-python.tar.gz` in `./data`, it will find and extract it.
*   If the file is missing and no internet is available, it falls back to **synthetic noise** (for testing the code).



In [4]:
def get_dataloaders(config, subsample=False):
    print(f"üì• Loading Data... (Subsample={subsample})")
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    try:
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    except Exception as e:
        print(f"‚ö†Ô∏è Failed to download CIFAR10: {e}. Using synthetic data.")
        trainset = torchvision.datasets.FakeData(size=20000 if subsample else 50000, image_size=(3, 32, 32), num_classes=10, transform=transforms.ToTensor())
        testset = torchvision.datasets.FakeData(size=10000, image_size=(3, 32, 32), num_classes=10, transform=transforms.ToTensor())

    if subsample and len(trainset) > 20000:
        indices = list(range(20000))
        trainset = torch.utils.data.Subset(trainset, indices)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=2, pin_memory=(config.device=='cuda'))
    testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=2)
    return trainloader, testloader, trainset, testset


## 5. Training Utilities
Define `train_one_epoch` (with Mixed Precision support) and `evaluate`.



In [5]:
def train_one_epoch(model, loader, criterion, optimizer, scaler, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        if device == 'cuda':
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return total_loss / len(loader), 100. * correct / total

def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return 100. * correct / total


## 6. Main Execution Loop
This runs the full pipeline:
1.  **Config & OOM Check**: Re-initializes model if memory error occurs.
2.  **Warmup**: Checks speed.
3.  **Training**: Runs epochs.
4.  **Auto-Save**: Saves `model_final.pt` automatically.



In [6]:
# --- User Settings ---
MODE = 'full'  # 'demo' (30mins, fast) or 'full' (90mins, better acc)
TIME_BUDGET_MIN = 30
FORCE_CPU = False
# ---------------------

config = Config(mode=MODE, time_budget_min=TIME_BUDGET_MIN, force_cpu=FORCE_CPU)
print(f"üöÄ Starting Main Loop with Budget: {TIME_BUDGET_MIN} mins")

# OOM Retry Loop
while True:
    try:
        # Clean up memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        print(f"\n----------------------------------------------------------------")
        print(f"‚ñ∂Ô∏è Attempting Config: {config}")
        print(f"----------------------------------------------------------------")
        
        trainloader, testloader, trainset, testset = get_dataloaders(config, subsample=(MODE == 'demo'))
        
        model = MiniViT(config).to(config.device)
        params = model.count_params()
        print(f"üß† Model Parameters: {params:,}")
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        scaler = torch.cuda.amp.GradScaler(enabled=(config.device == 'cuda'))
        
        # --- Warmup Phase ---
        start_time = time.perf_counter()
        print("\nüî• Running warmup epoch to measure speed...")
        train_loss, train_acc = train_one_epoch(model, trainloader, criterion, optimizer, scaler, config.device)
        epoch_time = time.perf_counter() - start_time
        print(f"   Warmup done in {epoch_time:.2f}s.")
        
        # --- Time Budgeting ---
        total_time_needed = epoch_time * config.epochs
        if total_time_needed > config.time_budget:
            new_epochs = max(int(config.time_budget / epoch_time), 1)
            print(f"‚ö†Ô∏è Projected time {total_time_needed/60:.1f}m > budget {config.time_budget/60}m.")
            print(f"   Adjusting epochs: {config.epochs} -> {new_epochs}")
            config.epochs = new_epochs

        # --- Training Phase ---
        print(f"\nüèãÔ∏è Starting training for {config.epochs} epochs...")
        history = []
        total_start_time = time.time()
        
        for epoch in range(1, config.epochs + 1):
            ep_start = time.perf_counter()
            loss, acc = train_one_epoch(model, trainloader, criterion, optimizer, scaler, config.device)
            val_acc = evaluate(model, testloader, config.device)
            ep_duration = time.perf_counter() - ep_start
            
            print(f"Epoch {epoch:02d}/{config.epochs} | ‚è±Ô∏è {ep_duration:5.1f}s | üìâ Loss: {loss:.4f} | ‚úÖ Train: {acc:5.1f}% | ‚≠ê Val: {val_acc:5.1f}%")
            history.append((epoch, loss, acc, val_acc))
            
        total_wall_time = time.time() - total_start_time
        
        # --- Save Phase (Auto) ---
        torch.save(model.state_dict(), 'model_final.pt')
        print(f"\nüíæ Model saved to 'model_final.pt'")
        
        # --- Report Generation ---
        report_content = (
            f"Mini-ViT Notebook Report\n"
            f"========================\n"
            f"Config: {config}\n"
            f"Params: {params:,}\n"
            f"Total Time: {total_wall_time:.1f}s\n"
            f"Avg Sec/Epoch: {total_wall_time/config.epochs:.2f}s\n"
            f"Final Val Acc: {history[-1][3]:.2f}%\n"
            f"Final Train Loss: {history[-1][1]:.4f}\n"
            f"Fallback Triggered: {config.fallback_level > 0}\n"
        )
        with open("report.txt", "w") as f:
            f.write(report_content)
        print("üìÑ Report saved to 'report.txt'")
        print("‚úÖ Demo Complete!")
        break # Exit retry loop on success
        
    except RuntimeError as e:
        if 'out of memory' in str(e):
            print("\nüö® CUDA OOM caught! Scaling down configuration...")
            if not config.downgrade():
                print("‚ùå Could not downgrade further. Aborting.")
                raise e
        else:
            raise e


üöÄ Starting Main Loop with Budget: 30 mins

----------------------------------------------------------------
‚ñ∂Ô∏è Attempting Config: Img:32, Patch:4, Dim:128, Depth:4, Heads:4, Batch:64, Epochs:50
----------------------------------------------------------------
üì• Loading Data... (Subsample=False)
üß† Model Parameters: 546,186

üî• Running warmup epoch to measure speed...


  scaler = torch.cuda.amp.GradScaler(enabled=(config.device == 'cuda'))
  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):


   Warmup done in 22.85s.

üèãÔ∏è Starting training for 50 epochs...
Epoch 01/50 | ‚è±Ô∏è  47.5s | üìâ Loss: 1.4583 | ‚úÖ Train:  47.1% | ‚≠ê Val:  50.2%
Epoch 02/50 | ‚è±Ô∏è  56.8s | üìâ Loss: 1.3470 | ‚úÖ Train:  51.0% | ‚≠ê Val:  53.4%
Epoch 03/50 | ‚è±Ô∏è  48.3s | üìâ Loss: 1.2610 | ‚úÖ Train:  54.2% | ‚≠ê Val:  57.0%
Epoch 04/50 | ‚è±Ô∏è  42.7s | üìâ Loss: 1.1973 | ‚úÖ Train:  56.8% | ‚≠ê Val:  58.4%
Epoch 05/50 | ‚è±Ô∏è  48.4s | üìâ Loss: 1.1285 | ‚úÖ Train:  59.6% | ‚≠ê Val:  61.1%
Epoch 06/50 | ‚è±Ô∏è  50.2s | üìâ Loss: 1.0687 | ‚úÖ Train:  61.8% | ‚≠ê Val:  63.0%
Epoch 07/50 | ‚è±Ô∏è  58.1s | üìâ Loss: 1.0274 | ‚úÖ Train:  63.2% | ‚≠ê Val:  64.1%
Epoch 08/50 | ‚è±Ô∏è  57.5s | üìâ Loss: 0.9789 | ‚úÖ Train:  65.1% | ‚≠ê Val:  64.9%
Epoch 09/50 | ‚è±Ô∏è  50.9s | üìâ Loss: 0.9417 | ‚úÖ Train:  66.4% | ‚≠ê Val:  68.2%
Epoch 10/50 | ‚è±Ô∏è  30.8s | üìâ Loss: 0.9096 | ‚úÖ Train:  67.7% | ‚≠ê Val:  67.9%
Epoch 11/50 | ‚è±Ô∏è  46.1s | üìâ Loss: 0.8748 | ‚úÖ Train:  68.7% | 