In [1]:
import numpy as np
from typing import Union, Tuple, Self, Iterable
import matplotlib.pyplot as plt
from dotenv import load_dotenv
import os
import wandb
import time
import psutil

In [42]:
# regular constants
RNG = np.random.default_rng()
DTYPE = 'float32' 

proc  = psutil.Process(os.getpid())

# testing constants
if DTYPE=='float64':
    EPS, ATOL, RTOL = 1e-6, 1e-5, 1e-3
else:
    EPS, ATOL, RTOL = 1e-4, 1e-4, 1e-2
K = 20

dtype_eps = {'float16': 1e-4,
             'float32': 1e-7,
             'float64': 1e-15}[DTYPE]

In [43]:
class Tensor():
    def __init__(self, data, requires_grad=False, children=(), op=''):
        self.data: np.ndarray = np.array(data, dtype=DTYPE)
        self.grad = np.zeros_like(data, dtype=DTYPE)
        self.requires_grad = requires_grad
        self._prev = set(children)
        self._backward = lambda : None
        self._op = op

    @property
    def shape(self) -> Tuple[int]:
        return self.data.shape
    
    @property
    def size(self) -> int: 
        return self.data.size
    
    def zero_grad(self) -> None:
        self.grad = np.zeros_like(self.data, dtype=DTYPE)

    def item(self) -> np.ndarray:
        return self.data
    
    def _unbroadcast(self, grad: np.ndarray) -> np.ndarray:
        dims_to_remove = tuple(i for i in range(len(grad.shape) - len(self.shape))) 
        # remove prepended padding dimensions
        grad = np.sum(grad, axis=dims_to_remove, keepdims=False) 
        dims_to_reduce = tuple(i for i, (d1,d2) in enumerate(zip(grad.shape, self.shape)) if d1!=d2)
        # reduce broadcasted dimensions
        return np.sum(grad, axis=dims_to_reduce, keepdims=True)

    # need to build topo graph and then go through it and call backwards on each of the tensors
    def backward(self) -> None:
        self.grad = np.ones_like(self.data)
        topo = []
        visited = set()

        # do DFS on un-visited nodes, add node to topo-when all its children have been visited
        def build_topo(node):
            if node not in visited:
                visited.add(node)
                for child in node._prev:
                    build_topo(child)
                topo.append(node)
        build_topo(self)

        for node in reversed(topo):
            node._backward()
            node._prev = set(())
            node._backward = lambda : None

    def __getitem__(self, indexes):
        out = Tensor(self.data[indexes], self.requires_grad, (self), 'slice')

        def _backward():
            if self.requires_grad:
                pass
        out._backward = _backward
        return out
            
    def __add__(self, rhs) -> "Tensor":
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data + rhs.data, self.requires_grad or rhs.requires_grad, (self, rhs), '+')

        def _backward():
            if self.requires_grad:
                self.grad += self._unbroadcast(out.grad)
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(out.grad)
        out._backward = _backward
        return out
    
    def __neg__(self) -> "Tensor":
        out = Tensor(-self.data, self.requires_grad, (self,), 'neg')

        def _backward():
            if self.requires_grad:
                self.grad += -out.grad
        out._backward = _backward
        return out
    
    def __sub__(self, rhs) -> "Tensor":
        return self + (-rhs)

    def __mul__(self, rhs) -> "Tensor":
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data*rhs.data, self.requires_grad or rhs.requires_grad, (self, rhs), f'*')

        def _backward():
            if self.requires_grad:
                self.grad += self._unbroadcast(out.grad * rhs.data)
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(out.grad * self.data)
        out._backward = _backward
        return out
        
    def __truediv__(self, rhs) -> "Tensor":
        return self * (rhs**-1)
    
    # TODO add check for rhs, if epxponent if negative the gradient is undefined
    def __pow__(self, rhs) -> "Tensor": 
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        lhs_is_neg = self.data < 0
        rhs_is_frac = ~np.isclose(rhs.data % 1, 0)
        if np.any(lhs_is_neg & rhs_is_frac):
            raise ValueError('cannot raise negative value to a decimal power')
        
        out = Tensor(self.data**rhs.data, self.requires_grad or rhs.requires_grad, (self,), f'**')

        def _backward():
            if self.requires_grad:
                self.grad += self._unbroadcast(out.grad * ((rhs.data)*(self.data**(rhs.data-1))))
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(out.grad * (self.data ** rhs.data) * np.log(self.data))
        out._backward = _backward
        return out
    
    '''data shape: (da, ..., d2, d1, n, k) rhs shape: (ob, ..., o2, o1, k, m)
       inputs are broadcast so that they have the same shape by expanding along
       dimensions if possible, out shape: (tc, ..., t2, t1, n, m), where ti = max(di, oi)
       if di or oi does not exist it is treated as 1, and c = max d, a
       if self is 1d shape is prepended with a 1, for rhs it would be appended'''
    def __matmul__(self, rhs) -> "Tensor":
        rhs = rhs if isinstance(rhs, Tensor) else Tensor(rhs)
        out = Tensor(self.data @ rhs.data, self.requires_grad or rhs.requires_grad, (self, rhs), '@')

        def _backward():
            A, B, = self.data, rhs.data
            g = out.grad
            # broadcast 1d arrays to be 2d 
            A2 = A.reshape(1, -1) if len(A.shape) == 1 else A
            B2 = B.reshape(-1, 1) if len(B.shape) == 1 else B
            # extend g to have reduced dims
            g = np.expand_dims(g, -1) if len(B.shape) == 1 else g
            g = np.expand_dims(g, -2) if len(A.shape) == 1 else g
            # transpose last 2 dimensions, as matmul treats tensors as batched matricies
            if self.requires_grad:
                self.grad += self._unbroadcast(g @ B2.swapaxes(-2, -1))
            if rhs.requires_grad:
                rhs.grad += rhs._unbroadcast(A2.swapaxes(-2, -1) @ g)
        out._backward = _backward
        return out

    def relu(self) -> "Tensor":
        out = Tensor((self.data > 0) * self.data, self.requires_grad, (self,), 'Relu')

        def _backward():
            if self.requires_grad:
                self.grad += (self.data > 0) * out.grad
        out._backward = _backward
        return out
    
    def log(self) -> "Tensor":
        if np.any(self.data < 0):
            raise ValueError('cannot log negative values')
        out = Tensor(np.log(self.data), self.requires_grad, (self,), 'log')

        def _backward():
            if self.requires_grad:
                self.grad += (1 / self.data) * out.grad
        out._backward = _backward
        return out
    
    def exp(self) -> "Tensor":
        out = Tensor(np.exp(self.data), self.requires_grad, (self,), 'exp')

        def _backward():
            if self.requires_grad:
                self.grad += np.exp(self.data) * out.grad
        out._backward = _backward
        return out
    
    def sum(self, axis=None, keepdims=False) -> "Tensor":
        out = Tensor(np.sum(self.data, axis=axis, keepdims=keepdims), self.requires_grad, (self,), 'sum')

        def _backward():
            if self.requires_grad:
                g = np.expand_dims(out.grad, axis) if (axis is not None and not keepdims) else out.grad
                self.grad += g
        out._backward = _backward
        return out

    def mean(self, axis=None) -> "Tensor":
        out = Tensor(np.mean(self.data, axis=axis), self.requires_grad, (self,), 'mean')

        def _backward():
            if self.requires_grad:
                N =  self.size // out.size 
                g = np.expand_dims(out.grad, axis) if axis is not None else out.grad
                self.grad += g / N
        out._backward = _backward
        return out
    
    def clamp(self, a_min=None, a_max=None):
        out = Tensor(np.clip(self.data, a_min=a_min, a_max=a_max), self.requires_grad, (self,), 'clamp')

        def _backward():
            if self.requires_grad:
                mask = (self.data > a_min) if a_min is not None else np.ones_like(self.data)
                mask = mask & (self.data < a_max) if a_max is not None else mask
                self.grad += out.grad * mask
        out._backward = _backward
        return out
    
    def __radd__(self, lhs) -> "Tensor":
        return self + lhs
    
    def __rsub__(self, lhs) -> "Tensor":
        return self + lhs
    
    def __rmul__(self, lhs) -> "Tensor":
        return self * lhs
    
    def __rtruediv__(self, lhs) -> "Tensor":
        return Tensor(lhs) / self
    
    def __rpow__(self, lhs) -> "Tensor":
        return Tensor(lhs) ** self
    
    def __rmatmul__(self, lhs) -> "Tensor":
        return Tensor(lhs) @ self
    
    @classmethod
    def random(cls, shape: tuple, bounds = (0,1), requires_grad=False) -> "Tensor":
        lower, upper = bounds
        data = RNG.random(shape, dtype=DTYPE)*(upper-lower) + lower
        return cls(data, requires_grad=requires_grad)
    
    def __repr__(self) -> str:
        return f'tensor shape: {self.shape}, op:{self._op}'        


In [44]:
class Parameter(Tensor):
    def __init__(self, data):
        super().__init__(data, requires_grad=True)
    
    @classmethod
    def kaiming(cls, fan_in, shape):
        std = np.sqrt(2/fan_in)
        weights = (RNG.standard_normal(shape, dtype='float64')*std).astype(dtype=DTYPE)
        return cls(weights)
    
    @classmethod
    def zeros(cls, shape):
        return cls(np.zeros(shape, dtype=DTYPE))
    
    def __repr__(self) -> str:
        return f'parameter shape: {self.shape}, size: {self.size}' 

In [45]:
from abc import abstractmethod

class Module():
    
    def __call__(self, input: Tensor) -> Tensor:
        return self.forward(input)
    
    @property
    def modules(self) -> list["Module"]:
        modules: list[Module] = []
        for value in self.__dict__.values():
            if isinstance(value, Module):
                modules.append(value)

            elif isinstance(value, dict):
                for v in value.values():
                    if isinstance(v, Module):
                        modules.append(v)

            elif isinstance(value, Iterable) and not isinstance(value, (str, bytes)):
                for v in value:
                    if isinstance(v, Module):
                        modules.append(v)
                    
        return modules
    
    @property
    def params(self) -> list[Parameter]:
        immediate_params = [attr for attr in self.__dict__.values() 
                                    if isinstance(attr, Parameter)]
        modules_params = [param for module in self.modules 
                                    for param in module.params]
        return immediate_params + modules_params
    
    @abstractmethod
    def forward(self, input: Tensor) -> Tensor:
        pass
    
    def zero_grad(self) -> None:
        for param in self.params:
            param.zero_grad()

    def train(self) -> None:
        for param in self.params:
            param.requires_grad = True

        for module in self.modules:
            if isinstance(module, DynamicModule):
                module.mode = 'train'
        
    def eval(self) -> None:
        for param in self.params:
            param.requires_grad = False

        for module in self.modules:
            if isinstance(module, DynamicModule):
                module.mode = 'eval'

class Sequential(Module):
    def __init__(self, layers):
        self.layers = layers
    
    def forward(self, input: Tensor) -> Tensor:
        x = input
        for layer in self.layers:
            x = layer(x)
        return x
    
class Affine(Module):
    def __init__(self, in_dim, out_dim):
        self.A = Parameter.kaiming(in_dim, (in_dim, out_dim))
        self.b = Parameter.zeros((out_dim))

    def forward(self, input: Tensor):
        x = input
        # x: (B, in), A : (in, out), B: out
        return (x @ self.A) + self.b
    
class DynamicModule(Module):
    def __init__(self) -> None:
        self.mode = 'train'

class DropOut(DynamicModule):
    def __init__(self, p):
        self.mode = 'train'
        self.p = p

    def forward(self, x: Tensor):
        if self.mode == 'eval':
            return x * (1-self.p) # have to rescale during inference
        mask_idx_nums = RNG.choice(x.size, size=int(x.size*(1-self.p)), replace=False)
        mask_idxs = np.unravel_index(mask_idx_nums, x.shape)
        mask = np.zeros_like(x.data)
        mask[mask_idxs] = 1
        
        return x * mask
    
    '''TODO: can implement in the furture to make it faster once __getitem__ is implemented'''
    # def forward(self, x: Tensor):
    #     if self.mode == 'eval':
    #         return x * (1-self.p) # have to rescale during inference
    #     mask_idx_nums = RNG.choice(x.size, size=int(x.size*self.p), replace=False)
    #     mask_idxs = np.unravel_index(mask_idx_nums, x.shape)
    #     x.data[mask_idxs] = 0
        
    #     return x


class Relu(Module):
    def forward(self, x: Tensor):
        return x.relu()
    
class SoftMax(Module):
    def forward(self, x: Tensor):
        # temporary as max is not an implemented op
        x = x - np.max(x.data, axis=-1, keepdims=True) # for numerical stability 
        x = x.exp()
        norm_c = x.sum(axis=-1, keepdims=True)
        return x / norm_c

In [46]:
do = DropOut(0.5)
x = Tensor.random((2,3))
# print(A, do(A))
mask_idx_nums = RNG.choice(x.size, size=int(x.size*(1-0.5)), replace=False)
mask_idxs = np.unravel_index(mask_idx_nums, x.shape)
mask = np.zeros_like(x.data)
print(mask.shape)

(2, 3)


In [47]:
def one_hot_encode(array, num_c):
    one_hot = np.zeros(shape=(array.size, num_c))
    for idx, i in enumerate(array):
        one_hot[idx, i] = 1
    return one_hot

class Loss_Fn():
    def __call__(self, *args, **kwargs) -> Tensor:
        raise NotImplementedError("Loss function must implement __call__ method")
    
    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'
    
    def __str__(self) -> str:
        return self.__repr__()

class SoftMaxCrossEntropy(Loss_Fn):
    def __call__(self, z: Tensor, y) -> Tensor:
        '''logits z, shape (B, C), true integer lables y, shape (B)'''
        # TODO change from manual one hot encoding when getitem is implemented in tensor
        y = Tensor(one_hot_encode(y, z.shape[-1])) #shape (B, C)
        z = z - np.max(z.data, axis=-1, keepdims=True) # for numerical stability 
        loss = (-(z * y).sum(axis=-1) + ((z.exp()).sum(axis=-1)).log()).mean()
        return loss

class CrossEntropy(Loss_Fn):
    def __call__(self, q: Tensor, y) -> Tensor:
        '''pred q, shape (B, C), true integer lables y, shape (B)'''
        # TODO change from manual one hot encoding when getitem is implemented in tensor
        y = Tensor(one_hot_encode(y, q.shape[-1])) #shape (B, C)
        loss = -(y * (q+dtype_eps).log()).sum(axis=-1).mean()
        
        return loss
    
class MeanSquaredError():
    def __call__(self, q: Tensor, y) -> Tensor:
        '''pred q, shape (B, C), true values y, shape (B, C)'''
        loss = ((q - y) ** 2).sum(axis=-1).mean()
        return loss

class optimiser():
    def __init__(self, params: list[Parameter], lr: float=0.005):
        self.lr = lr
        self.params = params
    
    @abstractmethod
    def step(self) -> None:
        pass

    def zero_grad(self) -> None:
        for param in self.params:
            param.zero_grad()

    def train(self) -> None:
        for param in self.params:
            param.requires_grad = True
    
    def eval(self) -> None:
        for param in self.params:
            param.requires_grad = False
        
    
class SGD(optimiser):
    
    def step(self) -> None:
        for param in self.params:
            if not param.requires_grad:
                continue 
            param.data += -self.lr * param.grad

class Adam(optimiser):
    def __init__(self, params: list[Parameter], lr: float=0.005, 
                 betas: Tuple[float, float]=(0.9, 0.999), eps: float=1e-8):
        super().__init__(params, lr)
        self.b1 , self.b2 = betas
        self.eps = eps
        self.time_step = 0
        self.m = [np.zeros_like(param.data, dtype=DTYPE) for param in params]
        self.v = [np.zeros_like(param.data, dtype=DTYPE) for param in params]
    
    def step(self) -> None:
        self.time_step += 1
        for i, p in enumerate(self.params):
            if not p.requires_grad:
                continue 

            g = p.grad
            self.m[i] = self.b1*self.m[i] + (1-self.b1)*g
            self.v[i] = self.b2*self.v[i] + (1-self.b2)*(g**2)
            m_hat = self.m[i]/(1-self.b1**self.time_step)
            v_hat = self.v[i]/(1-self.b2**self.time_step)

            p.data += -self.lr * m_hat / (v_hat ** 0.5 + self.eps)

In [48]:
from math import ceil

class DataLoader():
    def __init__(self, input_data, true_data, batch_size, shuffle=False, rng: np.random.Generator=RNG):
        assert input_data.shape[0] == true_data.shape[0], 'must have the same number of inputs and true outputs'
        self.X = input_data
        self.y = true_data
        self.N = batch_size
        self.shuffle = shuffle
        self.rng = rng

    def __iter__(self):
        X, y = self.X, self.y
        if self.shuffle:
            permutation = self.rng.permutation(X.shape[0])
            X = X[permutation]
            y = y[permutation]
        splits = np.arange(self.N, X.shape[0], self.N)
        X = np.split(X, splits, axis=0)
        X = [Tensor(x, requires_grad=False) for x in X]
        y = np.split(y, splits, axis=0)
        return zip(X, y)

    def __len__(self):
        # samples/batch size rounded up
        return ceil(self.X.shape[0]/self.N)
    

In [None]:
# import urllib.request, numpy as np
# import os

# os.makedirs('datasets')

# url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz"
local_path = r"datasets/mnist.npz"

# urllib.request.urlretrieve(url, local_path)   # ⇦ makes a real file
data = np.load(local_path)

# im = X_train[0:3]
# print(type(im))
# plt.imshow(im, cmap='grey')
# plt.show()

In [50]:
# X_train, y_train = data["x_train"][0:12].reshape((-1,784)) / 255, data["y_train"][0:12]
# print(y_train.shape)
# train_loader = DataLoader(X_train, y_train, 4, shuffle=True)
# for X, y in train_loader:
#     print(y)
#     im = X[0].reshape((28,28))
#     print(type(im))
#     plt.imshow(im, cmap='grey')
#     plt.show()
#     continue

In [51]:
def one_hot_encode(array, num_c):
    one_hot = np.zeros(shape=(array.size, num_c))
    for idx, i in enumerate(array):
        one_hot[idx, i] = 1

    return one_hot

In [52]:
nn = Sequential([Affine(784, 100), Relu(), Affine(100, 200), Relu(), Affine(200, 10), SoftMax()])
nn.params

[parameter shape: (784, 100), size: 78400,
 parameter shape: (100,), size: 100,
 parameter shape: (100, 200), size: 20000,
 parameter shape: (200,), size: 200,
 parameter shape: (200, 10), size: 2000,
 parameter shape: (10,), size: 10]

In [53]:
class MultiAverageMeter():
    def __init__(self):
        self._metrics = {}
        self._counts = {}
    
    def update(self, metric, mean, n=1):
        if metric in self._metrics.keys():
            k = self._counts[metric]
            self._metrics[metric] += n*(mean - self._metrics[metric])/(k+n)
            self._counts[metric] += n
        else:
            self._metrics[metric] = mean
            self._counts[metric] = n
    
    def get_metric(self, metric):
        if metric in self._metrics:
            return  self._metrics[metric]
        raise KeyError(f'{metric} not found')
    
    def dump_metrics(self):
        return self._metrics

    def reset(self, metric=None):
        if metric is None:
            self._metrics = {}
            self._counts = {}
        else:
            del self._metrics[metric]
            del self._counts[metric]

    def get_log_str(self, metrics=None):
        log_str = ''
        metrics = self._metrics.keys() if metrics is None else metrics
        for metric in metrics:
            if metric not in self._metrics:
                continue
            log_str += f'{metric} : {self._metrics[metric]:.4f} '
        return log_str

    def __getitem__(self, key):  
        return self.get_metric(key)
    
    def __contains__(self, key): 
        return key in self._metrics
    
    def __iter__(self):          
        return iter(self._metrics)
    
    def items(self):            
        return self._metrics.items()
    
    def __len__(self):           
        return len(self._metrics)
    
    def __repr__(self):          
        return f"MultiAverageMeter({self._metrics})"

In [54]:
def get_wandb_run(wandb_config:dict, api_key:str|None=None):
    if api_key is None:
        load_dotenv()
        api_key = os.getenv("WANDB_API_KEY")
    assert api_key is not None, 'api key required'

    wandb.login(key=api_key)
    return wandb.init(**wandb_config)


class WandBLogger():
    def __init__(self, meter:MultiAverageMeter, wandb_config:dict, api_key:str|None=None):
        self._meter = meter
        self._wandb_run = get_wandb_run(wandb_config, api_key)
        self._mode = 'train'

    def eval(self):
        self._mode = 'eval'

    def train(self):
        self._mode = 'train'

    def test(self):
        self._mode = 'test'

    def log_epoch(self, epoch):
        metrics = {}
        for name, val in self._meter.items():
            metrics[name + f'/{self._mode}'] = val
        metrics = metrics | self.get_system_metrics()
        self._wandb_run.log(metrics, step=epoch)

    def get_system_metrics(self):
        metrics = {
            'sys/ram_gb' : proc.memory_info().rss / 1_073_741_824
        }
        return metrics
    
    def finish(self):
        self._wandb_run.finish()

In [55]:
wandb_config = {'project': 'torch from scratch testing',
                'name': 'dropout_tests',
                'config': {'optimiser':'adam', 'lr':0.05},
                'group': 'mnist tests',}

In [56]:
'''TODO:
 - early stopping
 - overfit batch
 '''

class Trainer():
    def __init__(
            self, 
            model: Module, 
            optimiser: optimiser, 
            loss_fn: Loss_Fn, 
            train_loader: DataLoader, 
            validation_loader: DataLoader, 
            test_loader: DataLoader, 
            logger, 
            wandb_logger: WandBLogger | None = None):
        
        self.model = model
        self.optimiser = optimiser
        self.loss_fn = loss_fn
        self.train_loader = train_loader
        self.validation_loader = validation_loader
        self.test_loader = test_loader
        self.epoch = 1
        self.logger = logger
        self.wandb_logger = wandb_logger
        self.meter = MultiAverageMeter()

    def train_epoch(self):
        self.model.train()

        for X, y in self.train_loader:
            self.optimiser.zero_grad()
            output = self.model(X)
            loss = self.loss_fn(output, y)
            loss.backward()
            self.optimiser.step()

        return self.meter.get_log_str()

    def validate(self):
        self.model.eval()

        for X, y in self.validation_loader:
            self.optimiser.zero_grad()
            output = self.model(X)
            loss = self.loss_fn(output, y)
        
        return self.meter.get_log_str()
    
    def test(self):
        self.model.eval()

        for X, y in self.validation_loader:
            self.optimiser.zero_grad()
            output = self.model(X)
            loss = self.loss_fn(output, y)

    def train(self, epochs: int):
        for t in range(epochs):
            print(f'epoch: {t}')
            if self.wandb_logger is not None: self.wandb_logger.train()
            self.meter.reset()
            
            log = self.train_epoch()
            print('train: ' + log)
            if self.wandb_logger is not None: self.wandb_logger.log_epoch(t)


            if self.wandb_logger is not None: self.wandb_logger.eval()
            self.meter.reset()
            log = self.validate()
            print('test: ' + log)
            if self.wandb_logger is not None: self.wandb_logger.log_epoch(t)
    
    def log_metrics(self):
        pass

In [57]:
def train_test_step(meter:MultiAverageMeter, train, nn, loader, loss_fn, optimiser):
    start = time.time()
    if train:
        nn.train()
    else:
        nn.eval()

    for X, y in loader:

        nn.zero_grad()
        out = nn(X)
        loss = loss_fn(out, y)
        if train:
            loss.backward()
            optimiser.step()
        
        preds = np.argmax(out.item(), axis=-1)
        acc = np.sum(preds == y) / preds.size 
        
        meter.update('CE', loss.item(), y.shape[0])
        meter.update('accuracy', acc, y.shape[0])
    end = time.time()
    meter.update('speed/epoch_sec', end - start)
    meter.update('speed/samples_per_sec', len(loader) / (end - start))
    return meter.get_log_str()

def train_nn(epochs, p=0.0):

    meter = MultiAverageMeter()
    wandb_logger = WandBLogger(meter, wandb_config)

    X_train, y_train = data["x_train"].reshape((-1,784)) / 255, data["y_train"]
    X_test, y_test = data["x_test"].reshape((-1,784)) / 255, data["y_test"]

    train_loader = DataLoader(X_train, y_train, 256, shuffle=True)
    test_loader = DataLoader(X_test, y_test, 256, shuffle=False)

    nn = Sequential([DropOut(p), Affine(784, 200), DropOut(p), Relu(), Affine(200, 100), Relu(), Affine(100, 50), Relu(), Affine(50, 10), SoftMax()])
    # nn = Sequential([Affine(784, 100), Relu(), Affine(100, 200), Relu(), Affine(200, 10), SoftMax()])
    loss_fn = CrossEntropy()
    optimiser = Adam(nn.params)
    # optimiser = SGD(nn.params, lr=0.05)

    for t in range(epochs):
        print(f'epoch: {t}')
        wandb_logger.train()
        meter.reset()
        log = train_test_step(meter, True, nn, train_loader, loss_fn, optimiser)
        print('train: ' + log)
        wandb_logger.log_epoch(t)

        wandb_logger.eval()
        meter.reset()
        log = train_test_step(meter, False, nn, test_loader, loss_fn, optimiser)
        print('test: ' + log)
        wandb_logger.log_epoch(t)

    wandb_logger.finish()   


In [58]:
train_nn(40)
train_nn(40, p=0.05)
train_nn(40, p=0.1)
train_nn(40, p=0.15)
train_nn(40, p=0.2)
train_nn(40, p=0.25)
train_nn(40, p=0.3)



epoch: 0
train: CE : 0.2515 accuracy : 0.9229 speed/epoch_sec : 1.1488 speed/samples_per_sec : 204.5629 
test: CE : 0.1167 accuracy : 0.9630 speed/epoch_sec : 0.0506 speed/samples_per_sec : 790.2152 
epoch: 1
train: CE : 0.1022 accuracy : 0.9692 speed/epoch_sec : 1.1814 speed/samples_per_sec : 198.9110 
test: CE : 0.1147 accuracy : 0.9672 speed/epoch_sec : 0.0303 speed/samples_per_sec : 1319.0880 
epoch: 2
train: CE : 0.0793 accuracy : 0.9757 speed/epoch_sec : 1.4261 speed/samples_per_sec : 164.7879 
test: CE : 0.0944 accuracy : 0.9710 speed/epoch_sec : 0.0382 speed/samples_per_sec : 1046.7638 
epoch: 3
train: CE : 0.0593 accuracy : 0.9810 speed/epoch_sec : 1.4965 speed/samples_per_sec : 157.0364 
test: CE : 0.0901 accuracy : 0.9748 speed/epoch_sec : 0.0319 speed/samples_per_sec : 1253.2469 
epoch: 4
train: CE : 0.0465 accuracy : 0.9848 speed/epoch_sec : 1.4309 speed/samples_per_sec : 164.2300 
test: CE : 0.1061 accuracy : 0.9730 speed/epoch_sec : 0.0334 speed/samples_per_sec : 1197.95

0,1
CE/eval,▄▄▂▂▃▁▁▁▂▃▃▂▃▂▅▄▃▃▂▅▄▃▄▅▃▅▄▄▄▅▄▅▃▆▆▅▅▅█▄
CE/train,█▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy/eval,▁▃▄▆▅▆▇▆▆▆▆▇▆▇▅▆▇▇▇▆▇▇█▆█▆▇█▇▇▇▇█▆▇█▇▇▆█
accuracy/train,▁▅▆▆▇▇▇▇▇▇█▇▇███████████████████████████
speed/epoch_sec/eval,▆▂▄▂▂▃▂▁▂▂▁▁▃▂▅▂▁▁▁█▁▁▁▄▁▃▂▂▁▁▇▁▁▂▁▂▃▂▁▁
speed/epoch_sec/train,▁▂▅▆▅▇▅▂▆▅▃▄▂▅█▅▃▁▂▃▃▁▁▂▃▂▂▂▂▄▄▂█▆▆▂▄▄▃▁
speed/samples_per_sec/eval,▂▇▄▆▆▅▇█▆▆▇▇▅▇▃▆▇█▇▁███▃▇▅▇▇▇█▂█▇▆█▆▅▆██
speed/samples_per_sec/train,▇▇▃▃▃▂▄▆▃▄▅▄▆▃▁▃▆▇▇▅▅▇█▆▆▇▆▇▆▅▅▆▁▃▂▇▅▄▅█
sys/ram_gb,▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▃▅▅▅▅▅▅▅▅▅▅▅▅███

0,1
CE/eval,0.11896
CE/train,0.01309
accuracy/eval,0.9795
accuracy/train,0.99647
speed/epoch_sec/eval,0.02742
speed/epoch_sec/train,1.11255
speed/samples_per_sec/eval,1458.62199
speed/samples_per_sec/train,211.22659
sys/ram_gb,1.79058


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/nik/.netrc


epoch: 0
train: CE : 0.2632 accuracy : 0.9183 speed/epoch_sec : 1.1836 speed/samples_per_sec : 198.5445 
test: CE : 0.1194 accuracy : 0.9624 speed/epoch_sec : 0.0299 speed/samples_per_sec : 1339.7122 
epoch: 1
train: CE : 0.1131 accuracy : 0.9645 speed/epoch_sec : 1.1518 speed/samples_per_sec : 204.0339 
test: CE : 0.1140 accuracy : 0.9651 speed/epoch_sec : 0.0267 speed/samples_per_sec : 1497.3507 
epoch: 2
train: CE : 0.0871 accuracy : 0.9725 speed/epoch_sec : 1.2111 speed/samples_per_sec : 194.0449 
test: CE : 0.0850 accuracy : 0.9733 speed/epoch_sec : 0.0278 speed/samples_per_sec : 1439.8201 
epoch: 3
train: CE : 0.0760 accuracy : 0.9765 speed/epoch_sec : 1.3593 speed/samples_per_sec : 172.8878 
test: CE : 0.1061 accuracy : 0.9685 speed/epoch_sec : 0.0299 speed/samples_per_sec : 1337.0430 
epoch: 4
train: CE : 0.0605 accuracy : 0.9809 speed/epoch_sec : 1.1408 speed/samples_per_sec : 205.9868 
test: CE : 0.0827 accuracy : 0.9769 speed/epoch_sec : 0.0308 speed/samples_per_sec : 1299.1

0,1
CE/eval,█▇▃▆▃▄▂▂▃▁▂▁▄▃▃▄▂▃▃▂▂▄▆▂▃▄▃▃▄▄▄▆▆▆▄▅▅▆▆▆
CE/train,█▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy/eval,▁▂▅▃▆▆▆▇▆▇▇▇▆▇▇▇█▇▇█▇▇▇██▇▇▇█▇▇▇▆▇███▇▇▇
accuracy/train,▁▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇████████████████████████
speed/epoch_sec/eval,▂▁▁▂▂▆▁▁▁▁▁▁▂▁▁▂▄▂▂█▃▂▂▁▂▁▂▃▂▂▂▂▂▂▁▁▂▂▂▁
speed/epoch_sec/train,▂▂▂▃▂▃▂▂▁▂▁▁▄▂▁▁▂▂▂██▂▂▁▂▂▂▂▃▂▂▂▃▃▁▁▄▃▂▂
speed/samples_per_sec/eval,▇█▇▇▆▂▇█▇▇██▇█▇▇▄▆▇▁▅▇▇▇▇▇▇▅▆▆▆▆▇▇▇█▆▇▆▇
speed/samples_per_sec/train,▆▇▆▅▇▆▇▆█▆██▄▆▇▇▆▇▆▁▁▆▇▇▆▇▇▇▅▆▆▇▅▆▇▇▄▅▇▇
sys/ram_gb,▇▇▇██▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃▃▂

0,1
CE/eval,0.10614
CE/train,0.02178
accuracy/eval,0.9806
accuracy/train,0.99372
speed/epoch_sec/eval,0.02874
speed/epoch_sec/train,1.141
speed/samples_per_sec/eval,1391.86116
speed/samples_per_sec/train,205.95972
sys/ram_gb,1.75604


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/nik/.netrc


epoch: 0
train: CE : 0.2819 accuracy : 0.9127 speed/epoch_sec : 1.0989 speed/samples_per_sec : 213.8437 
test: CE : 0.1112 accuracy : 0.9655 speed/epoch_sec : 0.0270 speed/samples_per_sec : 1480.8696 
epoch: 1
train: CE : 0.1276 accuracy : 0.9606 speed/epoch_sec : 1.1742 speed/samples_per_sec : 200.1358 
test: CE : 0.0943 accuracy : 0.9712 speed/epoch_sec : 0.0378 speed/samples_per_sec : 1059.2412 
epoch: 2
train: CE : 0.1028 accuracy : 0.9673 speed/epoch_sec : 1.3224 speed/samples_per_sec : 177.7126 
test: CE : 0.0746 accuracy : 0.9759 speed/epoch_sec : 0.0304 speed/samples_per_sec : 1315.7259 
epoch: 3
train: CE : 0.0839 accuracy : 0.9737 speed/epoch_sec : 1.3679 speed/samples_per_sec : 171.7916 
test: CE : 0.0825 accuracy : 0.9744 speed/epoch_sec : 0.0311 speed/samples_per_sec : 1285.2165 
epoch: 4
train: CE : 0.0820 accuracy : 0.9739 speed/epoch_sec : 1.1963 speed/samples_per_sec : 196.4465 
test: CE : 0.0753 accuracy : 0.9776 speed/epoch_sec : 0.0429 speed/samples_per_sec : 932.94

0,1
CE/eval,█▅▂▄▃▃▂▃▂▃▃▃▂▃▁▃▃▅▄▂▂▅▅▃▃▂▅▄▂▃▂▁▃▃▂▅▃▃▄▃
CE/train,█▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy/eval,▁▃▅▄▅▅▆▆▆▅▆▆▇▇▇▆▆▅▆▇█▆▆▇▇▇▇▆▇▇▇█▇▇▇▆▇▇▇▇
accuracy/train,▁▅▆▆▆▇▇▇▇▇▇▇▇▇██████████████████████████
speed/epoch_sec/eval,▁▃▂▂▃▁▂▁▅▂▁▂▁▁▁▁▁▁▃▃▂▁▅▂▁▁▁▂▁▁▂▁█▁▁▁▁▃▁▁
speed/epoch_sec/train,▂▃▅▆▃▃▂▂▃▂▂▂▆▃▂▂▃▂▃▂▄▁▂▄▁▁▂▆▂█▅▃▄▅▂▄▅▆▆▄
speed/samples_per_sec/eval,█▅▆▆▄█▅▇▃▅▇▆▇▇▇███▄▄▆▇▃▇███▆▇▇▆▇▁▇▇▇▇▄▇█
speed/samples_per_sec/train,▇▅▃▃▅▆▆▆▆▇▇▇▃▆▆▇▅▆▆▆▄▇▇▄██▇▂▇▁▄▆▄▃▇▄▄▂▃▄
sys/ram_gb,▂▁▁▂▁▃▂▂▄▂▅▂█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
CE/eval,0.08123
CE/train,0.02493
accuracy/eval,0.9822
accuracy/train,0.99247
speed/epoch_sec/eval,0.02691
speed/epoch_sec/train,1.23797
speed/samples_per_sec/eval,1486.53795
speed/samples_per_sec/train,189.82702
sys/ram_gb,1.74598


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/nik/.netrc


epoch: 0
train: CE : 0.3185 accuracy : 0.9001 speed/epoch_sec : 1.3825 speed/samples_per_sec : 169.9822 
test: CE : 0.1305 accuracy : 0.9583 speed/epoch_sec : 0.0272 speed/samples_per_sec : 1473.1976 
epoch: 1
train: CE : 0.1468 accuracy : 0.9551 speed/epoch_sec : 1.1688 speed/samples_per_sec : 201.0554 
test: CE : 0.0954 accuracy : 0.9722 speed/epoch_sec : 0.0443 speed/samples_per_sec : 901.9766 
epoch: 2
train: CE : 0.1162 accuracy : 0.9643 speed/epoch_sec : 1.1777 speed/samples_per_sec : 199.5370 
test: CE : 0.0872 accuracy : 0.9734 speed/epoch_sec : 0.0653 speed/samples_per_sec : 612.9268 
epoch: 3
train: CE : 0.1022 accuracy : 0.9680 speed/epoch_sec : 1.2234 speed/samples_per_sec : 192.0900 
test: CE : 0.0870 accuracy : 0.9749 speed/epoch_sec : 0.0416 speed/samples_per_sec : 962.3992 
epoch: 4
train: CE : 0.0916 accuracy : 0.9711 speed/epoch_sec : 1.4599 speed/samples_per_sec : 160.9674 
test: CE : 0.0767 accuracy : 0.9781 speed/epoch_sec : 0.0293 speed/samples_per_sec : 1364.3230

0,1
CE/eval,█▄▃▃▂▂▂▂▂▂▂▃▁▁▂▁▁▁▃▁▂▁▁▂▁▁▂▂▂▃▂▂▁▂▂▂▂▃▂▁
CE/train,█▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy/eval,▁▅▅▆▆▆▆▆▆▇▇▇▇▇▇██▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇█▇▇▇▇█
accuracy/train,▁▅▆▆▇▇▇▇▇▇▇▇▇▇██████████████████████████
speed/epoch_sec/eval,▁▄█▄▂▂▅▂▂▂▂▂▃▂▂▂▁▂▂▃▂▁▃▂▂▂▁▆▄▄▇▂▁▂▁▁▁▂▂▁
speed/epoch_sec/train,▅▂▂▃▆▄▆▃▃▃▃▅█▃▃▂▁▃▄▄▆▁▄▂▂▅▇▄▃▃▆▆▂▂▃▂▂▂▂▁
speed/samples_per_sec/eval,▇▃▁▄▇▇▃▆▆▆▆▆▄▅▆▆▇▆▆▄▅▇▅▆▆▇▇▂▃▃▂▆▇▇▇█▇▇▇▇
speed/samples_per_sec/train,▃▆▆▅▂▄▂▅▆▆▅▃▁▆▅▇█▅▄▅▃▇▅▆▆▃▂▄▅▆▃▂▆▇▆▇▇▇▆▇
sys/ram_gb,▆▆▆▇▇▇▇▆▇▇▆▇▆▇▇▇█▇█▇▆▆▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄

0,1
CE/eval,0.06743
CE/train,0.037
accuracy/eval,0.9838
accuracy/train,0.98878
speed/epoch_sec/eval,0.02768
speed/epoch_sec/train,1.09443
speed/samples_per_sec/eval,1445.06598
speed/samples_per_sec/train,214.72407
sys/ram_gb,1.6406


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/nik/.netrc


epoch: 0
train: CE : 0.3295 accuracy : 0.8973 speed/epoch_sec : 1.1908 speed/samples_per_sec : 197.3511 
test: CE : 0.1267 accuracy : 0.9599 speed/epoch_sec : 0.0305 speed/samples_per_sec : 1310.4743 
epoch: 1
train: CE : 0.1617 accuracy : 0.9501 speed/epoch_sec : 1.2461 speed/samples_per_sec : 188.5888 
test: CE : 0.1009 accuracy : 0.9682 speed/epoch_sec : 0.0316 speed/samples_per_sec : 1266.1094 
epoch: 2
train: CE : 0.1301 accuracy : 0.9592 speed/epoch_sec : 1.5102 speed/samples_per_sec : 155.6132 
test: CE : 0.0928 accuracy : 0.9719 speed/epoch_sec : 0.0595 speed/samples_per_sec : 672.0133 
epoch: 3
train: CE : 0.1132 accuracy : 0.9643 speed/epoch_sec : 1.2265 speed/samples_per_sec : 191.6088 
test: CE : 0.0848 accuracy : 0.9749 speed/epoch_sec : 0.0291 speed/samples_per_sec : 1375.6891 
epoch: 4
train: CE : 0.1027 accuracy : 0.9678 speed/epoch_sec : 1.1676 speed/samples_per_sec : 201.2646 
test: CE : 0.0869 accuracy : 0.9751 speed/epoch_sec : 0.0302 speed/samples_per_sec : 1322.67

0,1
CE/eval,█▅▄▄▄▂▃▃▂▂▂▂▂▁▂▂▂▂▂▂▁▂▁▂▂▁▁▂▂▂▂▂▁▁▂▁▂▁▁▂
CE/train,█▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy/eval,▁▃▄▅▅▆▆▆▆▇▇▆▇▇▇▆▇▇▇▇▇▇▇▇▇██████▇████▇███
accuracy/train,▁▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████████████
speed/epoch_sec/eval,▁▂▆▁▁▁█▄▁▄▆▁▁▂▁▁▁▃▂▁▁▂▁▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▄▁
speed/epoch_sec/train,▃▄█▄▃▄▇▄▆▂▄▅▂▃▁▂▁▂▁▂▃▃▁▁▁▁▂▂▂▄▁▁▁▂▂▁▁▁▅▂
speed/samples_per_sec/eval,▇▇▂▇▇▇▁▃█▃▂▇▇▆▇██▅▆▇▇▇▇▇█▇▇▇▇▇████▇▇▇▇▃▇
speed/samples_per_sec/train,▅▄▁▄▅▅▂▄▃▇▅▃▇▅▇▇▇▇▇▇▅▆████▇▆▆▄▇██▆▇▇██▄▆
sys/ram_gb,▁▂▂▂▃▃▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆█▇██▆▆▆▆▆▆▆▆▆▆▆▆

0,1
CE/eval,0.06653
CE/train,0.04621
accuracy/eval,0.9828
accuracy/train,0.98637
speed/epoch_sec/eval,0.02895
speed/epoch_sec/train,1.11785
speed/samples_per_sec/eval,1381.71648
speed/samples_per_sec/train,210.22448
sys/ram_gb,1.74432


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/nik/.netrc


epoch: 0
train: CE : 0.3565 accuracy : 0.8881 speed/epoch_sec : 1.0327 speed/samples_per_sec : 227.5601 
test: CE : 0.1335 accuracy : 0.9585 speed/epoch_sec : 0.0343 speed/samples_per_sec : 1167.3787 
epoch: 1
train: CE : 0.1843 accuracy : 0.9426 speed/epoch_sec : 1.1019 speed/samples_per_sec : 213.2756 
test: CE : 0.1009 accuracy : 0.9685 speed/epoch_sec : 0.0281 speed/samples_per_sec : 1424.1152 
epoch: 2
train: CE : 0.1500 accuracy : 0.9527 speed/epoch_sec : 1.0297 speed/samples_per_sec : 228.2309 
test: CE : 0.0939 accuracy : 0.9713 speed/epoch_sec : 0.0289 speed/samples_per_sec : 1381.9555 
epoch: 3
train: CE : 0.1342 accuracy : 0.9584 speed/epoch_sec : 0.9704 speed/samples_per_sec : 242.1700 
test: CE : 0.0934 accuracy : 0.9702 speed/epoch_sec : 0.0294 speed/samples_per_sec : 1361.4223 
epoch: 4
train: CE : 0.1251 accuracy : 0.9608 speed/epoch_sec : 1.0138 speed/samples_per_sec : 231.8050 
test: CE : 0.0827 accuracy : 0.9756 speed/epoch_sec : 0.0272 speed/samples_per_sec : 1468.7

0,1
CE/eval,█▅▅▄▄▄▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▃▃▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▁
CE/train,█▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy/eval,▁▄▄▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▆▇███▇███▇▇▇██▇██
accuracy/train,▁▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████████████
speed/epoch_sec/eval,▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▂█▁▂▁▁
speed/epoch_sec/train,▂▃▂▁▂▂▆▂▂▄▁▁▃▃▃▅▃▄▅▂▃▃▂▂▅▃▃▁▁▃▂▁▁▅█▅▃▅▅▄
speed/samples_per_sec/eval,▆▇▇▇██▇█▇███▇▅▆█▇▇▇▇▇█▆█▄██▇█▇█▇▇▇▆▁█▅█▇
speed/samples_per_sec/train,▇▆▇█▇▆▂▆▇▄██▅▆▅▄▆▄▄▆▅▅▇▆▃▅▆▇▇▅▆▇▇▃▁▃▆▃▄▄
sys/ram_gb,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▂▂▁▃▃▆▄▃█▅▃▃▃▃▃

0,1
CE/eval,0.05381
CE/train,0.05953
accuracy/eval,0.9855
accuracy/train,0.98178
speed/epoch_sec/eval,0.02834
speed/epoch_sec/train,1.26497
speed/samples_per_sec/eval,1411.35632
speed/samples_per_sec/train,185.77448
sys/ram_gb,1.76048


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/nik/.netrc


epoch: 0
train: CE : 0.3837 accuracy : 0.8805 speed/epoch_sec : 1.1467 speed/samples_per_sec : 204.9344 
test: CE : 0.1246 accuracy : 0.9608 speed/epoch_sec : 0.0335 speed/samples_per_sec : 1192.8175 
epoch: 1
train: CE : 0.2037 accuracy : 0.9371 speed/epoch_sec : 1.1186 speed/samples_per_sec : 210.0756 
test: CE : 0.1069 accuracy : 0.9681 speed/epoch_sec : 0.0278 speed/samples_per_sec : 1438.5979 
epoch: 2
train: CE : 0.1749 accuracy : 0.9459 speed/epoch_sec : 1.0646 speed/samples_per_sec : 220.7309 
test: CE : 0.0908 accuracy : 0.9714 speed/epoch_sec : 0.0289 speed/samples_per_sec : 1383.7792 
epoch: 3
train: CE : 0.1518 accuracy : 0.9531 speed/epoch_sec : 1.2708 speed/samples_per_sec : 184.9285 
test: CE : 0.0845 accuracy : 0.9748 speed/epoch_sec : 0.0276 speed/samples_per_sec : 1451.2409 
epoch: 4
train: CE : 0.1410 accuracy : 0.9560 speed/epoch_sec : 1.1643 speed/samples_per_sec : 201.8397 
test: CE : 0.0811 accuracy : 0.9736 speed/epoch_sec : 0.0454 speed/samples_per_sec : 880.54

0,1
CE/eval,█▆▅▄▃▃▃▃▂▂▂▂▂▂▁▂▂▂▂▂▁▁▁▂▁▁▂▂▂▁▁▁▁▁▁▁▁▂▁▁
CE/train,█▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
accuracy/eval,▁▃▄▅▅▆▆▅▇▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇██████▇█▇
accuracy/train,▁▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇██████████████████████
speed/epoch_sec/eval,▂▁▁▁▄▁▁▁▂▁▄▅▂▂▁▅▁▁▁▁▁▁▁▁▁▂█▄▁▁▁▂▁▁▁▂▁▁▁▁
speed/epoch_sec/train,▃▃▂▄▃▂▄▄▇▃▄▂▃▂▄▃▂▄▄▄▁▁▂▄▂▃█▅▂▂▃▂▂▅▄▁▂▂▃▂
speed/samples_per_sec/eval,▆█▇█▃▇▇█▆▇▃▂▆▅▇▃▇▇▇▇█████▆▁▃▇▇█▅█▇█▇▇▇▇█
speed/samples_per_sec/train,▅▅▆▄▅▇▄▄▂▆▄▆▆▆▄▅▆▄▄▄██▆▄▆▅▁▃▆▇▅▆▇▃▄▇▆▆▅▇
sys/ram_gb,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁███

0,1
CE/eval,0.05982
CE/train,0.07463
accuracy/eval,0.982
accuracy/train,0.977
speed/epoch_sec/eval,0.02712
speed/epoch_sec/train,0.98606
speed/samples_per_sec/eval,1474.92009
speed/samples_per_sec/train,238.32223
sys/ram_gb,1.83504
