In [6]:
import numpy as np
from typing import Union, Tuple, Self, Iterable
import inspect

In [7]:
RNG = np.random.default_rng()
DTYPE = 'float64' 

In [None]:
'''TODO 
- input and gradient sanitisation (ensure both are defined)
- add in auto grad visulisation
- add in a way to test the primative operations
- add in convolutions
- todo implement dataloaders 
- add logging, WandB and also terminal logging
'''

class Tensor():
    def __init__(self, data, children=(), op=''):
        self.data: np.ndarray = np.array(data, dtype='float64')
        self.grad = np.zeros_like(data, dtype='float64')
        self._prev = set(children)
        self._backward = lambda : None
        self._op = op

    @property
    def shape(self) -> Tuple[int]:
        return self.data.shape
    
    @property
    def size(self) -> int: 
        return self.data.size
    
    def zero_grad(self) -> None:
        self.grad = np.zeros_like(self.data, dtype='float64')

    def item(self) -> np.ndarray:
        return self.data
    
    def _unbroadcast(self, grad: np.ndarray) -> Self:
        dims_to_remove = tuple(i for i in range(len(grad.shape) - len(self.shape))) 
        # remove prepended padding dimensions
        grad = np.sum(grad, axis=dims_to_remove, keepdims=False) 
        dims_to_reduce = tuple(i for i, (d1,d2) in enumerate(zip(grad.shape, self.shape)) if d1!=d2)
        # reduce broadcasted dimensions
        return np.sum(grad, axis=dims_to_reduce, keepdims=True)

    # need to build topo graph and then go through it and call backwards on each of the tensors
    def backward(self) -> None:
        self.grad = np.ones_like(self.data)
        topo = []
        visited = set()

        # do DFS on un-visited nodes, add node to topo-when all its children have been visited
        def build_topo(node):
            if node not in visited:
                visited.add(node)
                for child in node._prev:
                    build_topo(child)
                topo.append(node)
        build_topo(self)

        for node in reversed(topo):
            node._backward()
            
    def __add__(self, rhs) -> Self:
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data + rhs.data, (self, rhs), '+')

        def _backward():
            self.grad += self._unbroadcast(out.grad)
            rhs.grad += rhs._unbroadcast(out.grad)
        out._backward = _backward
        return out
    
    def __neg__(self) -> Self:
        out = Tensor(-self.data, (self,), 'neg')

        def _backward():
            self.grad += -out.grad
        out._backward = _backward
        return out
    
    def __sub__(self, rhs) -> Self:
        return self + (-rhs)

    def __mul__(self, rhs) -> Self:
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data*rhs.data, (self,), f'*')

        def _backward():
            self.grad += self._unbroadcast(out.grad * rhs.data)
            rhs.grad += rhs._unbroadcast(out.grad * self.data)
        out._backward = _backward
        return out
        
    def __truediv__(self, rhs) -> Self:
        return self * (rhs**-1)
      
    # TODO need to restrict the rhs input when lhs contains negative values and check grad is defined
    def __pow__(self, rhs) -> Self: 
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        # lhs_is_neg = np.all(self.data < 0)
        # rhs_is_frac = 
        out = Tensor(self.data**rhs.data, (self,), f'**')

        def _backward():
            self.grad += out.grad * ((rhs.data)*(self.data**(rhs.data-1)))
            rhs.grad += out.grad * (self.data ** rhs.data) * np.log(rhs.data)
        out._backward = _backward
        return out
    
    '''data shape: (da, ..., d2, d1, n, k) rhs shape: (ob, ..., o2, o1, k, m)
       inputs are broadcast so that they have the same shape by expanding along
       dimensions if possible, out shape: (tc, ..., t2, t1, n, m), where ti = max(di, oi)
       if di or oi does not exist it is treated as 1, and c = max d, a
       if self is 1d shape is prepended with a 1, for rhs it would be appended'''
    def __matmul__(self, rhs) -> Self:
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data @ rhs.data, (self, rhs), '@')

        def _backward():
            A, B, = self.data, rhs.data
            g = out.grad
            # broadcast 1d arrays to be 2d 
            A2 = A.reshape(1, -1) if len(A.shape) == 1 else A
            B2 = B.reshape(-1, 1) if len(B.shape) == 1 else B
            # extend g to have reduced dims
            g = np.expand_dims(g, -1) if len(B.shape) == 1 else g
            g = np.expand_dims(g, -2) if len(A.shape) == 1 else g
            # transpose last 2 dimensions, as matmul treats tensors as batched matricies
            self.grad += self._unbroadcast(g @ B2.swapaxes(-2, -1))
            rhs.grad += rhs._unbroadcast(A2.swapaxes(-2, -1) @ g)
        out._backward = _backward
        return out

    def relu(self) -> Self:
        out = Tensor((self.data > 0) * self.data, (self,), 'Relu')

        def _backward():
            self.grad += (self.data > 0) * out.grad
        out._backward = _backward
        return out
    
    # need to check inp is non-negative
    def log(self) -> Self:
        out = Tensor(np.log(self.data), (self,), 'log')

        def _backward():
            self.grad = self.data ** -1
        out.backward = _backward
        return out
    
    def exp(self) -> Self:
        out = Tensor(np.exp(self.data), (self,), 'exp')

        def _backward():
            self.grad = np.exp(self.data)
        out.backward = _backward
        return out
    
    def sum(self, axis=None) -> Self:
        out = Tensor(np.sum(self.data, axis=axis), (self,), 'sum')

        def _backward():
            g = np.expand_dims(out.grad, axis) if axis is not None else out.grad
            self.grad += g
        out._backward = _backward
        return out

    def mean(self, axis=None) -> Self:
        out = Tensor(np.mean(self.data, axis=axis), (self,), 'mean')

        def _backward():
            N = out.size // self.size 
            g = np.expand_dims(out.grad, axis) if axis is not None else out.grad
            self.grad += g / N
        out._backward = _backward
        return out
    
    def __radd__(self, lhs) -> Self:
        return self + lhs
    
    def __rsub__(self, lhs) -> Self:
        return self + lhs
    
    def __rmul__(self, lhs) -> Self:
        return self * lhs
    
    def __rtruediv__(self, lhs) -> Self:
        try:
            lhs = Tensor(lhs)
        except TypeError:
            return NotImplementedError
        return lhs / self
    
    def __rpow__(self, lhs) -> Self:
        try:
            lhs = Tensor(lhs)
        except TypeError:
            return NotImplementedError
        return lhs ** self
    
    def __rmatmul__(self, lhs) -> Self:
        try:
            lhs = Tensor(lhs)
        except TypeError:
            return NotImplementedError
        return lhs @ self
    
    @classmethod
    def random(cls, shape: tuple, bounds = (0,1)) -> Self:
        lower, upper = bounds
        data = RNG.random(shape, dtype=DTYPE)*(upper-lower) + lower
        return cls(data)
    
    def __repr__(self) -> str:
        return f'tensor shape: {self.shape}, op:{self._op}'        


In [9]:
class Parameter(Tensor):
    def __init__(self, data):
        super().__init__(data)
    
    @classmethod
    def kaiming(cls, fan_in, shape):
        std = np.sqrt(2/fan_in)
        weights = RNG.normal(0, std, shape)
        return cls(weights)
    
    @classmethod
    def zeros(cls, shape):
        return cls(np.zeros(shape, dtype=DTYPE))
    
    def __repr__(self) -> str:
        return f'parameter shape: {self.shape}, size: {self.size}' 

In [10]:
from abc import ABC, abstractmethod

class Module(ABC):
    
    def __call__(self, input: Tensor) -> Tensor:
        return self.forward(input)
    
    @property
    def modules(self) -> list[Self]:
        modules: list[Self] = []
        for value in self.__dict__.values():
            if isinstance(value, Module):
                modules.append(value)

            elif isinstance(value, dict):
                for v in value.values():
                    if isinstance(v, Module):
                        modules.append(v)

            elif isinstance(value, Iterable) and not isinstance(value, (str, bytes)):
                for v in value:
                    if isinstance(v, Module):
                        modules.append(v)
                    
        return modules
    
    @property
    def params(self) -> list[Parameter]:
        immediate_params = [attr for attr in self.__dict__.values() 
                                    if isinstance(attr, Parameter)]
        modules_params = [param for module in self.modules 
                                    for param in module.params]
        return immediate_params + modules_params
    
    @abstractmethod
    def forward(self, input: Tensor) -> Tensor:
        pass
    
    def zero_grad(self) -> None:
        for param in self.params:
            param.zero_grad()

class Sequential(Module):
    def __init__(self, layers):
        self.layers = layers
    
    def forward(self, input: Tensor) -> Tensor:
        x = input
        for layer in self.layers:
            x = layer(x)
        return x
    
class Affine(Module):
    def __init__(self, in_dim, out_dim):
        self.A = Parameter.kaiming(in_dim, (in_dim, out_dim))
        self.b = Parameter.zeros((out_dim))

    def forward(self, x):
        # x: (B, in), A : (in, out), B: out
        return (x @ self.A) + self.b

class Relu(Module):
    def forward(self, x):
        return x.relu()

In [11]:
class SoftMaxCrossEntropy():

    def __call__(z: Tensor, y) -> Tensor:
        '''logits z, shape (B, C), true lables y, shape (B, C)'''
        loss = ((z * y).sum(axis=-1) + ((z.exp()).sum(axis=-1)).log()).mean()
        return loss

class SGD():
    def __init__(self, params: list[Parameter], lr: float):
        self.lr = lr
        self.params = params
    
    def step(self) -> None:
        for param in self.params:
            param.data += -self.lr * param.grad

In [12]:
class Trainer():
    def __init__(self, model, optimiser, loss, train_loader, test_loader, logger, wandb_run = None):
        self.model = model
        self.optimiser = optimiser
        self.loss = loss
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.epoch = 1
        self.logger = logger
        self.wandb_run = wandb_run

    def train_epoch():
        pass

    def validate():
        pass
    
    def fit():
        pass
    
    def log_metrics():
        pass

In [13]:
feedforward = Sequential([Affine(50, 100), Relu(), Affine(100, 200), Relu(), Affine(200, 10)])
feedforward.params

[parameter shape: (50, 100), size: 5000,
 parameter shape: (100,), size: 100,
 parameter shape: (100, 200), size: 20000,
 parameter shape: (200,), size: 200,
 parameter shape: (200, 10), size: 2000,
 parameter shape: (10,), size: 10]

In [14]:
''' auto-grad testing suite
    TODO:
    - test all of the auto-grad primatives, 
    - test using central differences
    - test by modifying each parameter individually i.e. only do scalar pertubations
    - for now 
'''
RNG = np.random.default_rng()
EPS, ATOL, RTOL = 5e-7, 1e-6, 1e-3
K = 20

def compute_central_diff_error(test_fn, test_input, 
            other_inputs, eps, perturbed_idx, tols):
    '''verify auto-grad of funciton f: R^n -> R'''
    atol, rtol = tols

    # rescale epsilon and convert to tensor
    perturbed_val = test_input.data[perturbed_idx]
    eps = eps * (1 + abs(perturbed_val))
    pertubation_tensor = np.zeros_like(test_input.data, dtype='float64')
    pertubation_tensor[perturbed_idx] += eps 
    pertubation_tensor = Tensor(pertubation_tensor)

    # Compute grad
    for tensor in [test_input, *other_inputs]:
        tensor.zero_grad()
    clean_out = test_fn(test_input, other_inputs)
    clean_out.backward()
    auto_grad = test_input.grad[perturbed_idx]

    # Compute central diff Grad approximaiton
    test_forward = test_input + pertubation_tensor
    forward_out = test_fn(test_forward, other_inputs).item()
    test_back = test_input - pertubation_tensor
    back_out = test_fn(test_back, other_inputs).item()
    approx_grad = (forward_out - back_out)/ (2*eps)


    abs_err = abs(approx_grad - auto_grad)
    rel_err = abs_err / (abs(auto_grad) + atol)
    is_close = abs_err <= atol + rtol*abs(auto_grad)

    return (is_close, abs_err, rel_err)

# need to generate inputs, compute cd err and output/format test result, to log file maybe?
def test_fn_random_inputs(test_fn, test_shape, other_shapes=[], input_bounds=(-10, 10),
                          num_samples=K, eps=EPS, tols=(ATOL, RTOL)):
    
    test_input = Tensor.random(test_shape, input_bounds)
    other_inputs = [Tensor.random(shape, input_bounds) for shape in other_shapes]

    num_samples = min(test_input.size, num_samples)
    pertubation_nums = RNG.choice(test_input.size, size=num_samples, replace=False)
    pretubation_idxs = np.unravel_index(pertubation_nums, test_shape)

    all_close = True
    failed = 0
    log = inspect.getsource(test_fn) + '\n' 
    log += f'test input \n {test_input.data} \nother inputs \n'
    for other_input in other_inputs:
        log += f' {other_input.data} \n'
    for sample_i in range(num_samples):
        perturbed_idx = tuple(pert_dim[sample_i] for pert_dim in pretubation_idxs)
        is_close, abs_err, rel_err = compute_central_diff_error(test_fn, test_input, 
                                        other_inputs, eps, perturbed_idx, tols)
        log += f'test {'passed' if is_close else 'failed'}: abs err = {abs_err}, rel err = {rel_err}, perturbed idx = {perturbed_idx} \n'
        if not is_close:
            all_close = False
            failed += 1
            # some logic for logging the failed case

    return (all_close, log)
        

In [46]:
test_shape, other_shapes = (2, 3), (3,2)
test_fn = lambda test, other: (test.relu()).sum()
all_close, log = test_fn_random_inputs(test_fn, test_shape, other_shapes)

print(all_close)
print(log)

True
test_fn = lambda test, other: (test.relu()).sum()

test input 
 [[ 4.36584424 -6.89218445  0.72963676]
 [ 4.53202843 -7.43184834 -4.14504382]] 
other inputs 
 [-9.61942403  7.20280344 -9.81055221] 
 [-2.68631656  4.12879115] 
test passed: abs err = 0.0, rel err = 0.0, perturbed idx = (np.int64(1), np.int64(2)) 
test passed: abs err = 1.321722731262298e-10, rel err = 1.3217214095408887e-10, perturbed idx = (np.int64(0), np.int64(2)) 
test passed: abs err = 2.2529045295982542e-10, rel err = 2.2529022766959776e-10, perturbed idx = (np.int64(1), np.int64(0)) 
test passed: abs err = 3.765547873513242e-10, rel err = 3.765544107969134e-10, perturbed idx = (np.int64(0), np.int64(0)) 
test passed: abs err = 0.0, rel err = 0.0, perturbed idx = (np.int64(1), np.int64(1)) 
test passed: abs err = 0.0, rel err = 0.0, perturbed idx = (np.int64(0), np.int64(1)) 



In [50]:
feedforward = Sequential([Affine(50, 100), Relu(), Affine(100, 200), Relu(), Affine(200, 10)])
test_shape, other_shapes = (50, 50), (1,)
test_fn = lambda test, other: feedforward(test).sum()
all_close, log = test_fn_random_inputs(test_fn, test_shape, other_shapes)

print(all_close)
print(log)

True
test_fn = lambda test, other: feedforward(test).sum()

test input 
 [[ 5.24869302 -7.32949073 -4.59527146 ...  0.41753385  3.62603279
   4.30435232]
 [ 6.13197392  0.2182355  -8.03118554 ...  3.53622372 -2.29636021
   2.96845961]
 [ 5.90608006  1.59010563  4.56163719 ...  1.71477364  2.18141702
   3.83171408]
 ...
 [ 1.68096585 -9.90775878 -3.22444042 ... -0.40164268 -3.23570823
  -8.56105366]
 [-3.28470248  7.18695545 -9.25924831 ...  7.26671642 -2.33127542
  -7.11449082]
 [-2.7405392  -6.70604766  1.11649571 ...  5.16444245  7.65561468
  -8.81066537]] 
other inputs 
 [1.57497144] 
test passed: abs err = 1.1500378649387244e-08, rel err = 2.763515178640418e-08, perturbed idx = (np.int64(27), np.int64(8)) 
test passed: abs err = 3.767702483337132e-08, rel err = 2.9018792550391e-08, perturbed idx = (np.int64(14), np.int64(33)) 
test passed: abs err = 2.928049241956998e-08, rel err = 1.2276544279545665e-07, perturbed idx = (np.int64(4), np.int64(19)) 
test passed: abs err = 4.2819723

In [17]:
A = Tensor.random((2, 50))
feedforward = Sequential([Affine(50, 100), Relu()])
B = feedforward(A)
print(B)

tensor shape: (2, 100), op:Relu
