In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:

from exp.nb_03 import *

# Get data

In [3]:
x_train,y_train,x_valid,y_valid= get_data_normalized()

In [4]:
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)
nh,bs = 50,64
c = y_train.max().item()+1
loss_func = F.cross_entropy

In [5]:
bs,nh,c

(64, 50, 10)

# Databunch: gather train and valid dataloader together

In [6]:
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl,self.valid_dl,self.c = train_dl,valid_dl,c
        
    @property
    def train_ds(self): return self.train_dl.dataset
        
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [7]:
db = DataBunch(*get_dls(train_ds,valid_ds,bs),c)

# Learner: get model, optimizer, loss func and data together (factoring into packages)

In [8]:
def get_simple_model(data,lr=0.1,nh=50):
    n_features = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(n_features,nh),nn.ReLU(),nn.Linear(nh,data.c))
    return model, optim.SGD(model.parameters(),lr =  lr)

In [9]:
class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model,self.opt,self.loss_func,self.data = model,opt,loss_func,data

In [10]:
loss_func

<function torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')>

In [11]:
learn = Learner(*get_simple_model(db),loss_func,db)

In [12]:
def fit(epochs, learn):
    # 0 begin fit
    for epoch in range(epochs):
        learn.model.train()
        # 1 begin epoch
        for xb,yb in learn.data.train_dl:
            # 2 begin batch: blank
            
            # 3 begin loss
            loss = learn.loss_func(learn.model(xb), yb)
            # 3 after loss: blank
            
            # 4 begin backward
            loss.backward()
            # 4 after backward: blank
            
            # 5 begin step
            learn.opt.step()
            # 5 after step
            learn.opt.zero_grad()
            
            # 2 after batch: blank
        
        # 6 begin validate
        learn.model.eval()
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += learn.loss_func(pred, yb)
                tot_acc  += accuracy (pred,yb)
        nv = len(learn.data.valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
        # 6 after validate: blank
        
        # 1 after epoch: blank
    
    # 0 after fit    
    return tot_loss/nv, tot_acc/nv

In [13]:
loss,acc = fit(1, learn)

0 tensor(0.5808) tensor(0.8209)


```
def one_batch(xb,yb):
    loss = loss_func(model(xb), yb)
    loss.backward()
    opt.step()
    opt.zero_grad()

def fit():
    for epoch in range(epochs):
        for b in train_dl: one_batch(*b)
```

# Callbacks

In [14]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): 
        return # will do this at validation step
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()

def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return

def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()


In [15]:
# callback funcs only return boolean
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self): return True
    def begin_epoch(self, epoch):
        self.epoch=epoch
        return True
    def begin_validate(self): return True
    def after_epoch(self): return True
    def begin_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self): return True
    def after_step(self): return True

# Callback handlers: to handle multiple callbacks

In [16]:
class CallbackHandler():
    def __init__(self,cbs=None):
        self.cbs = cbs if cbs else []

    def begin_fit(self, learn):
        self.learn,self.in_train = learn,True
        learn.stop = False
        res = True # boolean result
        for cb in self.cbs: res = (res and cb.begin_fit(learn))
        return res

    def after_fit(self):
        res = not self.in_train
        for cb in self.cbs: res = (res and cb.after_fit())
        return res
    
    def begin_epoch(self, epoch):
        learn.model.train()
        self.in_train=True
        res = True
        for cb in self.cbs: res = (res and cb.begin_epoch(epoch))
        return res

    def begin_validate(self):
        self.learn.model.eval()
        self.in_train=False
        res = True
        for cb in self.cbs: res = (res and cb.begin_validate())
        return res

    def after_epoch(self):
        res = True
        for cb in self.cbs: res = (res and cb.after_epoch())
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = (res and cb.begin_batch(xb, yb))
        return res

    def after_loss(self, loss):
        res = self.in_train # will be false at begin_validate, thus return False overall unless there's a callback to reverse it back
        for cb in self.cbs: res = (res and cb.after_loss(loss))
        return res

    def after_backward(self):
        res = True
        for cb in self.cbs: res = (res and cb.after_backward())
        return res

    def after_step(self):
        res = True
        for cb in self.cbs: res = (res and cb.after_step())
        return res
    
    def do_stop(self):
        try:     return learn.stop
        finally: learn.stop = False

## Example


In [17]:
class TestCallback(Callback):
    def begin_fit(self,learn):
        super().begin_fit(learn)
        self.n_iters = 0
        return True
        
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=10: learn.stop = True
        return True

In [18]:
fit(1, learn, cb=CallbackHandler([TestCallback()]))

1
2
3
4
5
6
7
8
9
10


```
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self): return True
    def begin_epoch(self, epoch):
        self.epoch=epoch
        return True
    def begin_validate(self): return True
    def after_epoch(self): return True
    def begin_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self): return True
    def after_step(self): return True
```

In [19]:
class TempCallback(Callback):
    def begin_fit(self,learn):
        super().begin_fit(learn) # set self.learn
        self.n_iters = 0
        print('begin_fit')
        return True
    def after_fit(self):
        print('after_fit')
        return True
    def begin_epoch(self,epoch):
        super().begin_epoch(epoch)
        print(f'Begin epoch: {self.epoch}')
        return True
    def begin_validate(self):
        print(f'Begin validating at epoch {self.epoch}')
        return True
    def after_epoch(self): 
        print('After epoch')
        return True
    def begin_batch(self, xb, yb):
        super().begin_batch(xb,yb)
        print(f'Begin batch with shape: {xb.shape}')
        return True
    def after_loss(self, loss):
        super().after_loss(loss)
        print(f'After loss of 1 batch: Training loss {self.loss}')
        return True
    def after_backward(self): 
        print(f'After backward')
        return True
    def after_step(self):
        self.n_iters += 1
        print(f'After opt step {self.n_iters}')
        if self.n_iters>=5: learn.stop = True
        return True

1 epoch, only step through 5 batches (aka only train 5 batches), validate all the batches of the epoch

In [20]:
fit(1, learn, cb=CallbackHandler([TempCallback()]))

begin_fit
Begin epoch: 0
Begin batch with shape: torch.Size([64, 784])
After loss of 1 batch: Training loss 0.3003041446208954
After backward
After opt step 1
Begin batch with shape: torch.Size([64, 784])
After loss of 1 batch: Training loss 0.2018508017063141
After backward
After opt step 2
Begin batch with shape: torch.Size([64, 784])
After loss of 1 batch: Training loss 0.217964768409729
After backward
After opt step 3
Begin batch with shape: torch.Size([64, 784])
After loss of 1 batch: Training loss 0.1688244789838791
After backward
After opt step 4
Begin batch with shape: torch.Size([64, 784])
After loss of 1 batch: Training loss 0.27026647329330444
After backward
After opt step 5
Begin validating at epoch 0
Begin batch with shape: torch.Size([128, 784])
Begin batch with shape: torch.Size([128, 784])
Begin batch with shape: torch.Size([128, 784])
Begin batch with shape: torch.Size([128, 784])
Begin batch with shape: torch.Size([128, 784])
Begin batch with shape: torch.Size([128, 7

# Refactor callbacks

```
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self): return True
    def begin_epoch(self, epoch):
        self.epoch=epoch
        return True
    def begin_validate(self): return True
    def after_epoch(self): return True
    def begin_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self): return True
    def after_step(self): return True
```

In [21]:
# quick example of __getattr__, getattr and setattr

class Dummy():
    def __init__(self,a,b):
        self.dum1,self.dum2 = a,b
    def dumb_func(self):
        return self.dum1+' '+self.dum2
    def __getattr__(self, attr):
        return attr.upper()
d = Dummy('dumb','dumber')
#__getattr__
print(d.does_not_exist) # 'DOES_NOT_EXIST'
print(d.what_about_this_one)  # 'WHAT_ABOUT_THIS_ONE'

# getattr
print(getattr(d,'dum1'))
print(getattr(d,'dum2'))
print(getattr(d,'dumb_func')())

# setattr
# existing attr
setattr(d,'dum2','dumbest')
print(getattr(d,'dum2'))
setattr(d,'dum3','dumber')
print(getattr(d,'dum3'))

DOES_NOT_EXIST
WHAT_ABOUT_THIS_ONE
dumb
dumber
dumb dumber
dumbest
dumber


In [22]:
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

from typing import *

def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, Iterable): return list(o)
    return [o]

In [23]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)


'train_eval_callback'

In [24]:
class Callback():
    _order=0
    def set_runner(self, run): self.run=run
    # use another object (runner) to call the same function
    # note that a runner will maintain several callbacks in the callback list (cbs)
    def __getattr__(self, k): return getattr(self.run, k)
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

In [25]:
# reponsible to switch the model back and forth in training or validation mode, 
# as well as maintaining a count of the iterations,
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs=0.
        self.run.n_iter=0
    
    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters # maintain % of iterations. self.iters is # of batches in dataloader
        self.run.n_iter   += 1 # maintain count of iterations
    
    # switch between training and validation mode
    # train 
    def begin_epoch(self):
        self.run.n_epochs=self.epoch # self.epoch is like i, i.e. for i in range(epochs)
        self.model.train()
        self.run.in_train=True
    # validation
    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False
        

In [26]:
# class TestCallback(Callback):
#     def begin_fit(self,learn):
#         super().begin_fit(learn)
#         self.n_iters = 0
#         return True
        
#     def after_step(self):
#         self.n_iters += 1
#         print(self.n_iters)
#         if self.n_iters>=10: learn.stop = True
#         return True

class TestCallback(Callback):
    _order=1
    def after_step(self):
        # stop at 10th step
        if self.n_iter>=10: return True

In [27]:
TrainEvalCallback().name

'train_eval'

```
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): 
        return # will do this at validation step
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()

def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return

def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()

```

In [28]:
# put 3 funcs above into one class
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs) # list of callbacks
        for cbf in listify(cb_funcs): # cb_funcs are likely partial(SomeCallback,<init_input>)
            cb = cbf()
            setattr(self, cb.name, cb) # save callback object under its snake_name to runner
            # e.g for TestCallBack, it will be: Runner.train_eval = TrainEvalCallback(<init_input from partial)()
            cbs.append(cb)
        self.stop,self.cbs = False,[TrainEvalCallback()]+cbs

    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data

#     def one_batch(xb, yb, cb):
#         if not cb.begin_batch(xb,yb): return
#         loss = cb.learn.loss_func(cb.learn.model(xb), yb)
#         if not cb.after_loss(loss): 
#             return # will do this at validation step
#         loss.backward()
#         if cb.after_backward(): cb.learn.opt.step()
#         if cb.after_step(): cb.learn.opt.zero_grad()      
    def one_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        
        # Only run if self('<method>') return False (which is the opposite of the func above)
        # because  self('<method>') return False only when callback does not have that method (look at __call__ at the end)
        # e.g for TestCallBack func
#         class TestCallback(Callback):
#         _order=1
#         def after_step(self):
#             # stop at 10th step
#             if self.n_iter>=10: return True
        # return True for after_step to stop. Thus will return False since there is no other method, which means to continue
        if self('begin_batch'): return 
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return
        self.opt.zero_grad()

#     def all_batches(dl, cb):
#         for xb,yb in dl:
#             one_batch(xb, yb, cb)
#             if cb.do_stop(): return
    def all_batches(self, dl):
        self.iters = len(dl)
        for xb,yb in dl:
            if self.stop: break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop=False

#     def fit(epochs, learn, cb):
#         if not cb.begin_fit(learn): return
#         for epoch in range(epochs):
#             if not cb.begin_epoch(epoch): continue
#             all_batches(learn.data.train_dl, cb)

#             if cb.begin_validate():
#                 with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
#             if cb.do_stop() or not cb.after_epoch(): break
#         cb.after_fit()
    def fit(self, epochs, learn):
        self.epochs,self.learn,self.loss = epochs,learn,tensor(0.)

        try:
            for cb in self.cbs: cb.set_runner(self)
            if self('begin_fit'): return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): self.all_batches(self.data.train_dl)

                with torch.no_grad(): 
                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)
                if self('after_epoch'): break
            
        finally:
            self('after_fit')
            self.learn = None

    def __call__(self, cb_name):
        for cb in sorted(self.cbs, key=lambda x: x._order):
            # where callback function (begin_epoch,after_step,...) are really called. If there is no such func -> None
            f = getattr(cb, cb_name, None)
            if f and f(): return True
        return False

In [36]:
type(db.train_dl.dataset)

exp.nb_03.Dataset

## Few more callback classes

In [42]:
class AvgStats():
    def __init__(self, metrics, in_train): self.metrics,self.in_train = listify(metrics),in_train
    
    def reset(self):
        self.tot_loss,self.count = 0.,0
        self.tot_mets = [0.] * len(self.metrics)
        
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets # [cumm_loss_value,cumm_metric1,cumm_metric2 ...]
    @property
    def avg_stats(self): return [o/self.count for o in self.all_stats] # average all_stats by dividing by # of items in train/val set
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run): # calculate count from batch size, add runner loss to total loss, calculate metrics and add to total metrics
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i,m in enumerate(self.metrics): # use metrics function to calculate metric
            self.tot_mets[i] += m(run.pred, run.yb) * bn

# callback to add metric
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)
        
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run)
    
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

In [39]:
loss_func = F.cross_entropy

In [40]:
learn = Learner(*get_simple_model(db), loss_func, db)

In [45]:
AvgStatsCallback([accuracy]).name

'avg_stats'

## Define callback with no partial: callback does not save as a runner attribute

In [43]:
stats = AvgStatsCallback([accuracy,accuracy])
run = Runner(cbs=stats)

In [46]:
run.cb

[<__main__.TrainEvalCallback at 0x7fcee8e6a7b8>,
 <__main__.AvgStatsCallback at 0x7fcee8e6a908>]

In [48]:
run.fit(2, learn)

train: [0.4457242513020833, tensor(0.8655), tensor(0.8655)]
valid: [0.60572392578125, tensor(0.8120), tensor(0.8120)]
train: [0.2547105794270833, tensor(0.9243), tensor(0.9243)]
valid: [0.536561962890625, tensor(0.8390), tensor(0.8390)]


In [51]:
loss,metric1,metric2=stats.valid_stats.avg_stats # use callback obj directly to get stats

In [52]:
loss,metric1,metric2

(0.536561962890625, tensor(0.8390), tensor(0.8390))

# Use callback created from partial (to see Callback object's snake name becomes an attribute of runner)

In [53]:
from functools import partial

acc_cbf = partial(AvgStatsCallback,accuracy)

run = Runner(cb_funcs=acc_cbf)

In [54]:
acc_cbf().name

'avg_stats'

In [55]:
run.cbs

[<__main__.TrainEvalCallback at 0x7fcee8e849e8>,
 <__main__.AvgStatsCallback at 0x7fcee8e84be0>]

In [57]:
run.fit(1, learn)

train: [0.203928955078125, tensor(0.9392)]
valid: [0.52443955078125, tensor(0.8497)]


## Get callback obj from runner

In [58]:
run.avg_stats

<__main__.AvgStatsCallback at 0x7fcee8e84be0>

In [64]:
run.avg_stats.train_stats # use runner to call callback obj to get stats

train: [0.203928955078125, tensor(0.9392)]

In [65]:
run.avg_stats.valid_stats

valid: [0.52443955078125, tensor(0.8497)]