In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from typing import Tuple, Dict
import wandb

In [None]:
!wandb login

In [23]:
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()

    # first conv block
    self.conv1 = nn.Sequential(
        # input channels (RGB image): 3
        # output channels: 32 feature maps
        # kernel size: 3x3
        # padding: 1 to maintain spatial dimensions
        # input size: 32x32x3
        # output size: 32x32x32
        nn.Conv2d(3, 32, kernel_size=3, padding=1),

        # normalizes each of the 32 feature maps
        # helps training by reducing internal covariate shift (see notes)
        nn.BatchNorm2d(32),

        # introduces non-linearity
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),

        # takes 2x2 window and keeps maximum value
        # stride of 2 means non-overlapping windows
        # reduces spatial dimensions by half
        # input size: 32x32x32
        # output size: 16x16x32
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

    # second conv block
    self.conv2 = nn.Sequential(
        nn.Conv2d(32, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.Conv2d(64, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),

        # output size: 64x8x8
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

    # third conv block
    self.conv3 = nn.Sequential(
        nn.Conv2d(64, 128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.Conv2d(128, 128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(),

        # output size: 128x4x4
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

    # fully connected layers
    self.fc = nn.Sequential(
        # linear transformation
        # This layer performs: output = weight_matrix(512 × 2048) × input(2048) + bias(512)
        # 128 filters, 4x4 feature map from previous layer
        # output size: 512 neurons
        # see notes as to why 512 neurons was chosen
        nn.Linear(128 * 4 * 4, 512),

        # f(x) = max(0,x)
        # see notes about ReLU
        nn.ReLU(),

        # input: 512 features from previous layer
        # output: 10 neurons (one for each CIFAR-10 class)
        # This layer performs: output = weight_matrix(10 × 512) × input(512) + bias(10)
        # the outputs represent the logits/score for each class
        nn.Linear(512, 10)
    )

  # x = batch_size, 3 RGB channels, 32x32 image dimensions
  # x.shape = (batch_size, 3, 32, 32)
  def forward(self, x):
    # Shape changes: (batch_size, 3, 32, 32) → (batch_size, 32, 16, 16)
    x = self.conv1(x)

    # Shape changes: (batch_size, 32, 16, 16) → (batch_size, 64, 8, 8)
    x = self.conv2(x)

    # Shape changes: (batch_size, 64, 8, 8) → (batch_size, 128, 4, 4)
    x = self.conv3(x)

    # flattens 3D feature maps into 1D vector
    # x.size(0) keeps the batch dimension
    # -1: automatically calculates the flatten dimension (128*4*4 = 2048)
    # Shape changes: (batch_size, 128, 4, 4) → (batch_size, 2048)
    x = x.view(x.size(0), -1)
    x = self.fc(x)

    # returns final logits
    # will typically go into a loss function for training or softmax for predictions
    return x

In [24]:
# data preprocessing
transform = transforms.Compose([
  transforms.RandomHorizontalFlip(),
  transforms.RandomCrop(32, padding=4),
  transforms.ToTensor(),

  # see notes on how to compute these values
  transforms.Normalize(
    mean=[0.4914, 0.4822, 0.4465],
    std=[0.2470, 0.2435, 0.2616]
  )
])

In [None]:
# load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# split training data into train and validation sets (80/20 split)
# use int() because we can't have fractional samples
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
train_dataset, val_dataset = random_split(trainset, [train_size, val_size])

In [26]:
# create data loaders
# larger batches = better GPU utilization, more stable gradients
# shuffle=True: randomly reorders data each epoch
# helps prevent the model from learning the order of data
# num_workers=2: uses 2 parallel processes to load data. 2 is a conservative choice
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

# no shuffling needed for validation because:
# we don't train on this data.
# consistent order helps with debugging
# makes results reproducible
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

In [27]:
# initialize the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [28]:
# training function
def train_epoch(model, train_loader, criterion, optimizer, device):
  model.train() # sets model to training mode (affects batchnorm, dropout)
  running_loss = 0.0 # accumulates loss over epoch
  correct = 0 # counts correct predictions
  total = 0 # counts total samples

  for inputs, labels in train_loader: # iterates through batches
    inputs, labels = inputs.to(device), labels.to(device)

    optimizer.zero_grad() # clears previous gradients
    outputs = model(inputs) # forward pass: gets model predictions
    loss = criterion(outputs, labels) # calculates loss
    loss.backward() # computes gradients
    optimizer.step() # updates model weights

    running_loss += loss.item() # accumulates batch loss


    _, predicted = outputs.max(1) # gets predicted class (highest probability)
    total += labels.size(0) # adds batch size to total

    # predicted.eq(labels) performs element-wise equality comparison between predicted and labels
    # .sum() adds up all True values. returns single-element tensor with # of correct predictions
    # .item() extracts scalar value from tensor
    correct += predicted.eq(labels).sum().item() # counts correct predictions

  # returns two metrics as a tuple
  # average loss per batch for the epoch
  # average correct percentage
  return running_loss / len(train_loader), 100. * correct / total

In [29]:
# validate function
def validate(model, val_loader, criterion, device):
  model.eval()
  running_loss = 0.0
  correct = 0
  total = 0

  with torch.no_grad():
    for inputs, labels in val_loader:
      inputs, labels = inputs.to(device), labels.to(device)
      outputs = model(inputs)
      loss = criterion(outputs, labels)

      running_loss += loss.item()
      _, predicted = outputs.max(1)
      total += labels.size(0)
      correct += predicted.eq(labels).sum().item()

  return running_loss / len(val_loader), 100. * correct / total

In [30]:
# train model function
def train_model(config: Dict = None):
  # initialize wandb
  if config is None:
    config = {
        'learning_rate': 0.001,
        'batch_size': 128,
        'num_epochs': 50,
        'architecture': 'CIFAR10CNN',
        'optimizer': 'Adam'
    }

  wandb.init(project='cifar10-cnn', config=config)
  # log gradients and model parameters
  wandb.watch(model)

  # training loop
  for epoch in range(config['num_epochs']):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)

    # log metrics to wandb
    wandb.log({
        'epoch': epoch,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc
    })

    print(f'Epoch {epoch+1}/{config["num_epochs"]}:')
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
    print('-' * 50)

  test_loss, test_acc = validate(model, test_loader, criterion, device)
  wandb.log({
    'test_loss': test_loss,
    'test_acc': test_acc
  })

  print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%')

  return model

In [31]:
model = train_model()

Epoch 1/50:
Train Loss: 1.4526 | Train Acc: 46.79%
Val Loss: 1.2163 | Val Acc: 55.83%
--------------------------------------------------
Epoch 2/50:
Train Loss: 0.9919 | Train Acc: 64.53%
Val Loss: 1.0289 | Val Acc: 63.60%
--------------------------------------------------
Epoch 3/50:
Train Loss: 0.8113 | Train Acc: 71.27%
Val Loss: 0.8233 | Val Acc: 70.68%
--------------------------------------------------
Epoch 4/50:
Train Loss: 0.7047 | Train Acc: 75.12%
Val Loss: 0.8215 | Val Acc: 72.05%
--------------------------------------------------
Epoch 5/50:
Train Loss: 0.6441 | Train Acc: 77.61%
Val Loss: 0.7605 | Val Acc: 73.76%
--------------------------------------------------
Epoch 6/50:
Train Loss: 0.5900 | Train Acc: 79.33%
Val Loss: 0.7646 | Val Acc: 73.55%
--------------------------------------------------
Epoch 7/50:
Train Loss: 0.5426 | Train Acc: 80.86%
Val Loss: 0.6576 | Val Acc: 77.06%
--------------------------------------------------
Epoch 8/50:
Train Loss: 0.5186 | Train Ac

In [None]:
wandb.finish()