In [1]:
%pip install -r requirements.txt

import torch
from torch.utils.data import DataLoader

device = "cpu"

In [2]:
import torchvision.datasets as datasets

class CacheMNISTDataset(datasets.MNIST):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cache = {}

    def __getitem__(self, index):
        if index in self.cache:
            return self.cache[index]
        self.cache[index] = super().__getitem__(index)
        return self.cache[index]

def pil_to_tensor(image):
    return torch.tensor(image.getdata()).view(28, 28).float() / 255

train_dataset = CacheMNISTDataset(root='./data', train=True, transform=pil_to_tensor, download=True)
test_dataset = CacheMNISTDataset(root='./data', train=False, transform=pil_to_tensor, download=True)

train_inp = torch.stack([x[0] for x in train_dataset]).to(torch.bfloat16)
train_out = torch.tensor([x[1] for x in train_dataset]).to(torch.bfloat16)
test_inp = torch.stack([x[0] for x in test_dataset]).to(torch.bfloat16)
test_out = torch.tensor([x[1] for x in test_dataset]).to(torch.bfloat16)

torch.save((train_inp, train_out), "mnist.train.pt")
torch.save((test_inp, test_out), "mnist.test.pt")

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False, num_workers=4)

In [3]:
import numpy as np

class SimpleDataset():
    def __init__(self, name):
        self.data = torch.load(name)

    def __len__(self):
        return len(self.data[0])

    def __getitem__(self, index):
        return self.data[0][index], self.data[1][index]

train_dataset = SimpleDataset("mnist.train.pt")
test_dataset = SimpleDataset("mnist.test.pt")

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False, num_workers=4)

  self.data = torch.load(name)


In [4]:
input_size = 784
hidden_size = 100
output_size = 10

W1 = torch.randn(input_size, hidden_size, dtype=torch.bfloat16, device=device) * 0.01
b1 = torch.zeros(hidden_size, dtype=torch.bfloat16, device=device)
W2 = torch.randn(hidden_size, output_size, dtype=torch.bfloat16, device=device) * 0.01
b2 = torch.zeros(output_size, dtype=torch.bfloat16, device=device)

learning_rate = 0.1

In [5]:
def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

def sigmoid_derivative(x):
    return sigmoid(x) * (1 - sigmoid(x))

def relu(x):
    return torch.max(torch.zeros_like(x), x).to(torch.bfloat16)

def relu_derivative(x):
    return (x > 0).to(torch.bfloat16)

In [6]:
def forward(x):
    z1 = x @ W1 + b1
    a1 = relu(z1)
    z2 = a1 @ W2 + b2
    return z1, a1, z2

forward(torch.randn(64, 784, dtype=torch.bfloat16, device=device))

(tensor([[ 0.6172,  0.1118,  0.0845,  ..., -0.0102, -0.1484,  0.0879],
         [ 0.2490,  0.4258,  0.2637,  ...,  0.4199, -0.1050,  0.6484],
         [ 0.0299, -0.0287, -0.2031,  ..., -0.0530, -0.0854,  0.2305],
         ...,
         [-0.0148, -0.5820,  0.3438,  ..., -0.0459,  0.4688, -0.3105],
         [-0.3184, -0.0684,  0.1001,  ...,  0.1465,  0.0977, -0.4570],
         [-0.2930, -0.4043,  0.2100,  ...,  0.2451, -0.1025, -0.3281]],
        dtype=torch.bfloat16),
 tensor([[0.6172, 0.1118, 0.0845,  ..., 0.0000, 0.0000, 0.0879],
         [0.2490, 0.4258, 0.2637,  ..., 0.4199, 0.0000, 0.6484],
         [0.0299, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.2305],
         ...,
         [0.0000, 0.0000, 0.3438,  ..., 0.0000, 0.4688, 0.0000],
         [0.0000, 0.0000, 0.1001,  ..., 0.1465, 0.0977, 0.0000],
         [0.0000, 0.0000, 0.2100,  ..., 0.2451, 0.0000, 0.0000]],
        dtype=torch.bfloat16),
 tensor([[-2.7344e-02, -6.5613e-03, -8.5449e-03,  5.8899e-03,  1.0986e-02,
           4.2969

In [7]:
@torch.compile
@torch.no_grad
def evaluate(loader):
    correct = 0
    total = 0
    for x, y in loader:
        x = x.view(-1, 28*28)
        _, _, z2 = forward(x)
        _, predicted = torch.max(z2, 1)
        total += y.size(0)
        correct += (predicted.to("cpu") == y).sum().item()
    train_accuracy = correct / total

    return correct, total, train_accuracy

evaluate(test_loader)

(1112, 10000, 0.1112)

In [8]:
def backward(x, y, z1, a1, z2):
    m = y.size(0)
    y_one_hot = torch.zeros(m, output_size, dtype=torch.bfloat16, device=device)
    y_one_hot[range(m), y.to(int)] = 1

    dz2 = z2 - y_one_hot
    dW2 = a1.T @ dz2 / m
    db2 = dz2.mean(dim=0)

    da1 = dz2 @ W2.t()
    dz1 = da1 * relu_derivative(z1)
    dW1 = x.T @ dz1 / m
    db1 = dz1.mean(dim=0)

    return dW1, db1, dW2, db2

backward(torch.randn(64, 784, dtype=torch.bfloat16, device=device), torch.randint(0, 10, (64,), dtype=torch.bfloat16, device=device), *forward(torch.randn(64, 784, dtype=torch.bfloat16, device=device)))

(tensor([[-4.2534e-04,  8.8501e-04, -5.6028e-05,  ...,  7.6675e-04,
           3.4904e-04,  1.5640e-03],
         [-4.9973e-04,  4.6921e-04, -1.1749e-03,  ...,  6.3324e-04,
           2.8801e-04,  4.8161e-05],
         [ 1.2493e-04, -1.8239e-05, -9.4604e-04,  ...,  7.7057e-04,
          -1.9379e-03, -7.7724e-05],
         ...,
         [ 1.3199e-03,  3.7956e-04, -9.4223e-04,  ...,  4.0817e-04,
           1.6556e-03, -2.0885e-04],
         [ 1.8768e-03, -1.1292e-03, -1.3065e-04,  ...,  9.7275e-04,
           1.5869e-03, -4.7112e-04],
         [ 3.6430e-04,  1.9932e-04,  7.0953e-04,  ..., -1.6499e-04,
          -1.5488e-03,  1.0834e-03]], dtype=torch.bfloat16),
 tensor([ 5.4169e-04,  1.9169e-04,  1.7166e-03, -2.7008e-03, -6.4468e-04,
         -2.3499e-03,  2.2411e-04,  8.0872e-04,  2.4719e-03,  1.1520e-03,
          1.0834e-03,  4.5300e-06,  1.6937e-03, -1.6251e-03,  3.2997e-04,
          5.3406e-04, -1.5640e-04, -5.7602e-04, -2.0447e-03,  8.6975e-04,
          1.2894e-03,  3.7842e-03, -

In [9]:
@torch.compile
def epoch(W1, b1, W2, b2):
    for x, y in train_loader:
        x = x.view(-1, 28*28)

        # Forward pass
        z1, a1, z2 = forward(x)

        # Backward pass
        dW1, db1, dW2, db2 = backward(x, y, z1, a1, z2)

        # Update weights and biases
        W1 -= learning_rate * dW1
        b1 -= learning_rate * db1
        W2 -= learning_rate * dW2
        b2 -= learning_rate * db2


def train(W1, b1, W2, b2):
    for e in range(10):
        epoch(W1, b1, W2, b2)

        print(f"Epoch: {e}, Train: {evaluate(train_loader)}, Test: {evaluate(test_loader)}")

train(W1, b1, W2, b2)

Epoch: 0, Train: (55462, 60000, 0.9243666666666667), Test: (9305, 10000, 0.9305)
Epoch: 1, Train: (56606, 60000, 0.9434333333333333), Test: (9458, 10000, 0.9458)
Epoch: 2, Train: (57216, 60000, 0.9536), Test: (9533, 10000, 0.9533)
Epoch: 3, Train: (57425, 60000, 0.9570833333333333), Test: (9538, 10000, 0.9538)
Epoch: 4, Train: (57803, 60000, 0.9633833333333334), Test: (9611, 10000, 0.9611)
Epoch: 5, Train: (57951, 60000, 0.96585), Test: (9616, 10000, 0.9616)
Epoch: 6, Train: (58051, 60000, 0.9675166666666667), Test: (9635, 10000, 0.9635)
Epoch: 7, Train: (58145, 60000, 0.9690833333333333), Test: (9647, 10000, 0.9647)
Epoch: 8, Train: (58271, 60000, 0.9711833333333333), Test: (9676, 10000, 0.9676)
Epoch: 9, Train: (58319, 60000, 0.9719833333333333), Test: (9667, 10000, 0.9667)
