# Customize Training with Callbacks
---

In [1]:
config = {
    'epochs': 10,
    'lr': 0.01,
    'bs': 128
}

## Import Libraries

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
import utils

## Load Data

In [4]:
x_train, y_train, x_test, y_test = utils.get_data()
train_dl, test_dl = utils.get_dataloaders(x_train, y_train, x_test, y_test, config['bs'])

In [5]:
data = utils.Databunch(train_dl, test_dl, n_classes=10)

## Initiate Learner

In [6]:
model = nn.Sequential(nn.Linear(784, 300), nn.ReLU(), nn.Linear(300,10))

In [7]:
optim = torch.optim.SGD(model.parameters(), lr = config['lr'])

In [8]:
learner = utils.Learner(model, data, optim, F.cross_entropy)

# Basic Callback System
---
**WARNING: Bad Smelly Code Ahead!**

## Abstract Callback Class

In [9]:
class Callback:
    """
    Abstract class inherited by all callbacks
  
    All methods return true by default
    to avoid interrupting the training loop    
    """
    
    def before_train(self, learner):
        self.learner = learner
        return True
    
    def before_epoch(self, epoch):
        self.epoch = epoch
        return True
    
    def before_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        return True
    
    def after_loss(self, loss):
        self.loss = loss
        return True
    
    def after_backward(self):  return True
    def after_step(self):      return True
    def before_validate(self): return True
    def after_epoch(self):     return True
    def after_train(self):     return True

## Test Callback

In [10]:
class Test(Callback):
    '''Test callback system to stop training at given iteration'''
    
    def __init__(self, stop_train_at=10):
        self.stop_at = stop_train_at
        
    def before_train(self, learner):
        super().before_train(learner)
        self.iters = 0
        return True
    
    def after_step(self):
        self.iters += 1
        print(self.iters)
        if self.iters > self.stop_at:
            self.learner.stop = True
        return True

## Callback Controller

In [11]:
class CallbackController:
    '''Runs the registered callbacks during a training loop'''
    
    def __init__(self, callbacks):
        self.cbs = callbacks if callbacks else []
    
    def before_train(self, learner):
        self.learner = learner
        self.in_train = True
        carry_on = True
        self.learner.stop = False
        for cb in self.cbs:
            carry_on = carry_on and cb.before_train(model)
        return carry_on
    
    def after_train(self):
        carry_on = not self.in_train
        for cb in self.cbs:
            carry_on = carry_on and cb.after_train()
        return carry_on
    
    def before_epoch(self, epoch):
        self.learner.model.train()
        self.in_train = True
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.before_epoch(epoch)
        return carry_on
    
    def after_epoch(self):
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.after_epoch()
        return carry_on
    
    def before_batch(self, xb, yb):
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.before_batch(xb, yb)
        return carry_on
        
    def before_validate(self):
        self.learner.model.eval()
        self.in_train = False
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.before_validate()
        return carry_on
        
    def after_loss(self, loss):
        carry_on = self.in_train
        for cb in self.cbs:
            carry_on = carry_on and cb.after_loss(loss)
        return carry_on
    
    def after_backward(self):
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.after_backward()
        return carry_on
    
    def after_step(self):
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.after_step()
        return carry_on
    
    def do_stop(self):
        try:     return self.learner.stop
        finally: self.learner.stop = False

## Train with Callbacks

In [17]:
def train_one_batch(xb, yb, cb):
    if not cb.before_batch(xb, yb): return
    yb_pred = cb.learner.model(xb)
    loss = cb.learner.loss_func(yb_pred, yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learner.opt.step()
    if cb.after_step(): cb.learner.opt.zero_grad()

def train_all_batches(dl, cb):
    for xb,yb in dl:
        train_one_batch(xb, yb, cb)
        if cb.do_stop: return
    
def train(epochs, learner, cb):
    if not cb.before_train(learner): return
    for epoch in range(epochs):
        if not cb.before_epoch(epoch): continue
        train_all_batches(learner.data.train_dl, cb)
        
        if cb.before_validate():
            with torch.no_grad():
                train_all_batches(learner.data.test_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): return
    cb.after_train()

In [18]:
train(config['epochs'], learner,
      CallbackController([Test(stop_train_at=10)]))

1
2
3
4
5
6
7
8
9
10


# Improved Callback System
---
**Cleaning up the mess**