In [None]:
from torch.utils.data import DataLoader
from torch import Tensor
from torch import optim
from torch.optim.adam import Adam

from torchvision.datasets import MNIST # type: ignore
from torchvision.transforms import Compose, ToTensor, Normalize # type: ignore

import utils

In [None]:
transform = Compose([ # type: ignore
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])

train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)
train_loader: DataLoader[tuple[Tensor, Tensor]] = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader: DataLoader[tuple[Tensor, Tensor]] = DataLoader(test_dataset, batch_size=64, shuffle=False)

lenet = utils.LeNet()

optimizer = Adam(lenet.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1.0)

In [None]:
trainer = utils.Trainer(lenet, train_loader, test_loader, optimizer, scheduler, epochs=100)
trainer.train()