# Image Classification on CIFAR-10 with ResNet & ViT
This notebook consolidates the full pipeline for training and evaluating **ResNet** and **Vision Transformer (ViT)** models on the **CIFAR-10** dataset.

## 1) Environment Setup

In [None]:

!pip install torch torchvision matplotlib --quiet
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


## 2) Data: CIFAR-10 Download & Dataloaders

In [None]:
#!/usr/bin/env python3
import os
import torchvision

DATA_DIR = "./data/cifar10"

def main():
    os.makedirs(DATA_DIR, exist_ok=True)
    print(f"[INFO] Downloading CIFAR-10 dataset to {DATA_DIR} ...")
    torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True)
    torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True)
    print("[INFO] CIFAR-10 dataset downloaded successfully!")

if __name__ == "__main__":
    main()

## 3) Utilities

In [None]:
import os
import torch
import torchvision
import torchvision.transforms as transforms

def get_cifar10_loaders(batch_size=128, data_dir="./data/cifar10", num_workers=2):
    """
    Loads CIFAR-10 from local data_dir (dataset must be pre-downloaded).
    """
    if not os.path.exists(os.path.join(data_dir, "cifar-10-batches-py")):
        raise FileNotFoundError(
            f"CIFAR-10 dataset not found in {data_dir}. "
            f"Run `python download_cifar10.py` first."
        )

    # Transforms
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    # Train set
    trainset = torchvision.datasets.CIFAR10(
        root=data_dir,
        train=True,
        download=False,
        transform=transform_train
    )

    # Test set
    testset = torchvision.datasets.CIFAR10(
        root=data_dir,
        train=False,
        download=False,
        transform=transform_test
    )

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=False
    )
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=False
    )

    return trainloader, testloader


## 4) Model Definitions (ResNet & ViT)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------
# Vision Transformer (EXACTLY your previous one)
# -----------------
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=128):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        num_patches = (img_size // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        B = x.shape[0]
        x = self.proj(x)  # [B, embed_dim, H', W']
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10,
                 embed_dim=128, depth=6, num_heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            activation="gelu",
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.encoder(x)
        x = self.norm(x[:, 0])  # CLS token
        return self.head(x)

# -----------------
# Small ResNet
# -----------------
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class SmallResNet(nn.Module):
    def __init__(self, block=BasicBlock, num_blocks=[2, 2, 2], num_classes=10):
        super().__init__()
        self.in_channels = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride):
        layers = [block(self.in_channels, out_channels, stride)]
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out


## 5) Training Pipeline: ResNet

In [None]:
#!/usr/bin/env python3
import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
from utils import get_cifar10_loaders
from models import SmallResNet

def main(args):
    device = torch.device("mps" if torch.backends.mps.is_available()
                          else "cuda" if torch.cuda.is_available()
                          else "cpu")
    print("Using device:", device)

    trainloader, testloader = get_cifar10_loaders(batch_size=args.batch_size)

    model = SmallResNet(num_classes=args.num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

    os.makedirs(args.out_dir, exist_ok=True)
    best_acc = 0.0

    for epoch in range(args.epochs):
        # train
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        for imgs, labels in trainloader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * imgs.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()
            total += imgs.size(0)
        train_acc = correct / total

        # val
        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for imgs, labels in testloader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                val_correct += (outputs.argmax(1) == labels).sum().item()
                val_total += labels.size(0)
        val_acc = val_correct / val_total

        print(f"Epoch {epoch+1}/{args.epochs}, "
              f"Train Loss: {total_loss/total:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            ckpt_path = os.path.join(args.out_dir, "resnet_best.pt")
            torch.save({
                "model_type": "resnet",
                "model_config": {"num_classes": args.num_classes},
                "model_state": model.state_dict(),
                "val_acc": val_acc
            }, ckpt_path)
            print(f"[INFO] Saved new best model to {ckpt_path} (Val Acc: {val_acc:.4f})")

        scheduler.step()



## 6) Training Pipeline: Vision Transformer (ViT)

In [None]:
#!/usr/bin/env python3
import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
from utils import get_cifar10_loaders
from models import VisionTransformer

def main(args):
    device = torch.device("mps" if torch.backends.mps.is_available()
                          else "cuda" if torch.cuda.is_available()
                          else "cpu")
    print("Using device:", device)

    trainloader, testloader = get_cifar10_loaders(batch_size=args.batch_size)

    model = VisionTransformer(num_classes=args.num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    os.makedirs(args.out_dir, exist_ok=True)
    best_acc = 0.0

    for epoch in range(args.epochs):
        # train
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        for imgs, labels in trainloader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * imgs.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()
            total += imgs.size(0)
        train_acc = correct / total

        # val
        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for imgs, labels in testloader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                val_correct += (outputs.argmax(1) == labels).sum().item()
                val_total += labels.size(0)
        val_acc = val_correct / val_total

        print(f"Epoch {epoch+1}/{args.epochs}, "
              f"Train Loss: {total_loss/total:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            ckpt_path = os.path.join(args.out_dir, "vit_best.pt")
            torch.save({
                "model_type": "vit",
                "model_config": {"num_classes": args.num_classes},
                "model_state": model.state_dict(),
                "val_acc": val_acc
            }, ckpt_path)
            print(f"[INFO] Saved new best model to {ckpt_path} (Val Acc: {val_acc:.4f})")



## 7) Inference Pipeline

In [None]:
import torch
from PIL import Image
import torchvision.transforms as transforms
import argparse
from models import VisionTransformer, SmallResNet

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

CLASSES = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']

def predict_image(path, ckpt_path):
    # Load checkpoint
    ckpt = torch.load(ckpt_path, map_location=device)
    model_type = ckpt.get("model_type", "resnet")  # Default to resnet if missing
    num_classes = ckpt.get("model_config", {}).get("num_classes", 10)

    # Load model
    if model_type == "vit":
        model = VisionTransformer(num_classes=num_classes).to(device)
    else:
        model = SmallResNet(num_classes=num_classes).to(device)

    model.load_state_dict(ckpt["model_state"])
    model.eval()

    # Load and preprocess image
    img = Image.open(path).convert("RGB")
    img = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img)
        probs = torch.softmax(outputs, dim=1)
        pred_idx = probs.argmax(1).item()

    print(f"Predicted: {CLASSES[pred_idx]} (Prob: {probs[0][pred_idx].item():.4f})")
    return CLASSES[pred_idx]



## 8) Quick Start (Summary Logs Only)

In [None]:

# Example: quick sanity run with small epoch count
# You can modify these calls to run full experiments.
# NOTE: Keep epochs low in Colab for demo; increase for real training.
try:
    # ResNet quick run (adjust args as required by your training functions)
    print("=== Quick ResNet Training (demo) ===")
    # e.g., train_resnet(epochs=2, batch_size=128, lr=0.001, out_dir="checkpoints_resnet_demo")
except Exception as e:
    print("ResNet demo skipped:", e)

try:
    # ViT quick run
    print("=== Quick ViT Training (demo) ===")
    # e.g., train_vit(epochs=2, batch_size=128, lr=0.001, out_dir="checkpoints_vit_demo")
except Exception as e:
    print("ViT demo skipped:", e)
