In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_03 import *

## DataBunch/Learner

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=4799)

In [3]:
x_train,y_train,x_valid,y_valid = get_data()
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)
nh,bs = 50,64
c = y_train.max().item()+1
loss_func = F.cross_entropy
train_ds

<exp.nb_03.Dataset at 0x7fafe2049978>

Factor out the connected pieces of info out of the fit() argument list

`fit(epochs, model, loss_func, opt, train_dl, valid_dl)`

Let's replace it with something that looks like this:

`fit(1, learn)`

This will allow us to tweak what's happening inside the training loop in other places of the code because the `Learner` object will be mutable, so changing any of its attribute elsewhere will be seen in our training loop.

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5363)

In [4]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl,self.valid_dl,self.c = train_dl,valid_dl,c # c is final number of acitvations
        
    @property
    def train_ds(self): return self.train_dl.dataset
        
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [5]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)
data.train_dl.dataset
get_dls(train_ds, valid_ds, bs)

(<torch.utils.data.dataloader.DataLoader at 0x7fafe2049710>,
 <torch.utils.data.dataloader.DataLoader at 0x7fafe2049828>)

In [6]:
#export
def get_model(data, lr=0.5, nh=50):
    m = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,data.c))
    return model, optim.SGD(model.parameters(), lr=lr)

class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model,self.opt,self.loss_func,self.data = model,opt,loss_func,data

In [7]:
learn = Learner(*get_model(data), loss_func, data)


In [8]:
def fit(epochs, learn):
    for epoch in range(epochs):
        learn.model.train()
        for xb,yb in learn.data.train_dl:
            loss = learn.loss_func(learn.model(xb), yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()

        learn.model.eval()
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += learn.loss_func(pred, yb)
                tot_acc  += accuracy (pred,yb)
        nv = len(learn.data.valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv

In [9]:
loss,acc = fit(1, learn)

0 tensor(0.1720) tensor(0.9492)


## CallbackHandler

This was our training loop (without validation) from the previous notebook, with the inner loop contents factored out:

```python
def one_batch(xb,yb):
    pred = model(xb)
    loss = loss_func(pred, yb)
    loss.backward()
    opt.step()
    opt.zero_grad()
    
def fit():
    for epoch in range(epochs):
        for b in train_dl: one_batch(*b)
```

Add callbacks so we can remove complexity from loop, and make it flexible:

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5628)

In [10]:
def one_batch(xb, yb, cb): # xb yb are taining data performs fitting of a single batch
    #(dir(cb))
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()

def all_batches(dl, cb): # loops though each batch
    print("In all batches function: " + str(dir(cb)))
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return

def fit(epochs, learn, cb): # loops through epochs
    print("Fit function: " +str(dir(cb)))
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb) # learn.data is the learner, containing data and model
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()

In [11]:
class Callback():
    def begin_fit(self, learn): # add the learner to the call back
        self.learn = learn
        print(learn)
        return True
    def after_fit(self): return True
    def begin_epoch(self, epoch):
        self.epoch=epoch
        return True
    def begin_validate(self): 
        print("In the callback begin_validate")
        return True
    def after_epoch(self): return True
    def begin_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self): return True
    def after_step(self): return True

In [12]:
class CallbackHandler():
    def __init__(self,cbs=None):
        self.cbs = cbs if cbs else [] 

    def begin_fit(self, learn):
        self.learn,self.in_train = learn,True
        learn.stop = False
        res = True
        #print("In the call back handler begin_fit method")
        for cb in self.cbs: res = res and cb.begin_fit(learn)
        return res

    def after_fit(self):
        res = not self.in_train
        print("In after fit call back handler")
        for cb in self.cbs: res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        self.learn.model.train()
        self.in_train=True
        res = True
        for cb in self.cbs: res = res and cb.begin_epoch(epoch)
        return res

    def begin_validate(self):
        self.learn.model.eval()
        self.in_train=False
        res = True
        print("In the call back hander begin_validate")
        for cb in self.cbs: res = res and cb.begin_validate()
        return res

    def after_epoch(self):
        res = True
        for cb in self.cbs: res = res and cb.after_epoch()
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
        return res

    def after_loss(self, loss):
        res = self.in_train
        for cb in self.cbs: res = res and cb.after_loss(loss)
        return res

    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in self.cbs: res = res and cb.after_step()
        return res
    
    def do_stop(self):
        try:     return self.learn.stop
        finally: self.learn.stop = False

In [13]:
class TestCallback(Callback): # a call back is a class inherited from Callback class
    def begin_fit(self,learn): # only amend the callbacks where we add them
        super().begin_fit(learn)
        self.n_iters = 0
        return True
        
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=10: self.learn.stop = True
        return True

In [14]:
fit(1, learn, cb=CallbackHandler([TestCallback()]))

Fit function: ['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'after_backward', 'after_epoch', 'after_fit', 'after_loss', 'after_step', 'begin_batch', 'begin_epoch', 'begin_fit', 'begin_validate', 'cbs', 'do_stop']
<__main__.Learner object at 0x7fafe1ff6208>
In all batches function: ['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'after_backward', 'after_epoch', 'after_fit', 'after_loss', 'after_step', 

In [15]:
#%debug

This is roughly how fastai does it now (except that the handler can also change and return `xb`, `yb`, and `loss`). But let's see if we can make things simpler and more flexible, so that a single class has access to everything and can change anything at any time. The fact that we're passing `cb` to so many functions is a strong hint they should all be in the same class!

## Runner

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5811)

In [16]:
#export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

class Callback():
    _order=0
    def set_runner(self, run): self.run=run
    def __getattr__(self, k): return getattr(self.run, k)
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

This first callback is reponsible to switch the model back and forth in training or validation mode, as well as maintaining a count of the iterations, or the percentage of iterations ellapsed in the epoch.

In [17]:
#export
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs=0.
        self.run.n_iter=0
    
    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter   += 1
        
    def begin_epoch(self):
        self.run.n_epochs=self.epoch
        self.model.train()
        self.run.in_train=True
        self.run.n_iter = 0

    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False

We'll also re-create our TestCallback

In [18]:
class TestCallback(Callback):
    def after_step(self):
        if self.train_eval.n_iters>=10: return True

In [19]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)

'train_eval_callback'

In [20]:
TrainEvalCallback().name

'train_eval'

In [21]:
#export
from typing import *

def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

In [51]:
#export
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        #print(cbs)
        #print(cb_funcs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb) # sets the callback class as an attribute of the runer class
            cbs.append(cb) # this doesn't happen for the cbs. Maybe better when not storing stats
        self.stop,self.cbs = False,[TrainEvalCallback()]+cbs
       # print(self.cbs)

    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data

    def one_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        if self('begin_batch'): return # if call back function returns TRUE then return from method
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return # don't run the backward pass 
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return
        self.opt.zero_grad()

    def all_batches(self, dl):
        self.iters = len(dl)
        for xb,yb in dl:
            #print(self.stop)
            if self.stop: 
                #print("I wish I was stopping")
                break
            #print("Why do I keep on going?")
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop=False

    def fit(self, epochs, learn):
        self.epochs,self.learn = epochs,learn
        
        try:
            for cb in self.cbs: cb.set_runner(self)
            if self('begin_fit'): return
            for epoch in range(epochs):
                self.epoch = epoch
                print("Epoch: " + str(epoch))
                if not self('begin_epoch'): self.all_batches(self.data.train_dl)

                with torch.no_grad(): 
                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)
                if self('after_epoch'): break
            
        finally:
            self('after_fit')
            self.learn = None

    def __call__(self, cb_name): # calls the call back
        #print(cb_name)
        for cb in sorted(self.cbs, key=lambda x: x._order): # sorts call backs by order
            f = getattr(cb, cb_name, None) # if the cb has a cb_name function then get it
            if f and f(): return True
        return False
    

Third callback: how to compute metrics.

In [52]:
#export
class AvgStats():
    def __init__(self, metrics, in_train): 
        self.metrics,self.in_train = listify(metrics),in_train
        print("Metrics:")
        print(metrics)
    def reset(self):
        self.tot_loss,self.count = 0.,0
        self.tot_mets = [0.] * len(self.metrics)
        
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self): return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i,m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn # can have any metrics on the predictions and truth 

class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)
        
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): 
            stats.accumulate(self.run) # passes the runner in 
            #print(self.run)
    
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

In [53]:
learn = Learner(*get_model(data), loss_func, data)

In [54]:
stats = AvgStatsCallback([accuracy])
run = Runner(cbs=stats)
#run.cbs[1].after_loss
stats

Metrics:
[<function accuracy at 0x7fb0552187b8>]
Metrics:
[<function accuracy at 0x7fb0552187b8>]


<__main__.AvgStatsCallback at 0x7fafe1fc4ef0>

In [55]:
run.fit(4, learn)

Epoch: 0
train: [0.30987896484375, tensor(0.9048)]
valid: [0.1971385986328125, tensor(0.9402)]
Epoch: 1
train: [0.13820181640625, tensor(0.9580)]
valid: [0.15313585205078126, tensor(0.9530)]
Epoch: 2
train: [0.102556455078125, tensor(0.9685)]
valid: [0.12364180908203125, tensor(0.9621)]
Epoch: 3
train: [0.08278763671875, tensor(0.9746)]
valid: [0.1060583740234375, tensor(0.9689)]


In [56]:
loss,acc = stats.valid_stats.avg_stats
assert acc>0.9
loss,acc

(0.1060583740234375, tensor(0.9689))

In [57]:
#export
from functools import partial

In [64]:
acc_cbf = partial(AvgStatsCallback,accuracy)
x = acc_cbf()

Metrics:
<function accuracy at 0x7fb0552187b8>
Metrics:
<function accuracy at 0x7fb0552187b8>


In [69]:
class earlyStop(Callback):
    _order = 1
    def __init__(self, stopIt = 6):
        self.stopIt = stopIt
        
    def after_batch(self):
        #print(self.n_iter)
        if self.n_iter > 10:
            self.run.stop = True
            
            return True
        #print("boom") 
          
        #pass
        
    def after_epoch(self):
        #print(self.epoch)
        if self.epoch == self.stopIt:
            #self.stop = True
            return True

run = Runner(cb_funcs=[acc_cbf, partial(earlyStop, 3)])
#run = Runner(cbs = earlyStop(), cb_funcs=[acc_cbf])

Metrics:
<function accuracy at 0x7fb0552187b8>
Metrics:
<function accuracy at 0x7fb0552187b8>


In [70]:
run.fit(10, learn)

Epoch: 0
train: [0.07979269461198286, tensor(0.9787)]
valid: [0.07471542060375214, tensor(0.9609)]
Epoch: 1
train: [0.04883186383680864, tensor(0.9872)]
valid: [0.0653093233704567, tensor(0.9766)]
Epoch: 2
train: [0.05038662390275435, tensor(0.9872)]
valid: [0.07080596685409546, tensor(0.9766)]
Epoch: 3
train: [0.06881838495081122, tensor(0.9744)]
valid: [0.06597046554088593, tensor(0.9688)]


Using Jupyter means we can get tab-completion even for dynamic code like this! :)

In [32]:
run.

SyntaxError: invalid syntax (<ipython-input-32-153f5d03df9c>, line 1)

In [None]:
run.avg_stats.valid_stats.avg_stats

## Export

In [None]:
!python notebook2script.py 04_callbacks.ipynb

In [None]:
learn.opt.param_groups

In [None]:
# Recorder class - 
class Recorder(Callback):
    def begin_fit(self): self.lrs,self.losses = [],[]

    def after_batch(self):
        if not self.in_train: return
        self.lrs.append(self.opt.param_groups[-1]['lr'])
        self.losses.append(self.loss.detach().cpu())        

    def plot_lr  (self): plt.plot(self.lrs)
    def plot_loss(self): plt.plot(self.losses)
        
rcd_cbf = Recorder

run = Runner(cb_funcs=[acc_cbf, Recorder])
run.fit(2, learn)


In [None]:
run.learn