In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
from os.path import join
import math

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from operations import *
from fcn_revised import *

In [3]:
class DataLoader():
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        
    def __iter__(self):
        for i in range(0, len(self.dataset), self.batch_size):
            yield self.dataset[i: i+self.batch_size]

In [None]:
class Dataset():
    def __init__(self, x_chunk, y_chunk):
        self.x_chunk = x_chunk
        self.y_chunk = y_chunk
    
    def __len__(self):
        return len(self.x_chunk)
    
    def __getitem__(self, i):
        return self.x_chunk[i], self.y_chunk[i]

In [None]:
#export
def softmax(inp):
    # prone to overflow (floating aint precise)
    return inp.exp() / inp.exp().sum(-1, keepdim=True)

def log_sum_exp(inp):
    e = inp.max(-1)[0]
    return e + (inp - e[:, None]).exp().sum(-1).log()
    
def log_softmax(inp):
    # LogSumExp trick to avoid floating point error
    return inp - log_sum_exp(inp).unsqueeze(-1)

def nll_loss(pre, tar):
    # use multiple indexing 
    return -pre[range(tar.shape[0]), tar].mean()

def cross_entropy(inp, tar):
    return nll_loss(log_softmax(inp), tar)

def accuracy(pre, tar):
    return (torch.argmax(pre, dim=1) == tar).float().mean()

In [None]:
from torch import nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in, nh), nn.ReLU(), nn.Linear(nh, n_out)]
        
    def __call__(self, x):
        for l in self.layers:
            x = l(x)
        return x

model = Model(m, nh, c)
pred = model(x_train)

In [None]:
loss1 = nll_loss(log_softmax(pred), y_train)
loss2 = cross_entropy(pred, y_train)
loss3 = F.nll_loss(F.log_softmax(pred, -1), y_train)
loss4 = F.cross_entropy(pred, y_train)

test_near(loss1, loss2)
test_near(loss2, loss3)
test_near(loss3, loss4)

In [5]:
class Optimizer():
    def __init__(self, parameters, learning_rate):
        self.parameters = list(parameters)
        self.learning_rate = learning_rate
    
    def step(self):
        for param in self.parameters:
            param.step(self.learning_rate)
    
    def zero_grad(self):
        for param in self.parameters:
            param.zero_grad()

In [None]:
# def fit(num_epochs, model, optim, loss_func, ds_train, ds_valid):
#     for epoch in range(num_epochs):
#         for batch in range(math.ceil(len))

In [None]:
model = Sequential(Linear(in_dim, nh), ReLU(), Linear(nh, out_dim, True))
optimizer = Optimizer(model.parameters(), 0.5)

In [4]:
x_train, y_train, x_valid, y_valid = get_mnist_data()
x_train, x_valid = normalize_data(x_train, x_valid)

batch_size = 64
num_hidden = 50 # hidden cells
learning_rate = 0.5
(n, m), c = x_train.shape, int(y_train.max() + 1)