In [None]:
from typing import Tuple
import numpy as np
import torch
from torch import Tensor
from torchvision.datasets import MNIST
from tqdm import tqdm

In [None]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')

def collate(x) -> Tensor:
    if isinstance(x, (tuple, list)):
        if isinstance(x[0], Tensor):
            return torch.stack(x)
        return torch.tensor(x)
    raise "Not supported yet"


def to_one_hot(x: Tensor) -> Tensor:
    return torch.eye(x.max() + 1)[x]

def load_mnist(path: str = "./data", train: bool = True):
    mnist_raw = MNIST(path, download=True, train=train)
    mnist_data = []
    mnist_labels = []
    for image, label in mnist_raw:
        tensor = torch.from_numpy(np.array(image))
        mnist_data.append(tensor)
        mnist_labels.append(label)

    mnist_data = collate(mnist_data).float()
    mnist_data = mnist_data.flatten(start_dim=1)
    mnist_data /= mnist_data.max()
    mnist_labels = collate(mnist_labels)
    if train:
        mnist_labels1 = to_one_hot(mnist_labels)
        return mnist_data, mnist_labels, mnist_labels1
    return mnist_data, mnist_labels


def activate(x: Tensor) -> Tensor:
    return x.softmax(dim=1)

def relu(x: Tensor) -> Tensor:
    return torch.relu(x)

def train_batch(x: Tensor, y: Tensor, w: Tensor, hidden_w: Tensor, b: Tensor, hidden_b: Tensor, lr: float, batch_size: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    # Forward
    hidden_forward = relu(x @ hidden_w + hidden_b)
    forward = hidden_forward @ w + b
    y_hat = activate(forward)

    # Backward
    error = y - y_hat
    hidden_error = error @ w.T
    delta_hidden_w = x.T @ hidden_error
    delta_w = hidden_forward.T @ error
    delta_hidden_b = hidden_error.mean(axis=0)
    delta_b = error.mean(axis=0)

    hidden_w += lr * delta_hidden_w / batch_size
    hidden_b += lr * delta_hidden_b / batch_size
    w += lr * delta_w / batch_size
    b += lr * delta_b / batch_size
    return w, b, hidden_w, hidden_b

def train_epoch(data: Tensor, labels: Tensor, w: Tensor, hidden_w: Tensor, b: Tensor, hidden_b:Tensor, lr: float, batch_size: int) \
        -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    for i in range(0, data.shape[0], batch_size):
        x = data[i: i + batch_size].to(w.device)
        y = labels[i: i + batch_size].to(w.device)
        w, b, hidden_w, hidden_b = train_batch(x, y, w, hidden_w, b, hidden_b, lr, batch_size)
    return w, b, hidden_w, hidden_b

def evaluate(data: Tensor, labels: Tensor, w: Tensor, b: Tensor, hidden_w: Tensor, hidden_b: Tensor, batch_size: int, train: bool, epochs) -> float:
    total_correct_predictions = 0
    total_len = data.shape[0]
    for i in range(0, total_len, batch_size):
        x = data[i: i + batch_size].to(w.device)
        y = labels[i: i + batch_size].to(w.device)
        forward = relu(x @ hidden_w + hidden_b) @ w + b
        predicted_distribution = activate(forward)
        cross_entropy_result = torch.nn.functional.cross_entropy(forward, y)
        # Both losses are computed, but they might be printed one over the other
        # If one looks for a few seconds at the output, it will be able to see both losses
        if train == True:
            epochs.set_postfix_str(f"Loss train: {cross_entropy_result}")
        else:
            epochs.set_postfix_str(f"Loss validation: {cross_entropy_result}")
        predicted_max_value, predicted_max_value_indices = torch.max(predicted_distribution, dim=1)
        equality_mask = predicted_max_value_indices == y
        correct_predictions = equality_mask.sum().item()
        total_correct_predictions += correct_predictions

    # print(total_correct_predictions, "/", total_len)
    return total_correct_predictions / total_len

def train(device: torch.device = get_default_device()):
    w = torch.rand((100, 10), device = device)
    hidden_w = torch.rand((784, 100), device = device)
    b = torch.zeros((1, 10), device = device)
    hidden_b = torch.zeros((1, 100), device = device)
    lr = 0.06
    batch_size = 128
    data, labels_train, labels = load_mnist(train=True)
    data_test, labels_test = load_mnist(train=False)
    max_accuracy = 0.0
    epoch_number = 100
    epochs = tqdm(range(epoch_number))
    for i in range(epoch_number):
        print("Epoch: ", i)
        w, b, hidden_w, hidden_b = train_epoch(data, labels, w, hidden_w, b, hidden_b, lr, batch_size)
        train_accuracy = evaluate(data, labels_train, w, b, hidden_w, hidden_b, batch_size, True, epochs)
        test_accuracy = evaluate(data_test, labels_test, w, b, hidden_w, hidden_b, batch_size, False, epochs)
        print("Current test accuracy: ", test_accuracy)
        max_accuracy = max(max_accuracy, test_accuracy)
        if i > 0 and i % 25 == 0:
            lr *= 0.5

    print("Best accuracy: ", max_accuracy)

train()