In [170]:
import numpy as np
from typing import Union, Tuple, Self, Iterable
import inspect

In [171]:
# regular constants
RNG = np.random.default_rng()
DTYPE = 'float64' 

# testing constants
if DTYPE=='float64':
    EPS, ATOL, RTOL = 1e-6, 1e-5, 1e-3
else:
    EPS, ATOL, RTOL = 1e-4, 1e-4, 1e-2
K = 20

In [None]:
'''TODO 
- loss functions, implement and check correctness (CE/+softmax, mse)
- implement adam 
- make constants/ rngs be in a seperate file
- add in requires grad functionality (enable grad, context manager etc)
- add in more layers convolutions, softmax, dropout, batch norm (possibly)
- todo implement dataloaders 
- add logging, WandB and also terminal logging
- add in auto grad visulisation
'''

class Tensor():
    def __init__(self, data, requires_grad=False, children=(), op=''):
        self.data: np.ndarray = np.array(data, dtype=DTYPE)
        self.grad = np.zeros_like(data, dtype=DTYPE)
        self.requires_grad = requires_grad
        self._prev = set(children)
        self._backward = lambda : None
        self._op = op

    @property
    def shape(self) -> Tuple[int]:
        return self.data.shape
    
    @property
    def size(self) -> int: 
        return self.data.size
    
    def zero_grad(self) -> None:
        self.grad = np.zeros_like(self.data, dtype=DTYPE)

    def item(self) -> np.ndarray:
        return self.data
    
    def _unbroadcast(self, grad: np.ndarray) -> Self:
        dims_to_remove = tuple(i for i in range(len(grad.shape) - len(self.shape))) 
        # remove prepended padding dimensions
        grad = np.sum(grad, axis=dims_to_remove, keepdims=False) 
        dims_to_reduce = tuple(i for i, (d1,d2) in enumerate(zip(grad.shape, self.shape)) if d1!=d2)
        # reduce broadcasted dimensions
        return np.sum(grad, axis=dims_to_reduce, keepdims=True)

    # need to build topo graph and then go through it and call backwards on each of the tensors
    def backward(self) -> None:
        self.grad = np.ones_like(self.data)
        topo = []
        visited = set()

        # do DFS on un-visited nodes, add node to topo-when all its children have been visited
        def build_topo(node):
            if node not in visited:
                visited.add(node)
                for child in node._prev:
                    build_topo(child)
                topo.append(node)
        build_topo(self)

        for node in reversed(topo):
            node._backward()
            
    def __add__(self, rhs) -> Self:
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data + rhs.data, self.requires_grad or rhs.requires_grad, (self, rhs), '+')

        def _backward():
            if self.requires_grad:
                self.grad += self._unbroadcast(out.grad)
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(out.grad)
        out._backward = _backward
        return out
    
    def __neg__(self) -> Self:
        out = Tensor(-self.data, self.requires_grad, (self,), 'neg')

        def _backward():
            if self.requires_grad:
                self.grad += -out.grad
        out._backward = _backward
        return out
    
    def __sub__(self, rhs) -> Self:
        return self + (-rhs)

    def __mul__(self, rhs) -> Self:
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data*rhs.data, self.requires_grad or rhs.requires_grad, (self, rhs), f'*')

        def _backward():
            if self.requires_grad:
                self.grad += self._unbroadcast(out.grad * rhs.data)
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(out.grad * self.data)
        out._backward = _backward
        return out
        
    def __truediv__(self, rhs) -> Self:
        return self * (rhs**-1)
    
    # TODO add check for rhs, if epxponent if negative the gradient is undefined
    def __pow__(self, rhs) -> Self: 
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        lhs_is_neg = self.data < 0
        rhs_is_frac = ~np.isclose(rhs.data % 1, 0)
        if np.any(lhs_is_neg & rhs_is_frac):
            raise ValueError('cannot raise negative value to a decimal power')
        
        out = Tensor(self.data**rhs.data, self.requires_grad or rhs.requires_grad, (self,), f'**')

        def _backward():
            if self.requires_grad:
                self.grad += self._unbroadcast(out.grad * ((rhs.data)*(self.data**(rhs.data-1))))
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(out.grad * (self.data ** rhs.data) * np.log(self.data))
        out._backward = _backward
        return out
    
    '''data shape: (da, ..., d2, d1, n, k) rhs shape: (ob, ..., o2, o1, k, m)
       inputs are broadcast so that they have the same shape by expanding along
       dimensions if possible, out shape: (tc, ..., t2, t1, n, m), where ti = max(di, oi)
       if di or oi does not exist it is treated as 1, and c = max d, a
       if self is 1d shape is prepended with a 1, for rhs it would be appended'''
    def __matmul__(self, rhs) -> Self:
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data @ rhs.data, self.requires_grad or rhs.requires_grad, (self, rhs), '@')

        def _backward():
            A, B, = self.data, rhs.data
            g = out.grad
            # broadcast 1d arrays to be 2d 
            A2 = A.reshape(1, -1) if len(A.shape) == 1 else A
            B2 = B.reshape(-1, 1) if len(B.shape) == 1 else B
            # extend g to have reduced dims
            g = np.expand_dims(g, -1) if len(B.shape) == 1 else g
            g = np.expand_dims(g, -2) if len(A.shape) == 1 else g
            # transpose last 2 dimensions, as matmul treats tensors as batched matricies
            if self.requires_grad:
                self.grad += self._unbroadcast(g @ B2.swapaxes(-2, -1))
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(A2.swapaxes(-2, -1) @ g)
        out._backward = _backward
        return out

    def relu(self) -> Self:
        out = Tensor((self.data > 0) * self.data, self.requires_grad, (self,), 'Relu')

        def _backward():
            if self.requires_grad:
                self.grad += (self.data > 0) * out.grad
        out._backward = _backward
        return out
    
    # need to check inp is non-negative
    def log(self) -> Self:
        if np.any(self.data < 0):
            raise ValueError('cannot log negative values')
        out = Tensor(np.log(self.data), self.requires_grad, (self,), 'log')

        def _backward():
            if self.requires_grad:
                self.grad += 1 / self.data 
        out._backward = _backward
        return out
    
    def exp(self) -> Self:
        out = Tensor(np.exp(self.data), self.requires_grad, (self,), 'exp')

        def _backward():
            if self.requires_grad:
                self.grad += np.exp(self.data)
        out._backward = _backward
        return out
    
    def sum(self, axis=None) -> Self:
        out = Tensor(np.sum(self.data, axis=axis), self.requires_grad, (self,), 'sum')

        def _backward():
            if self.requires_grad:
                g = np.expand_dims(out.grad, axis) if axis is not None else out.grad
                self.grad += g
        out._backward = _backward
        return out

    def mean(self, axis=None) -> Self:
        out = Tensor(np.mean(self.data, axis=axis), self.requires_grad, (self,), 'mean')

        def _backward():
            if self.requires_grad:
                N =  self.size // out.size 
                g = np.expand_dims(out.grad, axis) if axis is not None else out.grad
                self.grad += g / N
        out._backward = _backward
        return out
    
    def __radd__(self, lhs) -> Self:
        return self + lhs
    
    def __rsub__(self, lhs) -> Self:
        return self + lhs
    
    def __rmul__(self, lhs) -> Self:
        return self * lhs
    
    def __rtruediv__(self, lhs) -> Self:
        try:
            lhs = Tensor(lhs)
        except TypeError:
            return NotImplementedError
        return lhs / self
    
    def __rpow__(self, lhs) -> Self:
        try:
            lhs = Tensor(lhs)
        except TypeError:
            return NotImplementedError
        return lhs ** self
    
    def __rmatmul__(self, lhs) -> Self:
        try:
            lhs = Tensor(lhs)
        except TypeError:
            return NotImplementedError
        return lhs @ self
    
    @classmethod
    def random(cls, shape: tuple, bounds = (0,1), requires_grad=False) -> Self:
        lower, upper = bounds
        data = RNG.random(shape, dtype=DTYPE)*(upper-lower) + lower
        return cls(data, requires_grad=requires_grad)
    
    def __repr__(self) -> str:
        return f'tensor shape: {self.shape}, op:{self._op}'        


In [173]:
class Parameter(Tensor):
    def __init__(self, data):
        super().__init__(data, requires_grad=True)
    
    @classmethod
    def kaiming(cls, fan_in, shape):
        std = np.sqrt(2/fan_in)
        weights = RNG.standard_normal(shape, dtype=DTYPE)*std
        return cls(weights)
    
    @classmethod
    def zeros(cls, shape):
        return cls(np.zeros(shape, dtype=DTYPE))
    
    def __repr__(self) -> str:
        return f'parameter shape: {self.shape}, size: {self.size}' 

In [None]:
from abc import ABC, abstractmethod

class Module(ABC):
    
    def __call__(self, input: Tensor) -> Tensor:
        return self.forward(input)
    
    @property
    def modules(self) -> list[Self]:
        modules: list[Self] = []
        for value in self.__dict__.values():
            if isinstance(value, Module):
                modules.append(value)

            elif isinstance(value, dict):
                for v in value.values():
                    if isinstance(v, Module):
                        modules.append(v)

            elif isinstance(value, Iterable) and not isinstance(value, (str, bytes)):
                for v in value:
                    if isinstance(v, Module):
                        modules.append(v)
                    
        return modules
    
    @property
    def params(self) -> list[Parameter]:
        immediate_params = [attr for attr in self.__dict__.values() 
                                    if isinstance(attr, Parameter)]
        modules_params = [param for module in self.modules 
                                    for param in module.params]
        return immediate_params + modules_params
    
    @abstractmethod
    def forward(self, input: Tensor) -> Tensor:
        pass
    
    def zero_grad(self) -> None:
        for param in self.params:
            param.zero_grad()

class Sequential(Module):
    def __init__(self, layers):
        self.layers = layers
    
    def forward(self, input: Tensor) -> Tensor:
        x = input
        for layer in self.layers:
            x = layer(x)
        return x
    
class Affine(Module):
    def __init__(self, in_dim, out_dim):
        self.A = Parameter.kaiming(in_dim, (in_dim, out_dim))
        self.b = Parameter.zeros((out_dim))

    def forward(self, x: Tensor):
        # x: (B, in), A : (in, out), B: out
        return (x @ self.A) + self.b

class Relu(Module):
    def forward(self, x: Tensor):
        return x.relu()
    
class SoftMax(Module):
    def forward(self, x: Tensor):
        x = x.exp()
        norm_c = x.sum()
        return x / norm_c

In [175]:
class SoftMaxCrossEntropy():

    def __call__(z: Tensor, y) -> Tensor:
        '''logits z, shape (B, C), true lables y, shape (B, C)'''
        loss = ((z * y).sum(axis=-1) + ((z.exp()).sum(axis=-1)).log()).mean()
        return loss

class SGD():
    def __init__(self, params: list[Parameter], lr: float):
        self.lr = lr
        self.params = params
    
    def step(self) -> None:
        for param in self.params:
            param.data += -self.lr * param.grad

In [176]:
class Trainer():
    def __init__(self, model, optimiser, loss, train_loader, test_loader, logger, wandb_run = None):
        self.model = model
        self.optimiser = optimiser
        self.loss = loss
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.epoch = 1
        self.logger = logger
        self.wandb_run = wandb_run

    def train_epoch():
        pass

    def validate():
        pass
    
    def fit():
        pass
    
    def log_metrics():
        pass

In [177]:
feedforward = Sequential([Affine(50, 100), Relu(), Affine(100, 200), Relu(), Affine(200, 10)])
feedforward.params

[parameter shape: (50, 100), size: 5000,
 parameter shape: (100,), size: 100,
 parameter shape: (100, 200), size: 20000,
 parameter shape: (200,), size: 200,
 parameter shape: (200, 10), size: 2000,
 parameter shape: (10,), size: 10]

In [178]:
''' auto-grad testing suite
    TODO:
    - test all of the auto-grad primatives, 
    - test using central differences
    - test by modifying each parameter individually i.e. only do scalar pertubations
'''

def compute_central_diff_error(test_fn, test_input, 
            other_inputs, eps, perturbed_idx, tols):
    '''verify auto-grad of funciton f: R^n -> R'''
    atol, rtol = tols

    # rescale epsilon and convert to tensor
    perturbed_val = test_input.data[perturbed_idx]
    eps = eps * (1 + abs(perturbed_val))
    pertubation_tensor = np.zeros_like(test_input.data, dtype=DTYPE)
    pertubation_tensor[perturbed_idx] += eps 
    pertubation_tensor = Tensor(pertubation_tensor)

    # Compute grad
    for tensor in [test_input, *other_inputs]:
        tensor.zero_grad()
    clean_out = test_fn(test_input, other_inputs)
    clean_out.backward()
    auto_grad = test_input.grad[perturbed_idx]

    # Compute central diff Grad approximaiton
    test_forward = test_input + pertubation_tensor
    forward_out = test_fn(test_forward, other_inputs).item()
    test_back = test_input - pertubation_tensor
    back_out = test_fn(test_back, other_inputs).item()
    approx_grad = (forward_out - back_out) / (2*eps)


    abs_err = abs(approx_grad - auto_grad)
    rel_err = abs_err / (abs(auto_grad) + atol)
    is_close = abs_err <= atol + rtol*abs(auto_grad)

    return is_close, abs_err, rel_err, clean_out.item(), forward_out, back_out

# need to generate inputs, compute cd err and output/format test result, to log file maybe?
def test_fn_random_inputs(test_fn, test_shape, other_shapes=[], input_bounds=(-5, 5),
                          num_samples=K, eps=EPS, tols=(ATOL, RTOL)):
    
    test_input = Tensor.random(test_shape, input_bounds, requires_grad=True)
    other_inputs = [Tensor.random(shape, input_bounds) for shape in other_shapes]

    num_samples = min(test_input.size, num_samples)
    pertubation_nums = RNG.choice(test_input.size, size=num_samples, replace=False)
    pretubation_idxs = np.unravel_index(pertubation_nums, test_shape)

    all_close = True
    failed = 0
    # log = inspect.getsource(test_fn) + '\n' 
    log = ''
    log += f'test input \n {test_input.data} \nother inputs \n'
    for other_input in other_inputs:
        log += f' {other_input.data} \n'
    for sample_i in range(num_samples):
        perturbed_idx = tuple(int(pert_dim[sample_i]) for pert_dim in pretubation_idxs)
        is_close, abs_err, rel_err, clean_out, forward_out, back_out = compute_central_diff_error(
                                        test_fn, test_input, other_inputs, eps, perturbed_idx, tols)
        log += f'test {'passed' if is_close else 'failed'}: abs err = {abs_err:.4f}, rel err = {rel_err:.4f}, perturbed idx = {perturbed_idx} \n'
        log += f'clean_out: {clean_out} forward_out: {forward_out} back_out: {back_out} \n'
        if not is_close:
            all_close = False
            failed += 1
            # some logic for logging the failed case

    return all_close, log
        

In [179]:
bin_ufuncs = {'add' : lambda test_inp, other_inps: (test_inp+other_inps[0]).sum(),
              'radd': lambda test_inp, other_inps: (other_inps[0]+test_inp).sum(),
              'sub' : lambda test_inp, other_inps: (test_inp-other_inps[0]).sum(),
              'rsub': lambda test_inp, other_inps: (other_inps[0]-test_inp).sum(),
              'mul' : lambda test_inp, other_inps: (test_inp*other_inps[0]).sum(),
              'rmul': lambda test_inp, other_inps: (other_inps[0]*test_inp).sum(),
              'pow' : lambda test_inp, other_inps: (test_inp**other_inps[0]).sum(),
              'rpow': lambda test_inp, other_inps: (other_inps[0]**test_inp).sum(),
              'truediv' : lambda test_inp, other_inps: (test_inp/other_inps[0]).sum(),
              'rtruediv': lambda test_inp, other_inps: (other_inps[0]/test_inp).sum(),}

matmul_fns = {'matmul': lambda test_inp, other_inps: (test_inp@other_inps[0]).sum(),
              'rmatmul': lambda test_inp, other_inps: (other_inps[0]@test_inp).sum(),}

unary_ufunc = {'relu': lambda test_inp, other_inps: (test_inp.relu()).sum(),
            'log': lambda test_inp, other_inps: (test_inp.log()).sum(),
            'exp': lambda test_inp, other_inps: (test_inp.exp()).sum(),
            'sum': lambda test_inp, other_inps: test_inp.sum(),
            'mean': lambda test_inp, other_inps: test_inp.mean(),}
                      

In [180]:
for func_name, test_fn in unary_ufunc.items():
    test_shape, other_shapes = (2, 3), [(3,2)]
    input_bounds = (1, 10) if func_name == 'log' else (-5, 5)
    all_close, log = test_fn_random_inputs(test_fn, test_shape, other_shapes, input_bounds=input_bounds)
    print(f'function: {func_name} {'passed' if all_close else 'failed'}')
    if not all_close:
        print(log)

for func_name, test_fn in matmul_fns.items():
    test_shape = (2, 3) if func_name == 'matmul' else (3, 2)
    other_shapes = [(3, 2)] if func_name == 'matmul' else [(2, 3)]
    all_close, log = test_fn_random_inputs(test_fn, test_shape, other_shapes, input_bounds=input_bounds)
    print(f'function: {func_name} {'passed' if all_close else 'failed'}')
    if not all_close:
        print(log)

for func_name, test_fn in bin_ufuncs.items():
    test_shape, other_shapes = (2, 3), [(2,3)]
    input_bounds = (1, 5) if (func_name == 'pow' or func_name == 'rpow') else (-5, 5)
    all_close, log = test_fn_random_inputs(test_fn, test_shape, other_shapes, input_bounds=input_bounds)
    print(f'function: {func_name} {'passed' if all_close else 'failed'}')
    if not all_close:
        print(log)

function: relu passed
function: log passed
function: exp passed
function: sum passed
function: mean passed
function: matmul passed
function: rmatmul passed
function: add passed
function: radd passed
function: sub passed
function: rsub passed
function: mul passed
function: rmul passed
function: pow passed
function: rpow passed
function: truediv passed
function: rtruediv passed


In [199]:
feedforward = Sequential([Affine(50, 1000), Relu(), Affine(1000, 2000), Relu(), Affine(2000, 10)])
feedforward.params

[parameter shape: (50, 1000), size: 50000,
 parameter shape: (1000,), size: 1000,
 parameter shape: (1000, 2000), size: 2000000,
 parameter shape: (2000,), size: 2000,
 parameter shape: (2000, 10), size: 20000,
 parameter shape: (10,), size: 10]

In [202]:
test_shape, other_shapes = (1000, 50), (1,)
test_fn = lambda test, other: feedforward(test).sum()
all_close, log = test_fn_random_inputs(test_fn, test_shape, other_shapes)

print(all_close)
print(log)

True
test input 
 [[ 3.75026808  3.79064791  3.35255986 ...  2.32166393  0.74594803
  -0.46479737]
 [-4.82424674  3.04718977 -0.47803226 ... -0.06946837 -4.47902869
  -4.57079004]
 [-1.10095121  4.44193048 -2.13587614 ...  2.43866997 -1.71638962
  -1.90082661]
 ...
 [-1.16603229 -1.89337666 -3.99321261 ...  1.50069455 -3.46158155
   0.05832781]
 [ 0.54550074  0.25045465 -4.41086695 ...  4.59098623  2.6388176
   4.59158431]
 [-1.03219612  0.01671006  1.78371654 ... -0.51904045 -4.09778641
  -4.67551704]] 
other inputs 
 [3.17973245] 
test passed: abs err = 0.0000, rel err = 0.0000, perturbed idx = (250, 23) 
clean_out: 5613.615594901011 forward_out: 5613.615592911575 back_out: 5613.615596890448 
test passed: abs err = 0.0000, rel err = 0.0000, perturbed idx = (734, 37) 
clean_out: 5613.615594901011 forward_out: 5613.615592760412 back_out: 5613.615597041609 
test passed: abs err = 0.0000, rel err = 0.0000, perturbed idx = (176, 8) 
clean_out: 5613.615594901011 forward_out: 5613.615594970

In [195]:
z = Tensor(np.array([[-100, 100], [100, -100]]), requires_grad=True)
y = Tensor(np.array([[1.0, 0], [0.0, 1.0]]), requires_grad=True)
crossentropy = SoftMaxCrossEntropy()
b = ((z * y).sum(axis=-1) + ((z.exp()).sum(axis=-1)).log()).mean()
# b = crossentropy(a, c)
b.backward()
print(b.data)
print(z.grad)
print(y.grad)

0.0
[[5.00000000e-01 2.68811714e+43]
 [2.68811714e+43 5.00000000e-01]]
[[-50.  50.]
 [ 50. -50.]]
