# CIFAR-10 Image Classification

**Models:** Baseline CNN, ResNet-18 (transfer learning), EfficientNet-B0 (timm)

**Author:** Satya Narayan Mohanty

**University:** KIIT Deemed to be University

**Guide:** Mr. N Biraja Isac

---

## 1. Setup & Environment

This notebook contains full code to train and evaluate three models on the CIFAR-10 dataset. It is designed to run on a machine with a CUDA-capable GPU. Use the provided `requirements.txt` to create a conda/venv environment.

Files created in `/mnt/data/CIFAR10_Image_Classification_Project`:
- `train_notebook.ipynb` (this notebook)
- `requirements.txt`
- `README.md`
- `models/` (directory where trained model `.pth` files will be saved)

Run-time note: Training from scratch can take several hours depending on hardware. You can reduce epochs or use smaller batch sizes for quick tests.

In [None]:
# Imports and helper functions
import os
import time
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import matplotlib.pyplot as plt

print("PyTorch version:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Create models directory
os.makedirs("models", exist_ok=True)

## 2. Data Preparation

We will use standard CIFAR-10 from `torchvision.datasets`. We apply typical augmentations used for CIFAR training.

In [None]:
# Data transforms and loaders
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
testloader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

classes = trainset.classes
print('Classes:', classes)
print('Train samples:', len(trainset), 'Test samples:', len(testset))

## 3. Baseline CNN Model

A simple CNN with 3 conv blocks and 2 FC layers.

In [None]:
class BaselineCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2),

            nn.Conv2d(64,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2),

            nn.Conv2d(128,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256*4*4,512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512,num_classes)
        )
    def forward(self,x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# quick instantiation + parameter count
model = BaselineCNN().to(device)
def count_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)
print('BaselineCNN params:', count_params(model))

## 4. ResNet-18 (Transfer Learning)

We load a ResNet-18 pretrained on ImageNet and modify the final layer.

In [None]:
resnet18 = models.resnet18(pretrained=True)
# modify final fc for CIFAR-10
resnet18.fc = nn.Linear(resnet18.fc.in_features, 10)
resnet18 = resnet18.to(device)
print('ResNet-18 params:', count_params(resnet18))

## 5. EfficientNet-B0 (via timm)

We use `timm` to load EfficientNet-B0 pretrained weights and adjust classifier.

In [None]:
# EfficientNet via timm
try:
    import timm
    efficient = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
    efficient = efficient.to(device)
    print('EfficientNet-B0 params:', count_params(efficient))
except Exception as e:
    print('timm not installed. To use EfficientNet install timm via pip: pip install timm')

## 6. Training Utilities

Training loop, evaluation, and save/load helpers.

In [None]:
# Training and evaluation utilities
from tqdm.notebook import tqdm

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, targets in tqdm(loader, leave=False):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(targets).sum().item()
        total += inputs.size(0)
    return running_loss/total, correct/total

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(targets).sum().item()
            total += inputs.size(0)
    return running_loss/total, correct/total

def save_model(model, path):
    torch.save(model.state_dict(), path)
    print('Saved model to', path)

def load_model_state(model, path, device):
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    print('Loaded model from', path)

## 7. Example Training Run (Baseline CNN)

Below is an example of training code. Uncomment and run to start training. For quick experimentation reduce `num_epochs` to 5.

In [None]:
# Example training run (uncomment to run)
# model = BaselineCNN().to(device)
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
# num_epochs = 50
# best_acc = 0.0
# for epoch in range(num_epochs):
#     train_loss, train_acc = train_epoch(model, trainloader, criterion, optimizer, device)
#     val_loss, val_acc = evaluate(model, testloader, criterion, device)
#     scheduler.step()
#     print(f'Epoch {epoch+1}/{num_epochs} - Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')
#     if val_acc > best_acc:
#         best_acc = val_acc
#         save_model(model, f'models/baseline_cnn_best.pth')


## 8. Evaluation & Visualization

Code to compute confusion matrix and plot some predictions.

In [None]:
# Confusion matrix and some predictions (example helper code)
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

def get_all_preds(model, loader, device):
    model.eval()
    preds = []
    targets = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, p = outputs.max(1)
            preds.append(p.cpu().numpy())
            targets.append(labels.numpy())
    preds = np.concatenate(preds)
    targets = np.concatenate(targets)
    return preds, targets

# Example usage (after loading/saving models)
# load_model_state(model, 'models/baseline_cnn_best.pth', device)
# preds, targets = get_all_preds(model, testloader, device)
# cm = confusion_matrix(targets, preds)
# print(classification_report(targets, preds, target_names=classes))
# sns.heatmap(cm, annot=True, fmt='d', xticklabels=classes, yticklabels=classes)
# plt.xlabel('Predicted'); plt.ylabel('True'); plt.show()