Imports.

In [None]:
import polars as pl
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import platform
from tqdm import tqdm
import random

from data import get_dataloaders
from models import MLP
from train import train, evaluate


In [None]:
print(torch.__version__, torch.cuda.is_available(), platform.python_version())
device = 'mps' if torch.backends.mps.is_available() else 'cpu'


In [None]:
BATCH_SIZE = 128
EPOCHS = 10
train_loader, test_loader = get_dataloaders(batch_size=BATCH_SIZE)
ds_train = train_loader.dataset
labels = pl.Series(name='label', values=ds_train.targets)
print(labels.value_counts().sort('label'))
random_num = ds_train.data[random.randint(0, len(ds_train))]
plt.imshow(random_num.reshape((28, 28)), cmap='gray')
plt.show()


In [None]:
model = MLP().to(device)
opt = torch.optim.SGD(model.parameters(), lr=0.05)
loss_fn = nn.CrossEntropyLoss()
loss_history = train(model, train_loader, loss_fn, opt, device, epochs=EPOCHS,
                     preprocess=lambda x: x.view(x.size(0), -1))
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')


In [None]:
test_loss, accuracy = evaluate(
    model, test_loader, loss_fn, device,
    preprocess=lambda x: x.view(x.size(0), -1)
)
print(f'Accuracy: {accuracy:.2f}%')
