In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from pathlib import Path
from IPython.core.debugger import set_trace
from fastai import datasets
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt
from torch import tensor, nn, optim
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
import torch.nn.functional as F
# from data_utilities import Dataset

In [52]:
class Dataset():
    def __init__(self, x_ds, y_ds):
        self.x_dataset = x_ds
        self.y_dataset = y_ds
    
    def __len__(self):
        return len(self.x_dataset)
    
    def __getitem__(self, i):
        return self.x_dataset[i],self.y_dataset[i]

class Callback():
    def begin_fit(self, model, optimizer, loss_func, train_data, valid_data):
        self.model = model
        self.opt = optimizer
        self.loss_function = loss_func
        self.train_dl = train_data
        self.valid_dl = valid_data
        return True
    
    def after_fit(self):
        return True
    
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True
    
    def begin_validate(self):
        return True
    
    def after_epoch(self):
        return True
    
    def begin_batch(self, x_batch, y_batch):
        self.x_mini_batch = x_batch
        self.y_mini_batch = y_batch
        return True

    def after_loss(self, loss):
        self.loss = loss
        return True
    
    def after_backward(self):
        return True
    def after_step(self):
        return True
    
class CallbackHandler():
    def __init__(self,cbs=None):
        self.cbs = cbs if cbs else []

    def begin_fit(self, model,optimizer, loss_func, train_data, valid_data):
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_func
        self.train_dl = train_data
        self.valid_dl = valid_data
        self.stop = False
        self.in_train = True
        res = True
        for cb in self.cbs:
                res = res and cb.begin_fit(model,optimizer, loss_func, train_dl, valid_dl)
        return res

    def after_fit(self):
        res = not self.in_train
        for cb in self.cbs: res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        self.model.train()
        self.in_train=True
        res = True
        for cb in self.cbs: res = res and cb.begin_epoch(epoch)
        return res

    def begin_validate(self):
        self.model.eval()
        self.in_train=False
        res = True
        for cb in self.cbs: res = res and cb.begin_validate()
        return res

    def after_epoch(self):
        res = True
        for cb in self.cbs: res = res and cb.after_epoch()
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
        return res

    def after_loss(self, loss):
        res = self.in_train
        for cb in self.cbs: res = res and cb.after_loss(loss)
        return res

    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in self.cbs: 
            res = res and cb.after_step()
            self.stop = cb.stop
            if self.stop is False:
                break
        return res
    
    def do_stop(self): #signalled by a call back
        try:     return self.stop
        finally: self.stop = False

In [22]:
#redefining fit()
def one_batch(x_minib, y_minib, callbacks):
    if not callbacks.begin_batch(x_minib,y_minib):
        return
    loss = callbacks.loss_function(callbacks.model(x_minib), y_minib)
    if not callbacks.after_loss(loss):
        return
    loss.backward()
    if callbacks.after_backward(): 
        callbacks.optimizer.step()
    if callbacks.after_step():
        callbacks.optimizer.zero_grad()
    
def all_batches(dataloader, callbacks):
    for x_minib, y_minib in dataloader:
        one_batch(x_minib, y_minib, callbacks)
        if callbacks.do_stop():
            return
    
def fit(num_epochs, model, optimizer, loss_func, train_dataloader, valid_dataloader, callbacks):
    if not callbacks.begin_fit(model, optimizer, loss_func, train_dataloader, valid_dataloader):
        return
    for epoch in range(num_epochs):
        if not callbacks.begin_epoch(epoch):
            continue
        all_batches(train_dataloader, callbacks)
        
        if callbacks.begin_validate():
            with torch.no_grad():
                all_batches(valid_dataloader, callbacks)
        if callbacks.do_stop() or not callbacks.after_epoch():
            break
        callbacks.after_fit()

In [23]:
class TestCallback(Callback):
    def begin_fit(self,model, opt, loss_func, train_dl, valid_dl):
        super().begin_fit(model, opt, loss_func, train_dl, valid_dl)
        self.n_iters = 0
        return True
        
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=10: 
            self.stop = True
        return True

In [46]:
#create simple 3 layer example model

def get_model(training_data, lr=0.5, nh=50):
    m = training_data.x_dataset.shape[1]
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,categories))
    return model, optim.SGD(model.parameters(), lr=lr)


MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'
def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train,y_train,x_valid,y_valid))


In [53]:
number_hidden = 50
batch_size = 64
loss_func = F.cross_entropy
x_train,y_train,x_valid,y_valid = get_data()

#setup data
train_ds = Dataset(x_train, y_train)
valid_ds = Dataset(x_valid, y_valid)
train_dl = DataLoader(train_ds, batch_size, shuffle=True, drop_last=True)
valid_dl = DataLoader(valid_ds, batch_size, shuffle=False)
categories = y_train.max().item()+1
model, optimizer = get_model(train_ds)

In [54]:
fit(1, model, optimizer, loss_func, train_dl, valid_dl, callbacks=CallbackHandler([TestCallback()]))

1
2
3
4
5
6
7
8
9
10
