Let's train AlexNet from scratch!

In [6]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torch.nn as nn

from torchvision import transforms

from trainer import Trainer
from utils import get_device, imshow

In [9]:
# setup device
device = get_device()

Using MPS...


### Training configurations

In [9]:
NUM_EPOCHS = 90 # roughly ~90 cycles specified in alexnet paper
BATCH_SIZE = 128
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
LEARNING_RATE = 0.01
NUM_CLASSES = 200 # tiny imagenet

# Dataset

We're going to use hugging face to load our dummy dataset (Tiny Imagenet) for training.

In [11]:
from datasets import load_dataset
train_set = load_dataset('zh-plus/tiny-imagenet', split='train')
test_set = load_dataset('zh-plus/tiny-imagenet', split='valid')

In [12]:
print(f'train: {len(train_set)}, test: {len(test_set)}')

train: 100000, test: 10000


In [15]:
# data preprocessing
train_transform = transforms.Compose([
        transforms.RandomResizedCrop(227),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

test_transform = transforms.Compose([
        transforms.CenterCrop(227),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
# split into 90-10 train and validation dataset
train_set, val_set = torch.utils.data.random_split(train_set, [0.9, 0.1])
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)


test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Observe Pretrained AlexNet
Let's see what the pytorch AlexNet looks like

In [12]:
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /Users/andy/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /Users/andy/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [07:51<00:00, 519kB/s] 


In [16]:
model

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

Some of the implementation details of the current pytorch model is a different than the original AlexNet paper but this github issue goes into some detail about it:

https://github.com/pytorch/vision/issues/549

NOTE: 
> "we added the avg_pool to make the model support images of different sizes. For images of size 224x224, the avg_pool is a no-op."

**We will try our best to implement the EXACT model mentioned in the paper instead of the PyTorch implementation**

# Our AlexNet Implementation

In [16]:
from alexnet import AlexNet
model = AlexNet()
model

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
    (1): ReLU()
    (2): LocalResponseNorm(5, alpha=0.001, beta=0.75, k=2)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): ReLU()
    (6): LocalResponseNorm(5, alpha=0.001, beta=0.75, k=2)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2):

In [17]:
n_params = sum(p.numel() for p in model.parameters())
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Number of parameters: {n_params}")
print(f"Number of trainable parameters: {n_trainable_params}")

Number of parameters: 62378344
Number of trainable parameters: 62378344


In [None]:
alexnet_trainer = Trainer(model,
                         batch_size=BATCH_SIZE,
                         learning_rate=LEARNING_RATE,
                         num_epochs=NUM_EPOCHS,
                         device=device)
alexnet_trainer.train(train_dataloader, val_dataloader)