# Customize Training with Callbacks
---

In [1]:
config = {
    'epochs': 20,
    'lr': 0.1,
    'bs': 128
}

## Import Libraries

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
import re
import utils

## Load Data

In [4]:
x_train, y_train, x_test, y_test = utils.get_data()
train_dl, test_dl = utils.get_dataloaders(x_train, y_train, x_test, y_test, config['bs'])

In [5]:
data = utils.Databunch(train_dl, test_dl, n_classes=10)

## Initialize Learner

In [6]:
model = nn.Sequential(nn.Linear(784, 300), nn.ReLU(), nn.Linear(300,100), nn.ReLU(), nn.Linear(100,10))

In [7]:
optim = torch.optim.SGD(model.parameters(), lr = config['lr'])

In [8]:
learner = utils.Learner(model, data, optim, F.cross_entropy)

# Basic Callback System
---
**WARNING: Bad Smelly Code Ahead!**

## Abstract Callback Class

In [9]:
class Callback:
    """
    Abstract class inherited by all callbacks
  
    All methods return true by default
    to avoid interrupting the training loop    
    """
    
    def before_train(self, learner):
        self.learner = learner
        return True
    
    def before_epoch(self, epoch):
        self.epoch = epoch
        return True
    
    def before_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        return True
    
    def after_loss(self, loss):
        self.loss = loss
        return True
    
    def after_backward(self):  return True
    def after_step(self):      return True
    def before_validate(self): return True
    def after_epoch(self):     return True
    def after_train(self):     return True

## Test Callback

In [10]:
class Test(Callback):
    '''Test callback system to stop training at given iteration'''
    
    def __init__(self, stop_train_at=10):
        self.stop_at = stop_train_at
        
    def before_train(self, learner):
        super().before_train(learner)
        self.iters = 0
        return True
    
    def after_step(self):
        self.iters += 1
        print(self.iters)
        if self.iters > self.stop_at:
            self.learner.stop = True
        return True

## Callback Controller

In [11]:
class CallbackController:
    '''Runs the registered callbacks during a training loop'''
    
    def __init__(self, callbacks):
        self.cbs = callbacks if callbacks else []
    
    def before_train(self, learner):
        self.learner = learner
        self.in_train = True
        carry_on = True
        self.learner.stop = False
        for cb in self.cbs:
            carry_on = carry_on and cb.before_train(model)
        return carry_on
    
    def after_train(self):
        carry_on = not self.in_train
        for cb in self.cbs:
            carry_on = carry_on and cb.after_train()
        return carry_on
    
    def before_epoch(self, epoch):
        self.learner.model.train()
        self.in_train = True
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.before_epoch(epoch)
        return carry_on
    
    def after_epoch(self):
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.after_epoch()
        return carry_on
    
    def before_batch(self, xb, yb):
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.before_batch(xb, yb)
        return carry_on
        
    def before_validate(self):
        self.learner.model.eval()
        self.in_train = False
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.before_validate()
        return carry_on
        
    def after_loss(self, loss):
        carry_on = self.in_train
        for cb in self.cbs:
            carry_on = carry_on and cb.after_loss(loss)
        return carry_on
    
    def after_backward(self):
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.after_backward()
        return carry_on
    
    def after_step(self):
        carry_on = True
        for cb in self.cbs:
            carry_on = carry_on and cb.after_step()
        return carry_on
    
    def do_stop(self):
        try:     return self.learner.stop
        finally: self.learner.stop = False

## Train with Callbacks

In [12]:
def train_one_batch(xb, yb, cb):
    if not cb.before_batch(xb, yb): return
    yb_pred = cb.learner.model(xb)
    loss = cb.learner.loss_func(yb_pred, yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learner.opt.step()
    if cb.after_step(): cb.learner.opt.zero_grad()

def train_all_batches(dl, cb):
    for xb,yb in dl:
        train_one_batch(xb, yb, cb)
        if cb.do_stop: return
    
def train(epochs, learner, cb):
    if not cb.before_train(learner): return
    for epoch in range(epochs):
        if not cb.before_epoch(epoch): continue
        train_all_batches(learner.data.train_dl, cb)
        
        if cb.before_validate():
            with torch.no_grad():
                train_all_batches(learner.data.test_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): return
    cb.after_train()

**TODO: Doesn't work. Fix it.**

In [13]:
train(config['epochs'], learner,
      CallbackController([Test(stop_train_at=10)]))

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20


# Improved Callback System
---
**Cleaning up the mess**

## Abstract Callback Class

In [14]:
def camel2snake(name):
    _camel_r1 = "(.)([A-Z][a-z]+)"
    _camel_r2 = "([a-z0-9])([A-Z])"
    s1 = re.sub(_camel_r1, r'\1_\2', name)
    return re.sub(_camel_r2, r'\1_\2', s1).lower()    

In [15]:
class Callback:
    '''Abstract Callback Class'''
    _order = 0
    
    def set_controller(self, control):  self.control = control
    def __getattr__(self, attr):        return getattr(self.control, attr)
    @property
    def name(self):                     return camel2snake(name or 'callback')

## Controller

In [16]:
class Controller:
    '''Main Controller responsible for callbacks and training loop'''
    
    def __init__(self, callback_list=[]):
        self.cbs = [TrainEval()] + callback_list
        self.stop = False
    
    @property
    def model(self):     return self.learner.model
    @property
    def data(self):      return self.learner.data
    @property
    def opt(self):       return self.learner.opt
    @property
    def loss_func(self): return self.learner.loss_func

    def run_one_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        if self('before_batch'): return
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return
        self.opt.zero_grad()
        self('after_batch')
        
    def run_all_batches(self, dl):
        self.iters = len(dl)
        for xb, yb in dl:
            if self.stop: break
            self.run_one_batch(xb, yb)
        self.stop = False
    
    def train(self, learner, epochs):
        self.learner, self.epochs = learner, epochs
        try:
            for cb in self.cbs: 
                cb.set_controller(self)
            if self('before_train'): return
            for epoch in range(self.epochs):
                self.epoch = epoch
                if not self('before_epoch'):
                    self.run_all_batches(self.data.train_dl)
                with torch.no_grad():
                    if not self('before_validate'):
                        self.run_all_batches(self.data.test_dl)
                if self('after_epoch'): break
        finally:
            self('after_train')
    
    def __call__(self, cb_func_name):
        for cb in sorted(self.cbs, key=lambda x: x._order):
            # get the callback function if defined or else None 
            cb_func = getattr(cb, cb_func_name, None)
            # run only if cb_func is not None
            if cb_func and cb_func(): return True
        # when a callback function was not defined in any of the callbacks
        return False

## Callback: Train Eval

In [17]:
class TrainEval(Callback):
    """
    Callback to switch between train and eval mode,
    as well as keep track of iterations during training.
    
    Note: This callback is attached by default
    """
    
    def before_train(self):
        self.control.n_epochs = 0.
        self.control.n_iter = 0.
    
    def before_epoch(self):
        self.control.model.train()
        self.control.in_train = True
        self.control.n_epochs = self.control.epoch
    
    def before_batch(self):
        if not self.control.in_train: return
        self.control.n_epochs += 1./self.control.iters
        self.control.n_iter += 1
    
    def before_validate(self):
        self.control.model.eval()
        self.control.in_train = False

## Callback: Stats Reporter

In [18]:
class StatsReporter(Callback):
    '''Report training statistics in terms of the given metrics'''
    
    def __init__(self, metrics):
        self.metrics = [] if metrics is None else metrics
            
    def before_epoch(self):
        self.train_loss, self.valid_loss = 0., 0.
        self.train_metrics = torch.tensor([0.]).expand(len(self.metrics))
        self.valid_metrics = torch.tensor([0.]).expand(len(self.metrics))
        self.train_count, self.valid_count = 0, 0
    
    def after_loss(self):
        batch_len = self.control.xb.shape[0]
        if self.control.in_train:
            self.train_count += batch_len
            self.train_loss += self.control.loss*batch_len
            self.train_metrics += torch.tensor([m(self.control.pred,self.control.yb)*batch_len\
                                                for m in self.metrics])
        else:
            self.valid_count += batch_len
            self.valid_loss += self.control.loss*batch_len
            self.valid_metrics += torch.tensor([m(self.control.pred,self.control.yb)*batch_len\
                                                for m in self.metrics])

        
    def after_epoch(self):
        header = f"EPOCH#{self.control.epoch} \t"
        train_avg_loss = self.train_loss / self.train_count
        valid_avg_loss = self.valid_loss / self.valid_count
        train_avg_metrics = self.train_metrics.numpy() / self.train_count
        valid_avg_metrics = self.valid_metrics.numpy() / self.valid_count
        train_str = f"Train loss: {train_avg_loss:.3f} \t metrics: {train_avg_metrics} \t"
        valid_str = f"Valid loss: {valid_avg_loss:.3f} \t metrics: {valid_avg_metrics} \t"
        print(header + train_str + valid_str)

## Train with Callbacks

In [19]:
reporter = StatsReporter([utils.accuracy])
control = Controller(callback_list=[reporter])

**TODO: Doesn't train. Fix it.**

In [20]:
control.train(learner, 20)

EPOCH#0 	Train loss: 31200.246 	 metrics: [0.11316] 	Valid loss: 2.302 	 metrics: [0.1064] 	
EPOCH#1 	Train loss: 2.301 	 metrics: [0.11356] 	Valid loss: 2.302 	 metrics: [0.1064] 	
EPOCH#2 	Train loss: 2.301 	 metrics: [0.11356] 	Valid loss: 2.302 	 metrics: [0.1064] 	
EPOCH#3 	Train loss: 2.301 	 metrics: [0.11356] 	Valid loss: 2.302 	 metrics: [0.1064] 	
EPOCH#4 	Train loss: 2.301 	 metrics: [0.11356] 	Valid loss: 2.302 	 metrics: [0.1064] 	
EPOCH#5 	Train loss: 2.301 	 metrics: [0.11356] 	Valid loss: 2.302 	 metrics: [0.1064] 	


KeyboardInterrupt: 