In [None]:
from typing import Tuple
import numpy as np
import torch
from torch import Tensor
from torchvision.datasets import MNIST
from tqdm import tqdm

In [None]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
        # For multi-gpu workstations, PyTorch will use the first available GPU (cuda:0), unless specified otherwise
        # (cuda:1).
    if torch.backends.mps.is_available():
        return torch.device('mos')
    return torch.device('cpu')

def collate(x) -> Tensor:
    if isinstance(x, (tuple, list)):
        if isinstance(x[0], Tensor):
            return torch.stack(x)
        return torch.tensor(x)
    raise "Not supported yet"

def to_one_hot(x: Tensor) -> Tensor:
    return torch.eye(x.max() + 1)[x]

def load_mnist(path: str = "./data", train: bool = True, pin_memory: bool = True):
    mnist_raw = MNIST(path, download=True, train=train)
    mnist_data = []
    mnist_labels = []
    for image, label in mnist_raw:
        tensor = torch.from_numpy(np.array(image))
        mnist_data.append(tensor)
        mnist_labels.append(label)

    mnist_data = collate(mnist_data).float()  # shape 60000, 28, 28
    mnist_data = mnist_data.flatten(start_dim=1)  # shape 60000, 784
    mnist_data /= mnist_data.max()  # min max normalize
    mnist_labels = collate(mnist_labels)  # shape 60000
    if train:
        mnist_labels = to_one_hot(mnist_labels)  # shape 60000, 10
    if pin_memory:
        return mnist_data.pin_memory(), mnist_labels.pin_memory()
    return mnist_data, mnist_labels

def activate(x: Tensor) -> Tensor:
    print(f"x: {x.shape}")
    return 1/(1+torch.exp(-x))

def forward(x: Tensor, w: Tensor, b: Tensor) -> Tensor:
    print("forward")
    print(f"x: {x.shape}")
    print(f"w: {w.shape}")
    b,wx= torch.broadcast_tensors(b,w @ x)
    z=w @ x + b
    return z

def backward_chain_rule(w: Tensor, w2: Tensor, w3: Tensor,z: Tensor, z2: Tensor, z3: Tensor,a: Tensor, a1: Tensor, a3: Tensor,y):
    c = (a3 - y) ** 2
    print(f"c: {c}")
    delta_c=z2*(activate(z3)*(1-activate(z3)))*2*(a3-y)



def train_batch(x: Tensor, y: Tensor, w: Tensor, b: Tensor, lr: float) -> Tuple[Tensor, Tensor]:
    print("train_batch")
    print(f"x: {x.shape}")
    print(f"y: {x.shape}")
    print(f"w: {w.shape}")
    z=forward(x, w, b)
    a = activate(z)
    w2 = torch.rand((10, 100), device=device)
    b2 = torch.rand((10, 1), device=device)
    z2=w2 @ z + b2
    a2 = activate(z2)
    print(f"a2: {a2.shape}")
    w3 = torch.rand((1, 10), device=device)
    b3 = torch.rand((1, 1), device=device)
    z3=w3 @ z2 + b3
    a3 = activate(z3)
    print(f"a3: {a3.shape}")
    backward_chain_rule(w,w2,w3,z,z2,z3,a,a2,a3,y)

    return w, b

def train_epoch(x: Tensor, y: Tensor, w: Tensor, b: Tensor, lr: float, batch: int) \
        -> Tuple[Tensor, Tensor]:
    non_blocking = w.device.type == 'cuda'
    for i in range(0, x.shape[1]):
        print(x.shape)
        x = x[i:,:i+batch]
        y = y[i:,:i+batch]
        w, b = train_batch(x, y, w, b, lr)
    return w, b

def train(epochs,batch, device,x,y):
    print(f"Using device {device}")
    pin_memory = device.type == 'cuda'  # Check the provided references.
    w = torch.rand((100, 784), device=device)
    b = torch.zeros((100, 1), device=device)
    lr = 0.0005
    epochs = tqdm(range(epochs))
    for _ in epochs:
        w, b = train_epoch(x, y, w, b, lr, batch)
        accuracy = evaluate(data_test, labels_test, w, b, eval_batch_size)
        epochs.set_postfix_str(f"accuracy = {accuracy}")

if __name__ == '__main__':
  device=get_default_device()
  pin_memory = device.type == 'cuda'
  x, y = load_mnist(train=True, pin_memory=pin_memory)
  x=x.swapaxes(0,1)
  y=y.swapaxes(0,1)
  print(f"x: {x.shape}")
  print(f"y: {y.shape}")
  batch=200
  epoch=300
  train(epoch,batch,device,x,y)