In [1]:
from typing import Tuple
import numpy as np
import torch
from torch import Tensor
from torchvision.datasets import MNIST
from tqdm import tqdm


In [2]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
        # For multi-gpu workstations, PyTorch will use the first available GPU (cuda:0), unless specified otherwise
        # (cuda:1).
    if torch.backends.mps.is_available():
        return torch.device('mos')
    return torch.device('cpu')

In [3]:
def forward(x: Tensor, w: Tensor, b: Tensor) -> Tensor:
    return x @ w + b

In [4]:
def sigmoid(z: Tensor):
    return z.sigmoid()

In [5]:
def softmax(z: Tensor) -> Tensor:
    return z.softmax(dim=1)

In [6]:
def backward(x: Tensor, y: Tensor, y1: Tensor, y2: Tensor, w2: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    error_layer2 = y2 - y # (batch_size, 10)
    delta_w2 = y1.T @ error_layer2 # (100, batch_size) @ (batch_size, 10) = (100, 10)
    delta_b2 = error_layer2.mean(dim=0) # (batch_size, 10)
    error_layer1 = y1 * (1 - y1) * (w2 @ error_layer2.T).T # (batch_size, 100) * (batch_size, 100) * ((100, 10) @ (10, batch_size)).T = (batch_size, 100)
    delta_w1 = x.T @ error_layer1 # (784, batch_size) @ (batch_size, 100) = (784, 100)
    delta_b1 = error_layer1.mean(dim=0) # (batch_size, 100)
    
    return delta_w1, delta_b1, delta_w2, delta_b2

In [7]:
def collate(x) -> Tensor:
    if isinstance(x, (tuple, list)):
        if isinstance(x[0], Tensor):
            return torch.stack(x)
        return torch.tensor(x)
    raise "Not supported yet"
    # see torch\utils\data\_utils\collate.py

In [8]:
def to_one_hot(x: Tensor) -> Tensor:
    return torch.eye(x.max() + 1)[x]

In [9]:
def train_batch(x: Tensor, y: Tensor, w1: Tensor, b1: Tensor, w2: Tensor, b2: Tensor, lr: float) -> Tuple[Tensor, Tensor, Tensor, Tensor, float]:
    y1 = sigmoid(forward(x, w1, b1))
    y2 = softmax(forward(y1, w2, b2))
    loss = torch.nn.functional.cross_entropy(y2, y)
    delta_w1, delta_b1, delta_w2, delta_b2 = backward(x, y, y1, y2, w2)
    w1 -= lr * delta_w1
    b1 -= lr * delta_b1
    w2 -= lr * delta_w2
    b2 -= lr * delta_b2
    return w1, b1, w2, b2, loss

In [10]:
def train_epoch(data: Tensor, labels: Tensor, w1: Tensor, b1: Tensor, w2: Tensor, b2: Tensor, lr: float, batch_size: int) \
        -> Tuple[Tensor, Tensor, Tensor, Tensor, float]:
    non_blocking = w1.device.type == 'cuda'
    epoch_loss = 0
    for i in range(0, data.shape[0], batch_size):
        x = data[i: i + batch_size].to(w1.device, non_blocking=non_blocking)
        y = labels[i: i + batch_size].to(w1.device, non_blocking=non_blocking)
        w1, b1, w2, b2, batch_loss = train_batch(x, y, w1, b1, w2, b2, lr)
        epoch_loss += batch_loss
    return w1, b1, w2, b2, epoch_loss / batch_size

In [11]:
def load_mnist(path: str = "./data", train: bool = True, pin_memory: bool = True):
    mnist_raw = MNIST(path, download=True, train=train)
    mnist_data = []
    mnist_labels = []
    for image, label in mnist_raw:
        tensor = torch.from_numpy(np.array(image))
        mnist_data.append(tensor)
        mnist_labels.append(label)

    mnist_data = collate(mnist_data).float()  # shape 60000, 28, 28
    mnist_data = mnist_data.flatten(start_dim=1)  # shape 60000, 784
    mnist_data /= mnist_data.max()  # min max normalize
    mnist_labels = collate(mnist_labels)  # shape 60000
    if train:
        mnist_labels = to_one_hot(mnist_labels)  # shape 60000, 10
    if pin_memory:
        return mnist_data.pin_memory(), mnist_labels.pin_memory()
    return mnist_data, mnist_labels

In [12]:
def evaluate(data: Tensor, labels: Tensor, w1: Tensor, b1: Tensor, w2: Tensor, b2: Tensor, batch_size: int) -> float:
    # Labels are not one hot encoded, because we do not need them as one hot.
    total_correct_predictions = 0
    total_len = data.shape[0]
    non_blocking = w1.device.type == 'cuda'
    for i in range(0, total_len, batch_size):
        x = data[i: i + batch_size].to(w1.device, non_blocking=non_blocking)
        y = labels[i: i + batch_size].to(w1.device, non_blocking=non_blocking)
        predicted_distribution = softmax(forward(sigmoid(forward(x, w1, b1)), w2, b2))
        # check torch.max documentation
        predicted_max_value, predicted_max_value_indices = torch.max(predicted_distribution, dim=1)
        # we check if the indices of the max value per line correspond to the correct label. We get a boolean mask
        # with True where the indices are the same, false otherwise
        equality_mask = predicted_max_value_indices == y
        # We sum the boolean mask, and get the number of True values in the mask. We use .item() to get the value out of
        # the tensor
        correct_predictions = equality_mask.sum().item()
        # correct_predictions = (torch.max(predicted_distribution, dim=1)[1] == y).sum().item()
        total_correct_predictions += correct_predictions

    return total_correct_predictions / total_len

In [13]:
def train(epochs: int = 1000, device: torch.device = get_default_device()):
    print(f"Using device {device}")
    pin_memory = device.type == 'cuda'  # Check the provided references.
    w1 = torch.normal(0, 1 / np.sqrt(784) , (784, 100), device=device)
    b1 = torch.zeros((1, 100), device=device)
    w2 = torch.normal(0, 1 / np.sqrt(784), (100, 10), device=device)
    b2 = torch.zeros((1, 10), device=device)
    lr = 0.001
    batch_size = 500
    eval_batch_size = 500
    data, labels = load_mnist(train=True, pin_memory=pin_memory)
    data_test, labels_test = load_mnist(train=False, pin_memory=pin_memory)
    epochs = tqdm(range(epochs))
    total_loss = 0
    for epoch in epochs:
        epoch_loss = 0
        w1, b1, w2, b2, epoch_loss = train_epoch(data, labels, w1, b1, w2, b2, lr, batch_size)
        total_loss += epoch_loss
        accuracy = evaluate(data_test, labels_test, w1, b1, w2, b2, eval_batch_size)
        epochs.set_postfix_str(f"accuracy = {accuracy}, epoch_loss = {epoch_loss}, total_loss = {total_loss}")
        if epoch % 300 == 0:
            lr *= 0.95

In [14]:
train(500)
train(500, torch.device('cpu'))

Using device cuda


100%|█| 500/500 [00:37<00:00, 13.41it/s, accuracy = 0.9801, epoch_loss = 0.35152125358581543, total_loss = 177.83457946


Using device cpu


100%|█| 500/500 [01:30<00:00,  5.50it/s, accuracy = 0.9795, epoch_loss = 0.3515196740627289, total_loss = 177.881317138
