In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai import datasets
from pathlib import Path
from IPython.core.debugger import set_trace
import pickle, gzip, math, torch, re, matplotlib as mpl, matplotlib.pyplot as plt
from functools import partial

from typing import Iterable
from torch import tensor, nn, optim
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
import torch.nn.functional as F

In [3]:
class Dataset():
    def __init__(self, x_ds, y_ds):
        self.x_dataset = x_ds
        self.y_dataset = y_ds
    
    def __len__(self):
        return len(self.x_dataset)
    
    def __getitem__(self, i):
        return self.x_dataset[i],self.y_dataset[i]

class Callback():
    def begin_fit(self, model, optimizer, loss_func, train_data, valid_data):
        self.model = model
        self.opt = optimizer
        self.loss_function = loss_func
        self.train_dl = train_data
        self.valid_dl = valid_data
        self.stop = False
        return True
    
    def after_fit(self):
        return True
    
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True
    
    def begin_validate(self):
        return True
    
    def after_epoch(self):
        return True
    
    def begin_batch(self, x_batch, y_batch):
        self.x_mini_batch = x_batch
        self.y_mini_batch = y_batch
        return True

    def after_loss(self, loss):
        self.loss = loss
        return True
    
    def after_backward(self):
        return True
    def after_step(self):
        return True
    
class CallbackHandler():
    def __init__(self,cbs=None):
        self.cbs = cbs if cbs else []

    def begin_fit(self, model,optimizer, loss_func, train_data, valid_data):
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_func
        self.train_dl = train_data
        self.valid_dl = valid_data
        self.stop = False
        self.in_train = True
        res = True
        for cb in self.cbs:
                res = res and cb.begin_fit(model,optimizer, loss_func, train_dl, valid_dl)
        return res

    def after_fit(self):
        res = not self.in_train
        for cb in self.cbs: res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        self.model.train()
        self.in_train=True
        res = True
        for cb in self.cbs: res = res and cb.begin_epoch(epoch)
        return res

    def begin_validate(self):
        self.model.eval()
        self.in_train=False
        res = True
        for cb in self.cbs: res = res and cb.begin_validate()
        return res

    def after_epoch(self):
        res = True
        for cb in self.cbs: res = res and cb.after_epoch()
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
        return res

    def after_loss(self, loss):
        res = self.in_train
        for cb in self.cbs: res = res and cb.after_loss(loss)
        return res

    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in self.cbs: 
            res = res and cb.after_step()
            self.stop = cb.stop
            if self.stop is False:
                break
        return res
    
    def do_stop(self): #signalled by a call back
        try:     return self.stop
        finally: self.stop = False

In [4]:
#redefining fit()
def one_batch(x_minib, y_minib, callbacks):
    if not callbacks.begin_batch(x_minib,y_minib):
        return
    loss = callbacks.loss_function(callbacks.model(x_minib), y_minib)
    if not callbacks.after_loss(loss):
        return
    loss.backward()
    if callbacks.after_backward(): 
        callbacks.optimizer.step()
    if callbacks.after_step():
        callbacks.optimizer.zero_grad()
    
def all_batches(dataloader, callbacks):
    for x_minib, y_minib in dataloader:
        one_batch(x_minib, y_minib, callbacks)
        if callbacks.do_stop():
            return
    
def fit(num_epochs, model, optimizer, loss_func, train_dataloader, valid_dataloader, callbacks):
    if not callbacks.begin_fit(model, optimizer, loss_func, train_dataloader, valid_dataloader):
        return
    for epoch in range(num_epochs):
        if not callbacks.begin_epoch(epoch):
            continue
        all_batches(train_dataloader, callbacks)
        
        if callbacks.begin_validate():
            with torch.no_grad():
                all_batches(valid_dataloader, callbacks)
        if callbacks.do_stop() or not callbacks.after_epoch():
            break
        callbacks.after_fit()

In [5]:
class TestCallback(Callback):
    def begin_fit(self,model, opt, loss_func, train_dl, valid_dl):
        super().begin_fit(model, opt, loss_func, train_dl, valid_dl)
        self.n_iters = 0
        return True
        
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=10: 
            self.stop = True
        return True

In [6]:
#create simple 3 layer example model

def get_model(training_data, lr=0.5, nh=50):
    m = training_data.x_dataset.shape[1]
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,categories))
    return model, optim.SGD(model.parameters(), lr=lr)


MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'
def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train,y_train,x_valid,y_valid))


In [7]:
number_hidden = 50
batch_size = 64
loss_func = F.cross_entropy
x_train,y_train,x_valid,y_valid = get_data()

#setup data
train_ds = Dataset(x_train, y_train)
valid_ds = Dataset(x_valid, y_valid)
train_dl = DataLoader(train_ds, batch_size, shuffle=True, drop_last=True)
valid_dl = DataLoader(valid_ds, batch_size, shuffle=False)
categories = y_train.max().item()+1
model, optimizer = get_model(train_ds)

In [8]:
fit(1, model, optimizer, loss_func, train_dl, valid_dl, callbacks=CallbackHandler([TestCallback()]))

1
2
3
4
5
6
7
8
9
10


---
# Stream-lining the Callback Interface:

### Helper Functions

In [9]:
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

def convert_to_list(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

#export
def accuracy(out, yb): return (torch.argmax(out, dim=1)==yb).float().mean()

def get_model(training_dl, lr=0.5, nh=50):
    m = training_dl.x_dataset.shape[1]
    categories = training_dl.y_dataset.max().item()+1
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,categories))
    return model, optim.SGD(model.parameters(), lr=lr)

### The Refactored Callback

In [10]:
class Callback():
    _order=0
    def set_runner(self, runner): self.run=runner
    def __getattr__(self, k): return getattr(self.run, k)
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

#callback included in every runner, tracks iterations and epochs
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs=0.
        self.run.n_iter=0
    
    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter   += 1
        
    def begin_epoch(self):
        self.run.n_epochs=self.epoch
        self.model.train()
        self.run.in_train=True

    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False

#new version of test callback
class TestCallback(Callback):
    def after_step(self):
        if self.train_eval.n_iters>=10:
            return True #stops training with a True flag


### Tracking Statistics using Callbacks

In [11]:
class AvgStats():
    def __init__(self, metrics, in_train):
        self.metrics = convert_to_list(metrics)
        self.in_train = in_train
        
    def reset(self):
        self.total_loss = 0.
        self.count = 0
        self.total_metrics = [0.] * len(self.metrics)
        
    @property
    def all_stats(self): 
        return [self.total_loss.item()] + self.total_metrics
    @property
    def avg_stats(self):
        return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run):
        batch_size = run.xb.shape[0]
        self.total_loss += run.loss * batch_size
        self.count += batch_size
        for i,m in enumerate(self.metrics):
            self.total_metrics[i] += m(run.pred, run.yb) * batch_size

class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats = AvgStats(metrics,True)
        self.valid_stats = AvgStats(metrics,False)
        
    def begin_epoch(self):        
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad():
            stats.accumulate(self.run)
    
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

### Runner: a class to wrap the core classes together and train the model

In [12]:
class Runner():
    def __init__(self, cbs=None, cb_functions=None):
        callbacks = convert_to_list(cbs)
        for cbf in convert_to_list(cb_functions): #convert functions to callbacks
            cb = cbf()
            setattr(self, cb.name, cb)
            callbacks.append(cb)
        self.stop = False
        self.callbacks = [TrainEvalCallback()]+callbacks

    def one_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        if self('begin_batch'): return
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return
        self.opt.zero_grad()

    def all_batches(self, dl):
        self.iters = len(dl)
        for xb,yb in dl:
            if self.stop: break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop=False

    #the new training loop
    def fit(self, epochs_in, model_in, optimizer_in, loss_function_in, train_dl_in, valid_dl_in):
        self.epochs = epochs_in
        self.model = model_in
        self.opt = optimizer_in
        self.loss_func = loss_function_in
        self.train_dl = train_dl_in
        self.valid_dl = valid_dl_in
        try:
            for cb in self.callbacks: cb.set_runner(self)
            if self('begin_fit'): return
            for epoch in range(self.epochs):
                self.epoch = epoch
                #training
                if not self('begin_epoch'): #callbacks default to false returns
                    self.all_batches(self.train_dl)
                
                #validation 
                with torch.no_grad(): 
                    if not self('begin_validate'): 
                        self.all_batches(self.valid_dl)
                if self('after_epoch'): break
            
        finally:
            self('after_fit')
            self.model = None
            self.opt = None
            self.loss_func = None
            self.train_dl = None
            self.valid_dl = None

    def __call__(self, cb_name):
        for cb in sorted(self.callbacks, key=lambda x: x._order):
            f = getattr(cb, cb_name, None)
            if f and f():
                return True
        return False
    

---
Training Example

In [13]:
training_ds = Dataset(x_train, y_train)
training_dl = DataLoader(training_ds, shuffle = True)
validation_ds = Dataset(x_valid, y_valid)
validation_dl = DataLoader(valid_ds, shuffle = False)

mod, opt = get_model(training_ds)
stats = AvgStatsCallback([accuracy])
run = Runner(cbs=stats)
run.fit(2, mod, opt, F.cross_entropy, training_dl, validation_dl)

train: [2.424024375, tensor(0.0993)]
valid: [2.387219921875, tensor(0.1064)]
train: [2.4178259375, tensor(0.1007)]
valid: [2.4089341796875, tensor(0.1064)]


## Employing the use of Python Partials

In [14]:
accuracy_callback_func = partial(AvgStatsCallback, accuracy) #makes use of Runner's callback function conversion
run = Runner(cb_functions = accuracy_callback_func)
run.fit(1, mod, opt, F.cross_entropy, training_dl, validation_dl)

train: [2.4177571875, tensor(0.1011)]
valid: [2.434072265625, tensor(0.0961)]


In [15]:
run.avg_stats.valid_stats.avg_stats

[2.434072265625, tensor(0.0961)]