In [None]:
import numpy as np
from typing import List
from collections.abc import Callable
import operator as op

import math
from collections import deque, defaultdict

In [None]:
class Function:
    def __init__(self, func: Callable, deriv: Callable):
        self.func = func
        self.deriv = deriv

class Layer:
    def __init__(self, n: int, m: int, activation: Function):
        self.W = np.random.rand(n, m)
        self.b = np.zeros((m))
        self.activation = activation
        self.adjoint = self.W.T
        self.cache = np.random.rand(n)

    def _forward(self, x: np.array) -> np.array:
        # f(xW + b)
        self.cache = x
        a = x @ self.W + self.b
        return self.activation.func(a)
    
    def _backward(self, upstream: np.array) -> np.array:
        # dL/dz = dL/df @ df/dz = upstream @ df/dz; (1xm) @ (mxm)
        delta = upstream @ self.activation.deriv(self.cache)
        # dL/dx = dL/dz @ dz/dx
        dx = delta @ self.adjoint
        # dL/db = dL/dz @ dz/db = dL/dz @ I_m
        db = delta
        # dL/dW = dL/dz @ dz/dW
        dW = self.cache.T @ delta
        return dx

class Model:
    def __init__(self, N: int, M: int, layers: List[Layer] = None):
        self.layers = [] if layers is None else layers
        self.N = N
        self.M = M

    def forward(self, input: np.array) -> np.array:
        x = input
        for layer in self.layers:
            x = layer.forward(x)
        return x

    def backward(self, loss: Function):
        pass