In [1]:
#export
import sys
from os.path import join
import math

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from operations import *

In [2]:
#export
def init_weight_he(in_dim, num_hidden):
    return torch.randn(in_dim, num_hidden) * (2./in_dim) ** 0.5

def init_weight_norm(in_dim, num_hidden):
    return torch.randn(in_dim, num_hidden) * (2./in_dim) ** 0.5

def init_weight(in_dim, num_hidden, end=False):
    if end:
        return init_weight_norm(in_dim, num_hidden)
    else:
        return init_weight_he(in_dim, num_hidden)

def init_bias_zero(num_hidden):
    return torch.zeros(num_hidden)

def init_bias_uni(num_hidden):
    return torch.zeros()

def init_bias(num_hidden, zero=True):
    if zero:
        return init_bias_zero(num_hidden)
    else:
        return init_bias_uni(num_hidden)

In [3]:
#export
class Parameter():
    # own data struct instead of acyclic directed graph
    def __init__(self, data=torch.Tensor(), requires_grad=True):
        self.data = data
        self.requires_grad = requires_grad 
        self.grad = 0
        
    def __get__(self, instance, owner):
        return self.data
    
    def step(self, lr):
        self.data -= lr * self.grad
    
    def zero_data(self):
        self.data = torch.zero_()
        
    def zero_grad(self):
        self.grad = 0
    
    def update(self, new_grad):
        self.grad = new_grad     

In [4]:
#export
class Sequential():
    def __init__(self, *args):
        self.layers = list(args)
        self.train = True
    
    def __call__(self, x):
        for layer in layers:
            x = layer(x)
        return x
    
    def backward(self):
        for layer in reversed(self.layers):
            layer.backward()
        
    def parameters(self):
        for layer in self.layers:
            for parameter in layer.parameters():
                yield parameter

In [5]:
#export
class Module():
    def __init__(self, name=None):
        self._parameters = {}
        self.name = name if name else 'unnamed_module'
        
    def __setattr__(self, k, v):
        if isinstance(v, Parameter): 
            self._parameters[k] = v
        super().__setattr__(k, v)
        
    def __call__(self, *args):
        self.args = args
        self.out = self.fwd(*args)
        return self.out
    
    def parameters(self):
        for p in self._parameters.values():
            yield p
    
    def forward(self):
        raise NotImplementedError('Module.forward')
    
    def backward(self):
        self.bwd(self.out, *self.args)

In [6]:
#export
class Linear(Module):
    def __init__(self, in_dim, num_hidden, end=False, require_grad=True):
        super().__init__()
        self.w = Parameter(init_weight(in_dim, num_hidden, end), require_grad)
        self.b = Parameter(init_bias(num_hidden))
    
    def fwd(self, inp):
        return inp @ self.w.data + self.b
    
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.data.t()
        self.w.update(inp.t() @ out.g)
        self.b.update(out.g.sum(0))

class ReLU(Module):
    def fwd(self, inp):
        return inp.clamp_min(0.) - 0.5
    
    def bwd(self, out, inp):
        inp.g = (inp > 0).float() * out.g

In [7]:
x_train, y_train, x_valid, y_valid = get_mnist_data()
x_train, x_valid = normalize_data(x_train, x_valid)

nh = 50
(num_data, in_dim), out_dim = x_train.shape, y_train.max() + 1

In [8]:
model = Sequential(Linear(in_dim, nh), ReLU(), Linear(nh, out_dim, True))