In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
from os.path import join
import math

sys.path.insert(1, join(sys.path[0], '..'))
from generated.operations import *

In [3]:
def lin(x, w, b): 
    return x @ w + b

def lin_grad(inp, out, w, b):
    inp.g = out.g @ w.t()
    w.g = inp.t() @ out.g
    b.g = out.g.sum(0)
    
def relu(x):
    return x.clamp_min(0.) - 0.5

def relu_grad(inp, out):
    inp.g = (inp > 0).float() * out.g

def Flatten(inp):
    return inp.view(-1)    
    
def mse(pred, y):
    return (pred.squeeze(-1) - y).pow(2).mean()

def mse_grad(inp, tar):
    inp.g = 2.*(inp.squeeze() - tar).unsqueeze(-1) / inp.shape[0]

In [4]:
#export
def he_init(m, n):
    return torch.randn(m, n) * (2./m)**0.5

def init(m, n, relu):
    if relu:
        return he_init(m, n)
    return torch.randn(m, n) * (1./m)**0.5

In [5]:
def forward_backward(inp, tar):
    # forward
    l1 = lin(inp, w1, b1)
    l2 = relu(l1)
    out = lin(l2, w2, b2)
    loss = mse(out, tar)
    # backward
    mse_grad(out, tar)
    lin_grad(l2, out, w2, b2)
    relu_grad(l1, l2)
    lin_grad(inp, l1, w1, b1)

In [6]:
#export
class Module():
    def __call__(self, *args):
        self.args = args
        self.out = self.fwd(*args)
        return self.out
    
    def forward(self):
        raise Exception('not implemented')
    
    def backward(self):
        self.bwd(self.out, *self.args)

class ReLU(Module):
    def fwd(self, inp):
        return inp.clamp_min(0.) - 0.5
    
    def bwd(self, out, inp):
        inp.g = (inp > 0).float() * out.g

class Lin(Module):
    def __init__(self, w, b):
        self.w = w
        self.b = b
    
    def fwd(self, inp):
        return inp @ self.w + self.b
    
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        self.w.g = inp.t() @ out.g
        self.b.g = out.g.sum(0)

class Flatten(Module):
    def fwd(self,x): 
        return x.view(-1)
    
    def bwd(self):
        pass
    
class Mse(Module):
    def fwd(self, inp, tar):
        return (inp.squeeze() - tar).pow(2).mean()
    
    def bwd(self, out, inp, tar):
        inp.g = 2 * (inp.squeeze() - tar).unsqueeze(-1) / tar.shape[0]

class Model():
    def __init__(self, layers):
        self.layers = layers
        self.loss = Mse()
    
    def __call__(self, x, tar):
        for layer in self.layers:
            x = layer(x)
        return self.loss(x, tar)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()

In [7]:
x_train, y_train, x_valid, y_valid = get_mnist_data()
x_train, x_valid = normalize_data(x_train, x_valid)

nh = 50 # hidden cells
(n, m), c = x_train.shape, y_train.max() + 1

In [8]:
w1 = init(m, nh, True)
b1 = torch.zeros(nh)
w2 = init(nh, 1, False)
b2 = torch.zeros(1)

In [9]:
forward_backward(x_train, y_train)

In [10]:
xt2 = x_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

model = Model([Lin(w12, b12), ReLU(), Lin(w22, b22)])
loss = model(xt2, y_train)
model.backward()

test_near(w22.g, w2.g)
test_near(b22.g, b2.g)
test_near(w12.g, w1.g)
test_near(b12.g, b1.g)
test_near(xt2.g, x_train.g)

In [11]:
xt3 = x_train.clone().requires_grad_(True)
w13 = w1.clone().requires_grad_(True)
w23 = w2.clone().requires_grad_(True)
b13 = b1.clone().requires_grad_(True)
b23 = b2.clone().requires_grad_(True)

def forward(x, targ):
    x = lin(x, w13, b13)
    x = relu(x)
    out = lin(x, w23, b23)
    return mse(out, targ)

loss = forward(xt3, y_train)
loss.backward()

test_near(w23.grad, w2.g)
test_near(b23.grad, b2.g)
test_near(w13.grad, w1.g)
test_near(b13.grad, b1.g)
test_near(xt3.grad, x_train.g)

In [12]:
!python3 ../notebook2script.py fully_connected.ipynb

Converted fully_connected.ipynb to ../generated/fully_connected.py
