Imports.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import platform
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import polars as pl

from data import get_dataloaders
from models import LeNet5, lenet5_init_
from train import train, evaluate
from levenberg_marquadt_optim import DiagLM


In [None]:
print(torch.__version__, torch.cuda.is_available(), platform.python_version())
device = 'mps' if torch.backends.mps.is_available() else 'cpu'


In [None]:
BATCH_SIZE = 256
EPOCHS = 10
train_loader, test_loader = get_dataloaders(batch_size=BATCH_SIZE)
ds_train = train_loader.dataset
labels = pl.Series(name='label', values=ds_train.targets)
print(labels.value_counts().sort('label'))
random_num = ds_train.data[random.randint(0, len(ds_train))]
plt.imshow(random_num.reshape((28, 28)), cmap='gray')
plt.show()


In [None]:
model = LeNet5().to(device)
model.apply(lenet5_init_)
opt = DiagLM(model.parameters(), lr=0.05)
loss_fn = nn.CrossEntropyLoss()
preprocess = lambda x: F.pad(x, (2, 2, 2, 2))
loss_history = train(model, train_loader, loss_fn, opt, device, epochs=EPOCHS, preprocess=preprocess)
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')


In [None]:
test_loss, accuracy = evaluate(model, test_loader, loss_fn, device, preprocess=preprocess)
print(f'Test Loss: {test_loss:.4f}')
print(f'Accuracy: {accuracy:.2f}%')
