In [42]:
import torch
from torch import Tensor
from torchvision.datasets import MNIST
import numpy as np
from tqdm import tqdm

In [57]:
def sigmoid(x):
    return 1/(1+torch.exp(-x))

def softmax(x):
    return x.softmax(dim=1)

def forward(x, w, b):
    z = x@w + b
    return z

def backward(x, y_pred, y_true, y_hidden, w2):
    error2 = y_true - y_pred
    delta_w2 = y_hidden.T @ error2
    delta_b2 = error2.mean(axis=0)
    error1 = (y_hidden * (1-y_hidden))*(w2 @ error2.T).T
    delta_w1 = x.T @ error1
    delta_b1 = error1.mean(axis=0)
    return delta_w1, delta_b1, delta_w2, delta_b2

In [55]:
def train_epoch(data, labels, w1, b1, w2, b2, mu, size):
    loss_epoch = 0
    non_blocking = w1.device.type == 'cuda'
    for i in range(0, data.shape[0], size):
        x = data[i: i + size].to(w1.device, non_blocking=non_blocking)
        y = labels[i: i + size].to(w1.device, non_blocking=non_blocking)
    y_hidden = sigmoid(forward(x, w1, b1))
    y_pred = softmax(forward(y_hidden, w2, b2))
    loss = torch.nn.functional.cross_entropy(y_pred, y)
    loss_epoch += loss
    delta_w1, delta_b1, delta_w2, delta_b2 = backward(x, y_pred, y, y_hidden, w2)
    w1 += mu*delta_w1
    b1 += mu*delta_b1
    w2 += mu*delta_w2
    b2 += mu*delta_b2
    return w1, b1, w2, b2, loss_epoch/size

In [45]:
def evaluate(data, labels, w1, b1, w2, b2, size):
    total_correct_predictions = 0
    total_length = data.shape[0]
    non_blocking = w1.device.type == 'cuda'
    for i in range(0, total_length, size):
        x = data[i: i + size].to(w1.device, non_blocking=non_blocking)
        y = labels[i: i + size].to(w1.device, non_blocking=non_blocking)
        y_hidden = sigmoid(forward(x, w1, b1)) 
        y_pred = sigmoid(forward(y_hidden, w2, b2))
        predicted_max_value, predicted_max_value_indices = torch.max(y_pred, dim=1)
        equality_mask = predicted_max_value_indices == y
        correct_predictions = equality_mask.sum().item()
        total_correct_predictions += correct_predictions
    return total_correct_predictions / total_length

In [46]:
def collate(x):
    if isinstance(x, (tuple, list)):
        if isinstance(x[0], Tensor):
            return torch.stack(x)
        else:
            return torch.tensor(x)
    else: 
        raise "Not supported yet"

def to_one_hot(x):
    return torch.eye(x.max()+1)[x]

def load_mnist(path: str = "./data", train: bool = True, pin_memory: bool = True):
    mnist_raw = MNIST(path, download=True, train=train)
    mnist_data = []
    mnist_labels = []
    for image, label in mnist_raw:
        tensor = torch.from_numpy(np.array(image))
        mnist_data.append(tensor)
        mnist_labels.append(label)

    mnist_data = collate(mnist_data).float()  # shape 60000, 28, 28
    mnist_data = mnist_data.flatten(start_dim=1)  # shape 60000, 784
    mnist_data /= mnist_data.max()  # min max normalize
    mnist_labels = collate(mnist_labels)  # shape 60000
    if train:
        mnist_labels = to_one_hot(mnist_labels)  # shape 60000, 10
    if pin_memory:
        return mnist_data.pin_memory(), mnist_labels.pin_memory()
    return mnist_data, mnist_labels

In [82]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    if torch.backends.mps.is_available():
        return torch.device('mps')
    return torch.device('cpu')

def train(epochs: int = 1000, device: torch.device = get_default_device()):
    print(f"Using device {device}")
    pin_memory = device.type == 'cuda'  
    # w1 = torch.normal(0, 1/np.sqrt(784), (784, 100), device=device)
    w1 = torch.empty((784,100), device=device).normal_(mean=0, std=np.power(np.sqrt(784), (-1)))
    # b1 = torch.zeros((1, 100), device=device) 
    b1 = torch.empty((1, 100), device=device).normal_(mean=0, std=1)
    # w2 = torch.normal(0, 1/np.sqrt(784), (100, 10), device=device)
    w2 = torch.empty((100,10), device=device).normal_(mean=0, std=np.power(np.sqrt(100), (-1)))
    # b2 = torch.zeros((1, 10), device=device) 
    b2 = torch.empty((1, 10), device=device).normal_(mean=0, std=1)
    mu = 0.0
    size = 500
    eval_size = 500
    data, labels = load_mnist(train=True, pin_memory=pin_memory)
    data_test, labels_test = load_mnist(train=False, pin_memory=pin_memory)
    epochs = tqdm(range(epochs))
    loss = 0
    for e in epochs:
        loss_epoch = 0
        w1, b1, w2, b2, loss_epoch = train_epoch(data, labels, w1, b1, w2, b2, mu, size)
        loss += loss_epoch
        accuracy = evaluate(data_test, labels_test, w1, b1, w2, b2, eval_size)
        epochs.set_postfix_str(f"accuracy = {accuracy}, loss_epoch = {loss_epoch}, loss_total = {loss}")
        if e % 100 == 0: 
            mu *= 0.5

In [83]:
train(500)

Using device cpu


100%|█| 500/500 [00:15<00:00, 32.76it/s, accuracy = 0.7416, loss_epoch = 0.00294921244494617, loss_total = 1.5218460559
