In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_03 import *

# DataBunch / Learner

In [3]:
x_train,y_train,x_valid,y_valid = get_data()
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)
nh,bs = 50,64
c = y_train.max().item()+1
loss_func = F.cross_entropy

We now have a lot of parameters to pass to fit the model:

model, loss_fun, optimize, train_df, valid_dl

We will all store them in a "Learner".

The Databunch will hold all the dataset , dataloaders and classes:

In [4]:
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl = train_dl
        self.valid_dl = valid_dl
        self.c = c
    
    @property
    def train_ds(self): return self.train_dl.dataset
    
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [5]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [6]:
def get_model(data, lr=0.5, nh=50):
    m = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,data.c))
    return model, optim.SGD(model.parameters(), lr=lr)

class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model,self.opt,self.loss_func,self.data = model,opt,loss_func,data

In [7]:
learn = Learner(*get_model(data), loss_func, data)

Now we access these parameters through the learner class:

In [8]:
def fit(epochs, learn):
    for epoch in range(epochs):
        
        learn.model.train()
        
        for xb,yb in learn.data.train_dl:
            loss = learn.loss_func(learn.model(xb), yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()

        learn.model.eval()
        
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += learn.loss_func(pred, yb)
                tot_acc  += accuracy (pred,yb)
                
        nv = len(learn.data.valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv

In [9]:
loss, acc = fit(1, learn)

0 tensor(0.4097) tensor(0.8783)


# CallbackHandler

Our initial training loop looked like this:

In [10]:
def one_batch(xb,yb):
    pred = model(xb)
    loss = loss_func(pred, yb)
    loss.backward()
    opt.step()
    opt.zero_grad()
    
def fit():
    for epoch in range(epochs):
        for b in train_dl: one_batch(*b)

Now we add callbacks at ***each and every step*** of the training to be able to tweak the training loop.

The callback will store all the learner and his parameters:

In [11]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()  
    if cb.after_step(): cb.learn.opt.zero_grad()
        
        
def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return

        
def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb)
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()

Each event CallBack returns True or False for the trainning to continue,

This is the CallBack superclass, every other callback will be modifications of this one :

In [12]:
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    
    def after_fit(self): return True
    
    def begin_epoch(self, epoch):
        self.epoch=epoch
        return True
    
    def begin_validate(self): return True
    
    def after_epoch(self): return True
    
    def begin_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        return True
    
    def after_loss(self, loss):
        self.loss = loss
        return True
    
    def after_backward(self): return True
    
    def after_step(self): return True

Each callback will dictate some kinf of trainning behabior and there can be many CallBacks for the same train, so we need a CallBackHandler to take care of the diffrent cbs:

In [13]:
class CallbackHandler():
    def __init__(self,cbs=None):
        self.cbs = cbs if cbs else []

    def begin_fit(self, learn):
        self.learn = learn
        self.in_train = True
        learn.stop = False
        res = True
        for cb in self.cbs: 
            res = res and cb.begin_fit(learn)
        return res

    def after_fit(self):
        res = not self.in_train
        for cb in self.cbs: res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        self.learn.model.train()
        self.in_train=True
        res = True
        for cb in self.cbs: res = res and cb.begin_epoch(epoch)
        return res

    def begin_validate(self):
        self.learn.model.eval()
        self.in_train=False
        res = True
        for cb in self.cbs: res = res and cb.begin_validate()
        return res

    def after_epoch(self):
        res = True
        for cb in self.cbs: res = res and cb.after_epoch()
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
        return res

    def after_loss(self, loss):
        res = self.in_train
        for cb in self.cbs: res = res and cb.after_loss(loss)
        return res

    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in self.cbs: res = res and cb.after_step()
        return res
    
    def do_stop(self):
        try:     return self.learn.stop
        finally: self.learn.stop = False

Lets try a basic callback, it keeps track of the number of iterations of the learning, increment the count after each step and stop the training at 10 iterations:

In [14]:
class TestCallback(Callback):
    def begin_fit(self,learn):
        super().begin_fit(learn) #Stores the learner's object in the current cb object's parameters
        self.n_iters = 0
        return True
        
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=10: self.learn.stop = True
        return True

In [15]:
fit(1, learn, cb=CallbackHandler([TestCallback()]))

1
2
3
4
5
6
7
8
9
10


# the Callback class

This function camel2snake will just convert a string in another compilable form:

In [16]:
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')

def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

In [1]:
class Callback():
    _order = 0 # define in what order to execute the callbacks
    
    def set_runner(self, run):
        # Allow the callback to access the runner props
        self.run = run
        
    def __getattr__(self, k):
        return getattr(self.run, k) # looks for attribute 'k' in 'run' instead of creating a mthod for each event
    
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

## Train/Eval Callback 

This Callback handles switching the model between train an eval modes and keeps track of progress:

In [18]:
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs = 0
        self.run.n_iter = 0
    
    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter   += 1
    
    def begin_epoch(self):
        self.run.n_epochs = self.epoch
        self.model.train()
        self.run.in_train = True
    
    def begin_validate(self):
        self.model.eval()
        self.run.in_train = False

Above, the methods are able to access self.model because these methods will become methods of the Runner object whose able to access the learners parameters.

Let's re-create the callback that stops the train after 10 iters:

In [19]:
class TestCallback(Callback):
    def after_step(self):
        if self.train_eval.n_iters >= 10: return True

In [20]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)

'train_eval_callback'

Here's how the callback's name is returned with the .name propriety:

In [21]:
TrainEvalCallback().name

'train_eval'

First we need a func to convert the callbacks passed into a list:

In [23]:
from typing import *

def listify(o):
    """
    Converts objtect passed to a list.
    """
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str):  return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

# The Runner Class 

In [3]:
class Runner():
    """
    This class will handle all the callbacks 
    and training mechanism.
    It takes a list of callbacks
    The fit function will take a Learner 
    """
    
    def __init__(self, cbs = None, cb_funcs = None):
        """ Create a list of all callbacks 
        and make them also accessible as atttributes of the runner """
        
        cbs = listify( cbs )
        #### Sets each callback functions as attributes of the runner object
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb)
            cbs.append(cb)
        
        self.stop = False
        
        #### Créate a list of all callback functions.
        self.cbs  = [TrainEvalCallback()] + cbs
    
    """ Make learner's parameters accessible through the Runner object """
    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data
    
    """ integrate fit and batch functions """
    def one_batch(self, xb, yb):
        self.xb = xb
        self.yb = yb
        
        if self('begin_batch'): return
        
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return # in_train allow to skip optimization when infering
        
        self.loss.backward()
        if self('after_backward'): return
        
        self.opt.step()
        if self('after_step'): return
        
        self.opt.zero_grad()
        
    def all_batches(self, dl):
        self.iters = len(dl) ### Total length of the databunch
        
        for xb, yb in dl:
            if self.stop : break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop = False
    
    def fit(self, epochs, learn):
        self.epochs = epochs
        self.learn = learn
        
        try:
            for cb in self.cbs: cb.set_runner(self) ### Pass the runner obj to all cbs
            
            if self('begin_fit'): return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'):
                    self.all_batches( self.data.train_dl )
                
                with torch.no_grad():
                    if not self('begin_validate'):
                        self.all_batches( self.data.valid_dl )
                
                if self('after_epoch'): break
        finally:
            self('after_fit')
            self.learn = None
    
    def __call__(self, cb_name):  ############################################################# THIS is the most important method!
        
        """ Get all Callbacks in the order defined by the _order attribute """
        for cb in sorted(self.cbs, key= lambda x: x._order):
        
            ### gets the function in the callback that has the cb_name
            f = getattr(cb, cb_name, None)
            
            ### Calls the function while checking if it exists
            if f and f(): return True
        return False
        

## Metrics recording Callbacks

In [2]:
class AvgStats():
    
    """
    Object that keeps track of the metrics on a specific dataset.
    """
    
    def __init__(self, metrics, in_train): 
        self.metrics  = listify(metrics)
        self.in_train = in_train
    
    def reset(self):
        self.tot_loss = 0.
        self.count = 0
        self.tot_mets = [0.] * len(self.metrics)
        
    @property
    def all_stats(self): 
        """ Returns a list of all the metrics """
        return [self.tot_loss.item()] + self.tot_mets
    
    @property
    def avg_stats(self): 
        """ Returns the average of each metrics """
        return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run):
        """ Loss is averaged on all batch, but batch are not always of the same size, 
        so to get a better loss calculation, we multiply the loss by the batch size. 
        (Because later this loss will be devided by the total dataset size
        and keeping only the loss averaged no matter the batch size will return a shifted loss) """
        
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i,m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn

In [26]:
class AvgStatsCallback(Callback):
    
    """
    Callback class that store and updates all the diffrent datasets stats
    with AvgStats objects.
    """
    
    def __init__(self, metrics):
        """ Create an AvgStat object to keep track of stats of each DS"""
        self.train_stats = AvgStats(metrics, in_train=True)
        self.valid_stats = AvgStats(metrics, in_train=False)
        
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        # self.in_train is handled by the TrainEval Callback
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): 
            stats.accumulate(self.run)
    
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

In [44]:
learn = Learner(*get_model(data), loss_func, data)

In [45]:
stats = AvgStatsCallback([accuracy])
run = Runner(cbs=stats)

In [46]:
run.fit(2, learn)

train: [0.3101871875, tensor(0.9056)]
valid: [0.18754737548828124, tensor(0.9377)]
train: [0.140584638671875, tensor(0.9575)]
valid: [0.14425281982421875, tensor(0.9564)]


In [47]:
loss, acc = stats.valid_stats.avg_stats
assert acc > 0.9
loss, acc

(0.14425281982421875, tensor(0.9564))

In [48]:
from functools import partial

In [49]:
acc_cbf = partial(AvgStatsCallback, accuracy)

In [52]:
run = Runner(cb_funcs = acc_cbf)

In [53]:
run.fit(1, learn)

train: [0.10609912109375, tensor(0.9671)]
valid: [0.1065044189453125, tensor(0.9683)]


In [56]:
run.avg_stats.valid_stats.avg_stats

[0.1065044189453125, tensor(0.9683)]