In [None]:
import sys, os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import torch
import torch.nn as nn
import torch.nn.functional as F
from src import *

# Setup

In [None]:
# load MNIST data
train_img_path = "../data/train_img.idx"
train_label_path = "../data/train_label.idx"
test_img_path = "../data/test_img.idx"
test_label_path = "../data/test_label.idx"

train_samples, train_labels, test_samples, test_labels = normalize_mnist_data(
    train_img_path, train_label_path, test_img_path, test_label_path
)

# use tensors
train_samples = torch.from_numpy(train_samples).float()
train_labels = torch.from_numpy(train_labels)
test_samples = torch.from_numpy(test_samples).float()
test_labels = torch.from_numpy(test_labels)

# hyperparameters
epochs = 100
batch_size = 64
lr = 1e-4
nin = 784
nhidden = 100
nout = 10

# initialize model and optimizer
m = nn.Sequential(
    nn.Linear(nin, nhidden),
    nn.ReLU(),
    nn.Linear(nhidden, nhidden),
    nn.ReLU(),
    nn.Linear(nhidden, nhidden),
    nn.ReLU(),
    nn.Linear(nhidden, nhidden),
    nn.ReLU(),
    nn.Linear(nhidden, nout),
    nn.Softmax(),
)
optimizer = torch.optim.AdamW(m.parameters(), lr=lr)


def loss_fn(y, y_hat):
    return F.cross_entropy(y_hat, y)

# Train Loop

In [None]:
# train model

for epoch in range(epochs):
    epoch_loss = 0.0
    for i in range(0, len(train_samples), batch_size):
        optimizer.zero_grad()
        y_hat = m(train_samples[i : i + batch_size].flatten(1))
        loss = loss_fn(train_labels[i : i + batch_size].flatten(), y_hat)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"epoch: {epoch}, loss: {epoch_loss}")

# Accuracy Validation

In [None]:
# test model
correct = 0
for i in range(len(test_samples)):
    y_hat = m(test_samples[i].flatten())
    correct += (torch.argmax(y_hat, dim=0) == test_labels[i]).item()

print(f"accuracy: {correct / len(test_samples) * 100}%")

In [None]:
plot_image(test_samples[1])

In [None]:
torch.argmax(m(test_samples[1].flatten()))