In [1]:
import numpy as np
from typing import Union, Tuple, Self, Iterable
import matplotlib.pyplot as plt
from dotenv import load_dotenv
import os
import wandb
import time
import psutil
from typing import Optional, Literal
import functools, time

In [None]:
# regular constants
RNG = np.random.default_rng()
DTYPE = 'float32' 

proc  = psutil.Process(os.getpid())

# testing constants
if DTYPE=='float64':
    EPS, ATOL, RTOL = 1e-6, 1e-5, 1e-3
else:
    EPS, ATOL, RTOL = 1e-4, 1e-4, 1e-2
K = 20

dtype_eps = {'float16': 1e-4,
             'float32': 1e-7,
             'float64': 1e-15}[DTYPE]

Mode = Literal['train', 'eval']

In [5]:
class Tensor():
    def __init__(self, data, requires_grad=False, children=(), op=''):
        self.data: np.ndarray = np.array(data, dtype=DTYPE)
        self.grad = np.zeros_like(data, dtype=DTYPE)
        self.requires_grad = requires_grad
        self._prev = set(children)
        self._backward = lambda : None
        self._op = op

    @property
    def shape(self) -> Tuple[int]:
        return self.data.shape
    
    @property
    def size(self) -> int: 
        return self.data.size
    
    def zero_grad(self) -> None:
        self.grad = np.zeros_like(self.data, dtype=DTYPE)

    def item(self) -> np.ndarray:
        return self.data
    
    def _unbroadcast(self, grad: np.ndarray) -> np.ndarray:
        dims_to_remove = tuple(i for i in range(len(grad.shape) - len(self.shape))) 
        # remove prepended padding dimensions
        grad = np.sum(grad, axis=dims_to_remove, keepdims=False) 
        dims_to_reduce = tuple(i for i, (d1,d2) in enumerate(zip(grad.shape, self.shape)) if d1!=d2)
        # reduce broadcasted dimensions
        return np.sum(grad, axis=dims_to_reduce, keepdims=True)

    # need to build topo graph and then go through it and call backwards on each of the tensors
    def backward(self) -> None:
        self.grad = np.ones_like(self.data)
        topo = []
        visited = set()

        # do DFS on un-visited nodes, add node to topo-when all its children have been visited
        def build_topo(node):
            if node not in visited:
                visited.add(node)
                for child in node._prev:
                    build_topo(child)
                topo.append(node)
        build_topo(self)

        for node in reversed(topo):
            node._backward()
            node._prev = set(())
            node._backward = lambda : None

    def __getitem__(self, indexes):
        out = Tensor(self.data[indexes], self.requires_grad, (self), 'slice')

        def _backward():
            if self.requires_grad:
                pass
        out._backward = _backward
        return out
            
    def __add__(self, rhs) -> "Tensor":
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data + rhs.data, self.requires_grad or rhs.requires_grad, (self, rhs), '+')

        def _backward():
            if self.requires_grad:
                self.grad += self._unbroadcast(out.grad)
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(out.grad)
        out._backward = _backward
        return out
    
    def __neg__(self) -> "Tensor":
        out = Tensor(-self.data, self.requires_grad, (self,), 'neg')

        def _backward():
            if self.requires_grad:
                self.grad += -out.grad
        out._backward = _backward
        return out
    
    def __sub__(self, rhs) -> "Tensor":
        return self + (-rhs)

    def __mul__(self, rhs) -> "Tensor":
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data*rhs.data, self.requires_grad or rhs.requires_grad, (self, rhs), f'*')

        def _backward():
            if self.requires_grad:
                self.grad += self._unbroadcast(out.grad * rhs.data)
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(out.grad * self.data)
        out._backward = _backward
        return out
        
    def __truediv__(self, rhs) -> "Tensor":
        return self * (rhs**-1)
    
    # TODO add check for rhs, if epxponent if negative the gradient is undefined
    def __pow__(self, rhs) -> "Tensor": 
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        lhs_is_neg = self.data < 0
        rhs_is_frac = ~np.isclose(rhs.data % 1, 0)
        if np.any(lhs_is_neg & rhs_is_frac):
            raise ValueError('cannot raise negative value to a decimal power')
        
        out = Tensor(self.data**rhs.data, self.requires_grad or rhs.requires_grad, (self,), f'**')

        def _backward():
            if self.requires_grad:
                self.grad += self._unbroadcast(out.grad * ((rhs.data)*(self.data**(rhs.data-1))))
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(out.grad * (self.data ** rhs.data) * np.log(self.data))
        out._backward = _backward
        return out
    
    '''data shape: (da, ..., d2, d1, n, k) rhs shape: (ob, ..., o2, o1, k, m)
       inputs are broadcast so that they have the same shape by expanding along
       dimensions if possible, out shape: (tc, ..., t2, t1, n, m), where ti = max(di, oi)
       if di or oi does not exist it is treated as 1, and c = max d, a
       if self is 1d shape is prepended with a 1, for rhs it would be appended'''
    def __matmul__(self, rhs) -> "Tensor":
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data @ rhs.data, self.requires_grad or rhs.requires_grad, (self, rhs), '@')

        def _backward():
            A, B, = self.data, rhs.data
            g = out.grad
            # broadcast 1d arrays to be 2d 
            A2 = A.reshape(1, -1) if len(A.shape) == 1 else A
            B2 = B.reshape(-1, 1) if len(B.shape) == 1 else B
            # extend g to have reduced dims
            g = np.expand_dims(g, -1) if len(B.shape) == 1 else g
            g = np.expand_dims(g, -2) if len(A.shape) == 1 else g
            # transpose last 2 dimensions, as matmul treats tensors as batched matricies
            if self.requires_grad:
                self.grad += self._unbroadcast(g @ B2.swapaxes(-2, -1))
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(A2.swapaxes(-2, -1) @ g)
        out._backward = _backward
        return out

    def relu(self) -> "Tensor":
        out = Tensor((self.data > 0) * self.data, self.requires_grad, (self,), 'Relu')

        def _backward():
            if self.requires_grad:
                self.grad += (self.data > 0) * out.grad
        out._backward = _backward
        return out
    
    def log(self) -> "Tensor":
        if np.any(self.data < 0):
            raise ValueError('cannot log negative values')
        out = Tensor(np.log(self.data), self.requires_grad, (self,), 'log')

        def _backward():
            if self.requires_grad:
                self.grad += (1 / self.data) * out.grad
        out._backward = _backward
        return out
    
    def exp(self) -> "Tensor":
        out = Tensor(np.exp(self.data), self.requires_grad, (self,), 'exp')

        def _backward():
            if self.requires_grad:
                self.grad += np.exp(self.data) * out.grad
        out._backward = _backward
        return out
    
    def sum(self, axis=None, keepdims=False) -> "Tensor":
        out = Tensor(np.sum(self.data, axis=axis, keepdims=keepdims), self.requires_grad, (self,), 'sum')

        def _backward():
            if self.requires_grad:
                g = np.expand_dims(out.grad, axis) if (axis is not None and not keepdims) else out.grad
                self.grad += g
        out._backward = _backward
        return out

    def mean(self, axis=None) -> "Tensor":
        out = Tensor(np.mean(self.data, axis=axis), self.requires_grad, (self,), 'mean')

        def _backward():
            if self.requires_grad:
                N =  self.size // out.size 
                g = np.expand_dims(out.grad, axis) if axis is not None else out.grad
                self.grad += g / N
        out._backward = _backward
        return out
    
    def clamp(self, a_min=None, a_max=None):
        out = Tensor(np.clip(self.data, a_min=a_min, a_max=a_max), self.requires_grad, (self,), 'clamp')

        def _backward():
            if self.requires_grad:
                mask = (self.data > a_min) if a_min is not None else np.ones_like(self.data)
                mask = mask & (self.data < a_max) if a_max is not None else mask
                self.grad += out.grad * mask
        out._backward = _backward
        return out
    
    def __radd__(self, lhs) -> "Tensor":
        return self + lhs
    
    def __rsub__(self, lhs) -> "Tensor":
        return self + lhs
    
    def __rmul__(self, lhs) -> "Tensor":
        return self * lhs
    
    def __rtruediv__(self, lhs) -> "Tensor":
        return Tensor(lhs) / self
    
    def __rpow__(self, lhs) -> "Tensor":
        return Tensor(lhs) ** self
    
    def __rmatmul__(self, lhs) -> "Tensor":
        return Tensor(lhs) @ self
    
    @classmethod
    def random(cls, shape: tuple, bounds = (0,1), requires_grad=False) -> "Tensor":
        lower, upper = bounds
        data = RNG.random(shape, dtype=DTYPE)*(upper-lower) + lower
        return cls(data, requires_grad=requires_grad)
    
    def __repr__(self) -> str:
        return f'tensor shape: {self.shape}, op:{self._op}'        


In [6]:
class Parameter(Tensor):
    def __init__(self, data):
        super().__init__(data, requires_grad=True)
    
    @classmethod
    def kaiming(cls, fan_in, shape):
        std = np.sqrt(2/fan_in)
        weights = (RNG.standard_normal(shape, dtype='float64')*std).astype(dtype=DTYPE)
        return cls(weights)
    
    @classmethod
    def zeros(cls, shape):
        return cls(np.zeros(shape, dtype=DTYPE))
    
    def __repr__(self) -> str:
        return f'parameter shape: {self.shape}, size: {self.size}' 

In [7]:
from abc import abstractmethod

class Module():
    
    def __call__(self, input: Tensor) -> Tensor:
        return self.forward(input)
    
    @property
    def modules(self) -> list["Module"]:
        modules: list[Module] = []
        for value in self.__dict__.values():
            if isinstance(value, Module):
                modules.append(value)

            elif isinstance(value, dict):
                for v in value.values():
                    if isinstance(v, Module):
                        modules.append(v)

            elif isinstance(value, Iterable) and not isinstance(value, (str, bytes)):
                for v in value:
                    if isinstance(v, Module):
                        modules.append(v)
                    
        return modules
    
    @property
    def params(self) -> list[Parameter]:
        immediate_params = [attr for attr in self.__dict__.values() 
                                    if isinstance(attr, Parameter)]
        modules_params = [param for module in self.modules 
                                    for param in module.params]
        return immediate_params + modules_params
    
    @abstractmethod
    def forward(self, input: Tensor) -> Tensor:
        pass
    
    def zero_grad(self) -> None:
        for param in self.params:
            param.zero_grad()

    def train(self) -> None:
        for param in self.params:
            param.requires_grad = True

        for module in self.modules:
            if isinstance(module, DynamicModule):
                module.mode = 'train'
        
    def eval(self) -> None:
        for param in self.params:
            param.requires_grad = False

        for module in self.modules:
            if isinstance(module, DynamicModule):
                module.mode = 'eval'

class Sequential(Module):
    def __init__(self, layers):
        self.layers = layers
    
    def forward(self, input: Tensor) -> Tensor:
        x = input
        for layer in self.layers:
            x = layer(x)
        return x
    
class Affine(Module):
    def __init__(self, in_dim, out_dim):
        self.A = Parameter.kaiming(in_dim, (in_dim, out_dim))
        self.b = Parameter.zeros((out_dim))

    def forward(self, input: Tensor):
        x = input
        # x: (B, in), A : (in, out), B: out
        return (x @ self.A) + self.b
    
class DynamicModule(Module):
    def __init__(self) -> None:
        self.mode = 'train'

class DropOut(DynamicModule):
    def __init__(self, p):
        self.mode = 'train'
        self.p = p

    def forward(self, x: Tensor):
        if self.mode == 'eval':
            return x * (1-self.p) # have to rescale during inference
        mask_idx_nums = RNG.choice(x.size, size=int(x.size*(1-self.p)), replace=False)
        mask_idxs = np.unravel_index(mask_idx_nums, x.shape)
        mask = np.zeros_like(x.data)
        mask[mask_idxs] = 1
        
        return x * mask
    
    '''TODO: can implement in the furture to make it faster once __getitem__ is implemented'''
    # def forward(self, x: Tensor):
    #     if self.mode == 'eval':
    #         return x * (1-self.p) # have to rescale during inference
    #     mask_idx_nums = RNG.choice(x.size, size=int(x.size*self.p), replace=False)
    #     mask_idxs = np.unravel_index(mask_idx_nums, x.shape)
    #     x.data[mask_idxs] = 0
        
    #     return x


class Relu(Module):
    def forward(self, x: Tensor):
        return x.relu()
    
class SoftMax(Module):
    def forward(self, x: Tensor):
        # temporary as max is not an implemented op
        x = x - np.max(x.data, axis=-1, keepdims=True) # for numerical stability 
        x = x.exp()
        norm_c = x.sum(axis=-1, keepdims=True)
        return x / norm_c

In [8]:
do = DropOut(0.5)
x = Tensor.random((2,3))
# print(A, do(A))
mask_idx_nums = RNG.choice(x.size, size=int(x.size*(1-0.5)), replace=False)
mask_idxs = np.unravel_index(mask_idx_nums, x.shape)
mask = np.zeros_like(x.data)
print(mask.shape)

(2, 3)


In [9]:
def one_hot_encode(array, num_c):
    one_hot = np.zeros(shape=(array.size, num_c))
    for idx, i in enumerate(array):
        one_hot[idx, i] = 1
    return one_hot

class Loss_Fn():
    def __call__(self, *args, **kwargs) -> Tensor:
        raise NotImplementedError("Loss function must implement __call__ method")
    
    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'
    
    def __str__(self) -> str:
        return self.__repr__()

class SoftMaxCrossEntropy(Loss_Fn):
    def __call__(self, z: Tensor, y) -> Tensor:
        '''logits z, shape (B, C), true integer lables y, shape (B)'''
        # TODO change from manual one hot encoding when getitem is implemented in tensor
        y = Tensor(one_hot_encode(y, z.shape[-1])) #shape (B, C)
        z = z - np.max(z.data, axis=-1, keepdims=True) # for numerical stability 
        loss = (-(z * y).sum(axis=-1) + ((z.exp()).sum(axis=-1)).log()).mean()
        return loss

class CrossEntropy(Loss_Fn):
    def __call__(self, q: Tensor, y) -> Tensor:
        '''pred q, shape (B, C), true integer lables y, shape (B)'''
        # TODO change from manual one hot encoding when getitem is implemented in tensor
        y = Tensor(one_hot_encode(y, q.shape[-1])) #shape (B, C)
        loss = -(y * (q+dtype_eps).log()).sum(axis=-1).mean()
        
        return loss
    
class MeanSquaredError():
    def __call__(self, q: Tensor, y) -> Tensor:
        '''pred q, shape (B, C), true values y, shape (B, C)'''
        loss = ((q - y) ** 2).sum(axis=-1).mean()
        return loss

class optimiser():
    def __init__(self, params: list[Parameter], lr: float=0.005):
        self.lr = lr
        self.params = params
    
    @abstractmethod
    def step(self) -> None:
        pass

    def zero_grad(self) -> None:
        for param in self.params:
            param.zero_grad()

    def train(self) -> None:
        for param in self.params:
            param.requires_grad = True
    
    def eval(self) -> None:
        for param in self.params:
            param.requires_grad = False
        
    
class SGD(optimiser):
    
    def step(self) -> None:
        for param in self.params:
            if not param.requires_grad:
                continue 
            param.data += -self.lr * param.grad

class Adam(optimiser):
    def __init__(self, params: list[Parameter], lr: float=0.005, 
                 betas: Tuple[float, float]=(0.9, 0.999), eps: float=1e-8):
        super().__init__(params, lr)
        self.b1 , self.b2 = betas
        self.eps = eps
        self.time_step = 0
        self.m = [np.zeros_like(param.data, dtype=DTYPE) for param in params]
        self.v = [np.zeros_like(param.data, dtype=DTYPE) for param in params]
    
    def step(self) -> None:
        self.time_step += 1
        for i, p in enumerate(self.params):
            if not p.requires_grad:
                continue 

            g = p.grad
            self.m[i] = self.b1*self.m[i] + (1-self.b1)*g
            self.v[i] = self.b2*self.v[i] + (1-self.b2)*(g**2)
            m_hat = self.m[i]/(1-self.b1**self.time_step)
            v_hat = self.v[i]/(1-self.b2**self.time_step)

            p.data += -self.lr * m_hat / (v_hat ** 0.5 + self.eps)

In [10]:
from math import ceil

class DataLoader():
    def __init__(self, input_data, true_data, batch_size, shuffle=False, rng: np.random.Generator=RNG):
        assert input_data.shape[0] == true_data.shape[0], 'must have the same number of inputs and true outputs'
        self.X = input_data
        self.y = true_data
        self.N = batch_size
        self.shuffle = shuffle
        self.rng = rng

    def __iter__(self):
        X, y = self.X, self.y
        if self.shuffle:
            permutation = self.rng.permutation(X.shape[0])
            X = X[permutation]
            y = y[permutation]
        splits = np.arange(self.N, X.shape[0], self.N)
        X = np.split(X, splits, axis=0)
        X = [Tensor(x, requires_grad=False) for x in X]
        y = np.split(y, splits, axis=0)
        return zip(X, y)

    def __len__(self):
        # samples/batch size rounded up
        return ceil(self.X.shape[0]/self.N)
    

In [11]:
# import urllib.request, numpy as np
# import os

# os.makedirs('datasets')

# url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz"
local_path = r"datasets/mnist.npz"

# urllib.request.urlretrieve(url, local_path)   # â‡¦ makes a real file
data = np.load(local_path)

# im = X_train[0:3]
# print(type(im))
# plt.imshow(im, cmap='grey')
# plt.show()

In [12]:
# X_train, y_train = data["x_train"][0:12].reshape((-1,784)) / 255, data["y_train"][0:12]
# print(y_train.shape)
# train_loader = DataLoader(X_train, y_train, 4, shuffle=True)
# for X, y in train_loader:
#     print(y)
#     im = X[0].reshape((28,28))
#     print(type(im))
#     plt.imshow(im, cmap='grey')
#     plt.show()
#     continue

In [13]:
def one_hot_encode(array, num_c):
    one_hot = np.zeros(shape=(array.size, num_c))
    for idx, i in enumerate(array):
        one_hot[idx, i] = 1

    return one_hot

In [14]:
nn = Sequential([Affine(784, 100), Relu(), Affine(100, 200), Relu(), Affine(200, 10), SoftMax()])
nn.params

[parameter shape: (784, 100), size: 78400,
 parameter shape: (100,), size: 100,
 parameter shape: (100, 200), size: 20000,
 parameter shape: (200,), size: 200,
 parameter shape: (200, 10), size: 2000,
 parameter shape: (10,), size: 10]

In [None]:
class MultiAverageMeter():
    def __init__(self):
        self._means = {}
        self._counts = {}

    def update_one(self, metric: str, mean: float, n: int = 1):  
        if metric in self._means.keys():
            k = self._counts[metric]
            self._means[metric] += n*(mean - self._means[metric])/(k+n)
            self._counts[metric] += n
        else:
            self._means[metric] = mean
            self._counts[metric] = n

    def update_many(self, metric_dict: dict[str, float|list]):
        for metric, val in metric_dict.items():
            if isinstance(val, float):
                mean = val
                n = 1
            elif isinstance(val, list):
                mean = val[0]
                n = val[1]
            self.update_one(metric, mean, n)
    
    def get_metric(self, metric):
        if metric in self._means:
            return  self._means[metric]
        raise KeyError(f'{metric} not found')
    
    def dump_metrics(self):
        return self._means

    def reset(self, metric=None):
        if metric is None:
            self._means = {}
            self._counts = {}
        else:
            del self._means[metric]
            del self._counts[metric]

    def get_log_str(self, metrics=None):
        log_str = ''
        metrics = self._means.keys() if metrics is None else metrics
        for metric in metrics:
            if metric not in self._means:
                continue
            log_str += f'{metric} : {self._means[metric]:.4f} '
        return log_str

    def __getitem__(self, key):  
        return self.get_metric(key)
    
    def __contains__(self, key): 
        return key in self._means
    
    def __iter__(self):          
        return iter(self._means)
    
    def items(self):            
        return self._means.items()
    
    def __len__(self):           
        return len(self._means)
    
    def __repr__(self):          
        return f"MultiAverageMeter({self._means})"

In [17]:
wandb_config = {'project': 'torch from scratch testing',
                'name': 'dropout_tests',
                'config': {'optimiser':'adam', 'lr':0.05},
                'group': 'mnist tests',}

In [None]:
'''TODO:
 - early stopping
 - overfit batch
 '''

Mode = Literal['train', 'eval']

class Callback:
    
    def on_train_start(self, trainer): pass
    def on_epoch_start(self, trainer): pass
    def on_batch_start(self, trainer): pass
    def on_batch_end(self, trainer): pass
    def on_epoch_end(self, trainer): pass
    def on_train_end(self, trainer): pass

class CallbackList:
    def __init__(self, callbacks):
        self._callbacks = callbacks
    
    def __getattr__(self, name):
        def callbacks(*args, **kwargs):
            for cb in self._callbacks:
                method = getattr(cb, name, None)
                if callable(method):
                    method(*args, **kwargs)
        return callbacks
    
def track_runtime(name=None):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(self: Trainer, *args, **kwargs):
            start = time.perf_counter()
            result = func(self, *args, **kwargs)
            end = time.perf_counter()
            metric_name = name or f'time/{func.__name__}'
            self._meter.update_one(metric_name, end - start)
            return result
        return wrapper
    return decorator


class Trainer():
    def __init__(
            self, 
            model: Module, 
            optimiser: optimiser, 
            loss_fn: Loss_Fn, 
            train_loader: DataLoader, 
            validation_loader: DataLoader, 
            test_loader: DataLoader, 
            logger, 
            callbacks: list[Callback]):
        
        self.model = model
        self.optimiser = optimiser
        self.loss_fn = loss_fn
        self.train_loader = train_loader
        self.validation_loader = validation_loader
        self.test_loader = test_loader
        self._epoch = 1
        self.logger = logger
        self._meter = MultiAverageMeter()
        self._mode: Mode = 'train'
        self._callbacks = CallbackList(callbacks) 

    @track_runtime()
    def train_batch(self, X, y):
        self.optimiser.zero_grad()
        output = self.model(X)
        loss = self.loss_fn(output, y)
        loss.backward()
        self.optimiser.step()

        return (loss, output)

    @track_runtime()
    def train_epoch(self):
        self.model.train()
        self._callbacks.on_epoch_start(self)

        for X, y in self.train_loader:

            self._callbacks.on_batch_start(self)

            out = self.evaluate_batch(X, y)
            metrics_dict = {'loss': (out[0].item(), X.shape[0])} 
            self._meter.update_many(metrics_dict)

            self._callbacks.on_batch_end()

        self._callbacks.on_epoch_end(self)
        return self._meter.get_log_str()

    @track_runtime()
    def evaluate_batch(self, X, y):
        self.optimiser.zero_grad()
        output = self.model(X)
        loss = self.loss_fn(output, y)

        return (loss, output)

    @track_runtime
    def evaluate(self):
        self.model.eval()

        self._callbacks.on_epoch_start(self)

        for X, y in self.validation_loader:

            self._callbacks.on_batch_start(self)

            out = self.evaluate_batch(X, y)
            metrics_dict = {'loss': (out[0].item(), X.shape[0])} 
            self._meter.update_many(metrics_dict)

            self._callbacks.on_batch_end(self)
        
        self._callbacks.on_epoch_end(self)
        return self._meter.get_log_str()
    
    @track_runtime()
    def test_batch(self, X:Tensor, y):
        self.optimiser.zero_grad()
        output = self.model(X)
        loss = self.loss_fn(output, y)

        return (loss, output)   

    @track_runtime()
    def test(self):
        self.model.eval()

        self._callbacks.on_epoch_start(self)

        for X, y in self.test_loader:

            self._callbacks.on_batch_start(self)

            out = self.evaluate_batch(X, y)
            metrics_dict = {'loss': (out[0].item(), X.shape[0])} 
            self._meter.update_many(metrics_dict)

            self._callbacks.on_batch_end(self)
        
        self._callbacks.on_epoch_end(self)
        return self._meter.get_log_str()

    @track_runtime()
    def train(self, epochs: int):

        self._callbacks.on_train_start(self)

        for t in range(epochs):
            print(f'epoch: {t}')
            self._meter.reset()
            log = self.train_epoch()
            print('train: ' + log)


            self._meter.reset()
            log = self.evaluate()
            print('eval: ' + log)

        if self.test_loader is not None:
            self._meter.reset()
            log = self.test()
            print('eval: ' + log)

        self._callbacks.on_train_end(self)


class WandBCallback(Callback):
    def __init__(self, api_key:Optional[str]=None):

        self._api_key = api_key if api_key is not None else self.get_api_key()
        assert api_key is not None, 'api key required'

        self._mode: Mode = 'train'

    def get_api_key(self):
        load_dotenv()
        return os.getenv("WANDB_API_KEY")

    def on_train_start(self, trainer: Trainer):
        self._meter = trainer._meter
        wandb.login(key=self._api_key)
        self._wandb_run = wandb.init(**wandb_config)

    def on_epoch_start(self, trainer: Trainer): 
        self._mode = trainer._mode

    def on_epoch_end(self, trainer): 
        metrics = {}
        for name, val in self._meter.items():
            metrics[name + f'/{self._mode}'] = val
        metrics = {
            'sys/ram_gb' : proc.memory_info().rss / 1_073_741_824
        }
        self._wandb_run.log(metrics, step=trainer._epoch)

    def on_train_end(self, trainer): 
        self._wandb_run.finish()

        

In [19]:
def train_test_step(meter:MultiAverageMeter, train, nn, loader, loss_fn, optimiser):
    start = time.time()
    if train:
        nn.train()
    else:
        nn.eval()

    for X, y in loader:

        nn.zero_grad()
        out = nn(X)
        loss = loss_fn(out, y)
        if train:
            loss.backward()
            optimiser.step()
        
        preds = np.argmax(out.item(), axis=-1)
        acc = np.sum(preds == y) / preds.size 
        
        meter.update('CE', loss.item(), y.shape[0])
        meter.update('accuracy', acc, y.shape[0])
    end = time.time()
    meter.update('speed/epoch_sec', end - start)
    meter.update('speed/samples_per_sec', len(loader) / (end - start))
    return meter.get_log_str()

def train_nn(epochs, p=0.0):

    meter = MultiAverageMeter()
    wandb_logger = WandBLogger(meter, wandb_config)

    X_train, y_train = data["x_train"].reshape((-1,784)) / 255, data["y_train"]
    X_test, y_test = data["x_test"].reshape((-1,784)) / 255, data["y_test"]

    train_loader = DataLoader(X_train, y_train, 256, shuffle=True)
    test_loader = DataLoader(X_test, y_test, 256, shuffle=False)

    nn = Sequential([DropOut(p), Affine(784, 200), DropOut(p), Relu(), Affine(200, 100), Relu(), Affine(100, 50), Relu(), Affine(50, 10), SoftMax()])
    # nn = Sequential([Affine(784, 100), Relu(), Affine(100, 200), Relu(), Affine(200, 10), SoftMax()])
    loss_fn = CrossEntropy()
    optimiser = Adam(nn.params)
    # optimiser = SGD(nn.params, lr=0.05)

    for t in range(epochs):
        print(f'epoch: {t}')
        wandb_logger.train()
        meter.reset()
        log = train_test_step(meter, True, nn, train_loader, loss_fn, optimiser)
        print('train: ' + log)
        wandb_logger.log_epoch(t)

        wandb_logger.eval()
        meter.reset()
        log = train_test_step(meter, False, nn, test_loader, loss_fn, optimiser)
        print('test: ' + log)
        wandb_logger.log_epoch(t)

    wandb_logger.finish()   


In [20]:
train_nn(40)
train_nn(40, p=0.05)
train_nn(40, p=0.1)
train_nn(40, p=0.15)
train_nn(40, p=0.2)
train_nn(40, p=0.25)
train_nn(40, p=0.3)

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/nik/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnikiwillems9[0m ([33mnikiwillems9-university-of-bristol[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


epoch: 0
train: CE : 0.2510 accuracy : 0.9238 speed/epoch_sec : 4.7142 speed/samples_per_sec : 49.8497 
test: CE : 0.1269 accuracy : 0.9603 speed/epoch_sec : 0.1202 speed/samples_per_sec : 332.7743 
epoch: 1
train: CE : 0.1025 accuracy : 0.9684 speed/epoch_sec : 5.8914 speed/samples_per_sec : 39.8886 
test: CE : 0.0928 accuracy : 0.9706 speed/epoch_sec : 0.1649 speed/samples_per_sec : 242.5838 
epoch: 2
train: CE : 0.0685 accuracy : 0.9787 speed/epoch_sec : 5.7389 speed/samples_per_sec : 40.9489 
test: CE : 0.1112 accuracy : 0.9683 speed/epoch_sec : 0.1859 speed/samples_per_sec : 215.1455 
epoch: 3
train: CE : 0.0537 accuracy : 0.9836 speed/epoch_sec : 5.8416 speed/samples_per_sec : 40.2288 
test: CE : 0.0917 accuracy : 0.9736 speed/epoch_sec : 0.1795 speed/samples_per_sec : 222.8867 
epoch: 4
train: CE : 0.0514 accuracy : 0.9838 speed/epoch_sec : 22.2729 speed/samples_per_sec : 10.5509 
test: CE : 0.0944 accuracy : 0.9734 speed/epoch_sec : 0.6020 speed/samples_per_sec : 66.4441 
epoch

KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7299b4aa20f0>> (for post_run_cell), with arguments args (<ExecutionResult object at 7299b4adbb60, execution_count=20 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7299b4adbad0, raw_cell="train_nn(40)
train_nn(40, p=0.05)
train_nn(40, p=0.." transformed_cell="train_nn(40)
train_nn(40, p=0.05)
train_nn(40, p=0.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/home/nik/personal_projects/dqn_from_scratch/test.ipynb#X23sZmlsZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe