# CIFAR-10 Classification


In [12]:
# Test cell - this should show execution time
data_config = DataConfig(batch_size=32)
data = CIFAR10Data(data_config)
print("Data loaded successfully!")

model = build_model("cnn")
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

Data loaded successfully!
Model created with 1149770 parameters


## Dependencies and Utilities

In [7]:
import os, random
from dataclasses import dataclass
from types import SimpleNamespace
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import RandAugment
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm

def set_seed(seed: int = 42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

def accuracy(outputs, targets):
    _, preds = outputs.max(1)
    return preds.eq(targets).float().mean().item()

def save_checkpoint(state, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(state, path)

def load_checkpoint(path, map_location=None):
    return torch.load(path, map_location=map_location)


## Data Config & Augmentations

In [None]:
@dataclass
class DataConfig:
    data_dir: str = "./"   # parent dir of cifar-10-batches-py
    batch_size: int = 128
    num_workers: int = 4
    randaugment: bool = False
    cutout: bool = False
    cutout_holes: int = 1
    cutout_len: int = 16

class Cutout(object):
    def __init__(self, n_holes: int = 1, length: int = 16):
        self.n_holes = n_holes
        self.length = length
    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = torch.ones((h, w), dtype=torch.float32)
        for _ in range(self.n_holes):
            y = np.random.randint(h); x = np.random.randint(w)
            y1 = np.clip(y - self.length // 2, 0, h); y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w); x2 = np.clip(x + self.length // 2, 0, w)
            mask[int(y1):int(y2), int(x1):int(x2)] = 0.
        mask = mask.expand_as(img)
        return img * mask

class CIFAR10Data:
    def __init__(self, cfg: DataConfig):
        self.cfg = cfg  # Store the config as an instance variable
        
        normalize = transforms.Normalize([0.4914, 0.4822, 0.4465],
                                         [0.2023, 0.1994, 0.2010])
        train_tfms = [transforms.RandomCrop(32, padding=4),
                      transforms.RandomHorizontalFlip()]
        if cfg.randaugment: train_tfms.insert(0, RandAugment())
        train_tfms += [transforms.ToTensor(), normalize]
        if cfg.cutout: train_tfms.append(Cutout(cfg.cutout_holes, cfg.cutout_len))
        self.train_transform = transforms.Compose(train_tfms)
        self.test_transform  = transforms.Compose([transforms.ToTensor(), normalize])

        self.train_set = datasets.CIFAR10(cfg.data_dir, train=True, download=True, transform=self.train_transform)
        self.val_set   = datasets.CIFAR10(cfg.data_dir, train=True, download=False, transform=self.test_transform)
        self.test_set  = datasets.CIFAR10(cfg.data_dir, train=False, download=True, transform=self.test_transform)

        num_train = len(self.train_set)
        split = 5000
        train_idx, val_idx = list(range(num_train - split)), list(range(num_train - split, num_train))
        self.train_subset = torch.utils.data.Subset(self.train_set, train_idx)
        self.val_subset   = torch.utils.data.Subset(self.val_set,   val_idx)
    def loaders(self):
        train_loader = DataLoader(self.train_subset, batch_size=self.cfg.batch_size, shuffle=True,
                                  num_workers=self.cfg.num_workers, pin_memory=True)
        val_loader = DataLoader(self.val_subset, batch_size=self.cfg.batch_size, shuffle=False,
                                num_workers=self.cfg.num_workers, pin_memory=True)
        test_loader = DataLoader(self.test_set, batch_size=self.cfg.batch_size, shuffle=False,
                                 num_workers=self.cfg.num_workers, pin_memory=True)
        return train_loader, val_loader, test_loader


## Models

In [9]:
class BasicCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        self.classifier = nn.Linear(256, num_classes)
    def forward(self, x):
        return self.classifier(self.features(x).view(x.size(0), -1))

class BasicBlock(nn.Module):
    expansion=1
    def __init__(self,in_planes,planes,stride=1):
        super().__init__()
        self.conv1=nn.Conv2d(in_planes,planes,3,stride,1,bias=False)
        self.bn1=nn.BatchNorm2d(planes)
        self.conv2=nn.Conv2d(planes,planes,3,1,1,bias=False)
        self.bn2=nn.BatchNorm2d(planes)
        self.shortcut=nn.Sequential()
        if stride!=1 or in_planes!=planes:
            self.shortcut=nn.Sequential(nn.Conv2d(in_planes,planes,1,stride,bias=False), nn.BatchNorm2d(planes))
    def forward(self,x):
        out=F.relu(self.bn1(self.conv1(x)))
        out=self.bn2(self.conv2(out))
        out+=self.shortcut(x)
        return F.relu(out)

class CIFARResNet18(nn.Module):
    def __init__(self,num_classes=10):
        super().__init__()
        self.in_planes=64
        self.conv1=nn.Conv2d(3,64,3,1,1,bias=False)
        self.bn1=nn.BatchNorm2d(64)
        self.layer1=self._make_layer(64,2,1)
        self.layer2=self._make_layer(128,2,2)
        self.layer3=self._make_layer(256,2,2)
        self.layer4=self._make_layer(512,2,2)
        self.avgpool=nn.AdaptiveAvgPool2d(1)
        self.fc=nn.Linear(512,num_classes)
    def _make_layer(self,planes,blocks,stride):
        layers=[BasicBlock(self.in_planes,planes,stride)]
        self.in_planes=planes
        for _ in range(1,blocks): layers.append(BasicBlock(self.in_planes,planes))
        return nn.Sequential(*layers)
    def forward(self,x):
        x=F.relu(self.bn1(self.conv1(x)))
        x=self.layer1(x); x=self.layer2(x); x=self.layer3(x); x=self.layer4(x)
        return self.fc(self.avgpool(x).view(x.size(0),-1))

def build_model(name,num_classes=10):
    if name in ["cnn","baseline"]: return BasicCNN(num_classes)
    if name in ["resnet18","resnet"]: return CIFARResNet18(num_classes)
    raise ValueError(name)


## Training and Validation

In [10]:
@torch.no_grad()
def evaluate(args):
    device=torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    ckpt=load_checkpoint(args.ckpt,map_location=device)
    model_name=args.model if args.model else ckpt["args"]["model"]
    model=build_model(model_name,10).to(device); model.load_state_dict(ckpt["model"]); model.eval()
    data=CIFAR10Data(DataConfig(args.data_dir,args.batch_size,args.workers))
    _,_,test_loader=data.loaders()
    all_preds,all_targets=[],[]
    for images,targets in tqdm(test_loader,desc="Evaluating"):
        preds=model(images.to(device)).argmax(1).cpu().numpy()
        all_preds.append(preds); all_targets.append(targets.numpy())
    y_pred,y_true=np.concatenate(all_preds),np.concatenate(all_targets)
    print(classification_report(y_true,y_pred,digits=4,target_names=["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]))
    print(confusion_matrix(y_true,y_pred))


In [11]:
@torch.no_grad()
def evaluate(args):
    device=torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    ckpt=load_checkpoint(args.ckpt,map_location=device)
    model_name=args.model if args.model else ckpt["args"]["model"]
    model=build_model(model_name,10).to(device)
    model.load_state_dict(ckpt["model"]); model.eval()
    data=CIFAR10Data(DataConfig(args.data_dir,args.batch_size,args.workers))
    _,_,test_loader=data.loaders()
    all_preds,all_targets=[],[]
    for images,targets in tqdm(test_loader,desc="Evaluating"):
        preds=model(images.to(device)).argmax(1).cpu().numpy()
        all_preds.append(preds); all_targets.append(targets.numpy())
    y_pred,y_true=np.concatenate(all_preds),np.concatenate(all_targets)
    print(classification_report(y_true,y_pred,digits=4,
          target_names=["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]))
    print(confusion_matrix(y_true,y_pred))
