In [87]:
from typing import Tuple
import numpy as np
import torch
from torch import Tensor
from torchvision.datasets import MNIST
from tqdm import tqdm


def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
        # For multi-gpu workstations, PyTorch will use the first available GPU (cuda:0), unless specified otherwise
        # (cuda:1).
    if torch.backends.mps.is_available():
        return torch.device('mos')
    return torch.device('cpu')


def collate(x) -> Tensor:
    if isinstance(x, (tuple, list)):
        if isinstance(x[0], Tensor):
            return torch.stack(x)
        return torch.tensor(x)
    raise "Not supported yet"
    # see torch\utils\data\_utils\collate.py


def load_mnist(path: str = "./data", train: bool = True, pin_memory: bool = True):
    mnist_raw = MNIST(path, download=True, train=train)
    mnist_data = []
    mnist_labels = []
    for image, label in mnist_raw:
        tensor = torch.from_numpy(np.array(image))
        mnist_data.append(tensor)
        mnist_labels.append(label)

    mnist_data = collate(mnist_data).float()  # shape 60000, 28, 28
    mnist_data = mnist_data.flatten(start_dim=1)  # shape 60000, 784
    mnist_data /= mnist_data.max()  # min max normalize
    mnist_labels = collate(mnist_labels)  # shape 60000
    if train:
        mnist_labels = to_one_hot(mnist_labels)  # shape 60000, 10
    if pin_memory:
        return mnist_data.pin_memory(), mnist_labels.pin_memory()
    return mnist_data, mnist_labels


def to_one_hot(x: Tensor) -> Tensor:
    return torch.eye(x.max() + 1)[x]


def forward(x: Tensor, w: Tensor, b: Tensor) -> Tensor:
    return x @ w + b


def activate_softmax(x: Tensor) -> Tensor:
    return x.softmax(dim=1)


def activate_sigmoid(x: Tensor) -> Tensor:
    # return 1.0 / (1.0 + torch.exp(-x))
    return torch.sigmoid(x)


def train_epoch(data: Tensor, labels: Tensor, w_hidden_layer: Tensor, b_hidden_layer: Tensor, w_last_layer: Tensor, b_last_layer: Tensor, lr: float, batch_size: int) \
        -> Tuple[Tensor, Tensor, Tensor, Tensor, float]:

    epoch_loss = 0
    non_blocking = w_hidden_layer.device.type == 'cuda'
    for i in range(0, data.shape[0], batch_size):
        x = data[i: i + batch_size].to(w_hidden_layer.device, non_blocking=non_blocking)
        y = labels[i: i + batch_size].to(w_hidden_layer.device, non_blocking=non_blocking)

        # forward propagation
        hidden_output = activate_sigmoid(forward(x, w_hidden_layer, b_hidden_layer))
        last_output = activate_softmax(forward(hidden_output, w_last_layer, b_last_layer))

        # loss
        epoch_loss += torch.nn.functional.cross_entropy(last_output, y).item()

        # backward propagation using chain rule
        last_error = last_output - y
        delta_w_last = hidden_output.T @ last_error
        delta_b_last = last_error.mean(dim=0)

        hidden_error = (hidden_output * (1 - hidden_output)) * (w_last_layer @ last_error.T).T
        delta_w_hidden = x.T @ hidden_error
        delta_b_hidden = hidden_error.mean(dim=0)

        w_last_layer -= lr * delta_w_last
        b_last_layer -= lr * delta_b_last

        w_hidden_layer -= lr * delta_w_hidden
        b_hidden_layer -= lr * delta_b_hidden

    return w_hidden_layer, b_hidden_layer, w_last_layer, b_last_layer, epoch_loss / batch_size



def evaluate(data: Tensor, labels: Tensor, w_hidden_layer: Tensor, b_hidden_layer: Tensor, w_last_layer: Tensor, b_last_layer: Tensor, batch_size: int) -> float:
    # Labels are not one hot encoded, because we do not need them as one hot.
    total_correct_predictions = 0
    total_len = data.shape[0]
    non_blocking = w_hidden_layer.device.type == 'cuda'
    for i in range(0, total_len, batch_size):
        x = data[i: i + batch_size].to(w_hidden_layer, non_blocking=non_blocking)
        y = labels[i: i + batch_size].to(w_hidden_layer.device, non_blocking=non_blocking)

        hidden_output = activate_sigmoid(forward(x, w_hidden_layer, b_hidden_layer))
        last_output = activate_softmax(forward(hidden_output, w_last_layer, b_last_layer))

        predicted_max_value, predicted_max_value_indices = torch.max(last_output, dim=1)
        # we check if the indices of the max value per line correspond to the correct label. We get a boolean mask
        # with True where the indices are the same, false otherwise
        equality_mask = predicted_max_value_indices == y
        # We sum the boolean mask, and get the number of True values in the mask. We use .item() to get the value out of
        # the tensor
        correct_predictions = equality_mask.sum().item()
        total_correct_predictions += correct_predictions

    return total_correct_predictions / total_len


def train(epochs: int = 1000, device: torch.device = get_default_device()):
    print(f"Using device {device}")
    pin_memory = device.type == 'cuda'
    w_hidden_layer = torch.empty((784, 100), device=device).normal_(mean=0, std=np.power(np.sqrt(784), (-1)))
    b_hidden_layer = torch.empty((1, 100), device=device).normal_(mean=0, std=1)

    w_last_layer = torch.empty((100, 10), device=device).normal_(mean=0, std=np.power(np.sqrt(100), (-1)))
    b_last_layer = torch.empty((1, 10), device=device).normal_(mean=0, std=1)

    lr = 0.005
    batch_size = 500
    eval_batch_size = 500
    data, labels = load_mnist(train=True, pin_memory=pin_memory)
    data_test, labels_test = load_mnist(train=False, pin_memory=pin_memory)
    epochs = tqdm(range(epochs))
    total_loss = 0
    for _ in epochs:
        w_hidden_layer, b_hidden_layer, w_last_layer, b_last_layer, epoch_loss = train_epoch(data, labels, w_hidden_layer, b_hidden_layer, w_last_layer, b_last_layer, lr, batch_size)
        total_loss += epoch_loss
        accuracy = evaluate(data_test, labels_test, w_hidden_layer, b_hidden_layer, w_last_layer, b_last_layer, eval_batch_size)
        epochs.set_postfix_str(f"accuracy = {accuracy}, epoch loss = {epoch_loss}, total loss = {total_loss}")
        if _ % 100 == 0:
            lr *= 0.7


In [None]:
if __name__ == '__main__':
    train(500, torch.device('cpu'))
    train(500)

Using device cpu


100%|██████████| 500/500 [04:50<00:00,  1.72it/s, accuracy = 0.9701, epoch loss = 0.3513817830085754, total loss = 176.7360966441631]


Using device cpu


 58%|█████▊    | 291/500 [02:48<01:51,  1.87it/s, accuracy = 0.9673, epoch loss = 0.3520513737201691, total loss = 103.41203617000582]