In [None]:
import matplotlib.pyplot as plt
import torch
import torchsummary
import torchvision
import tqdm.notebook

from traditional.lenet import LeNet5
from traditional.manual_scheduler import ManualLRScheduler

# Constants

In [None]:
# Data
dataset_location: str = "../data"
batch_size: int = 256
train_validation_split: float = 0.7

# Torch
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training
epochs: int = 20

# Load data
Load the MNIST dataset from torchvision and apply padding and normalisation as part of the transform.

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Pad(2),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0.5, 0.5)    
])

In [None]:
train_validation_data = torchvision.datasets.MNIST(dataset_location, transform=transform, download=True)
train_data, validation_data = torch.utils.data.random_split(train_validation_data, [train_validation_split, 1 - train_validation_split])
test_data = torchvision.datasets.MNIST(dataset_location, train=False, transform=transform, download=True)

In [None]:
def get_loader(dataset: torch.utils.data.Dataset) -> torch.utils.data.DataLoader:
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

train_loader = get_loader(train_data)
validation_loader = get_loader(validation_data)
test_loader = get_loader(test_data)

In [None]:
def get_sample() -> tuple[torch.Tensor, str]:
    data = next(iter(train_loader))
    return data[0][0].squeeze(0), train_validation_data.classes[data[1][0]]

image, label = get_sample()
print(f"Class: {label}")
plt.imshow(image);

# Training

In [None]:
# Model
model = LeNet5().to(device)
torchsummary.summary(model, (1, 32, 32))

In [None]:
# Optimizer and scheduler
learning_rates: list[float] = [5e-4, 2e-4, 1e-4, 5e-5, 1e-5]
counts: list[int] = [2, 3, 3, 4]

manual_lr_scheduler = ManualLRScheduler(learning_rates, counts)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rates[0])
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, manual_lr_scheduler.step)