# Simple MNIST implementation
The goal of this notebook is to train a simple neural network to classify MNIST digits.  
This is meant to be a simple implementation of an MDL from scratch to test the implementation of the layers(except the convolution layer) and losses.  

## Setup

### Imports

In [9]:
from os.path import join

import plotly.express as px
import kagglehub
import numpy as np

from optimizers import SGD, SGD_with_momentum, RMSprop, Adam
from metrics import accuracy
from losses import BinaryCrossentropy
from layers import Linear, Relu, Sigmoid

### Data extraction

In [10]:
dataset_path = kagglehub.dataset_download("hojjatk/mnist-dataset")
train_image_path = join(dataset_path, 'train-images.idx3-ubyte')
train_labels_path = join(dataset_path, 'train-labels.idx1-ubyte')
test_image_path = join(dataset_path, 't10k-images.idx3-ubyte')
test_labels_path = join(dataset_path, 't10k-labels.idx1-ubyte')

def load_images(path) -> np.ndarray:
    with open(path, 'rb') as f:
        return (
            np.frombuffer(f.read(), dtype=np.uint8)
            [16:]
            .reshape(-1, 28**2)
            / 255
        )

def load_labels(path) -> np.ndarray:
    with open(path, 'rb') as f:
        label_idxs = np.frombuffer(f.read(), dtype=np.uint8)[8:]
        labels = np.eye(10)[label_idxs]
        return labels

train_dataset = load_images(train_image_path)
train_labels = load_labels(train_labels_path)
test_dataset = load_images(test_image_path)
test_labels = load_labels(test_labels_path)



In [11]:
np.argmax(train_labels, axis=1)

array([5, 0, 4, ..., 5, 6, 8])

In [12]:
px.imshow(train_dataset[0].reshape(28, 28))

## Model definition

In [13]:
INPUT_SIZE = 28**2
nn: list[Linear|Relu|Sigmoid] = [
    Linear(INPUT_SIZE, 64),
    Relu(),
    Linear(64, 10),
    Sigmoid(),
]
loss = BinaryCrossentropy()

In [14]:
# Randomize the dataset and labels
indices = np.random.permutation(train_dataset.shape[0])
train_dataset = train_dataset[indices]
train_labels = train_labels[indices]

## Training

In [15]:
training_stats = (
    Adam(
        starting_lr=0.01,
        lr_decay=0.0005,
        momentum_weight=0.9,
        ada_grad_weight=0.9,
    )
    .optimize_nn(
        nn,
        train_dataset,
        train_labels,
        epochs = 15,
        batch_size = 128,
        loss=loss, 
    )
)

Output()

In [16]:
fig = (
    px.scatter(
        training_stats.melt(id_vars="epoch", value_vars=["loss", "accuracy", "learning_rate"]),
        x="epoch",
        y="value",
        color="variable",
        facet_row="variable",
        title="Training Metrics Over Btaches",
        height=600,
    )
    .update_yaxes(matches=None)
    .update_yaxes(showticklabels=True)
)
fig.show()