In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
from exports.e_02_MNISTLoader import loadMNIST
from exports.e_04_DataAPI import Dataset
from exports.e_05_Losses_Optimizers_TrainEval import make_dls

import torch
from torch import nn, optim
from torch.functional import F

### Data API so far:

In [3]:
x_train, y_train, x_valid, y_valid = loadMNIST()
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)

n_sampl, n_inp = x_train.shape
n_out = 10
n_hid = 50

batch_size = 64

make_dls(train_ds, valid_ds, batch_size)

(<torch.utils.data.dataloader.DataLoader at 0x7f67a8497df0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f673448f100>)

# Callbacks

## Data Wrapper
Easy access to both train and valid DataLoaders, Datasets

In [4]:
#--export--#
class DataWrapper():
    def __init__(self, train_dl, valid_dl, n_out):
        self.train_dl, self.valid_dl, self.n_out = train_dl, valid_dl, n_out
    
    # @property denotes a 'get' property of the class
    @property
    def train_ds(self): return self.train_dl.dataset
    
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [5]:
data_w = DataWrapper(*make_dls(train_ds, valid_ds, batch_size), n_out)

## Model Wrapper

In [6]:
#--export--#
def SimpleModel(data_w, lr=0.3, n_hid=50):
    n_inp, n_out = data_w.train_ds.x.shape[1], data_w.n_out
    
    model = nn.Sequential(nn.Linear(n_inp, n_hid), 
                          nn.ReLU(),
                          nn.Linear(n_hid, n_out))
    
    return model, optim.SGD(model.parameters(), lr=lr)

class ModelWrapper():
    def __init__(self, model, opt, loss_f, data_w):
        self.model, self.opt, self.loss_f, self.data_w = model, opt, loss_f, data_w

In [7]:
model_w = ModelWrapper(*SimpleModel(data_w), F.cross_entropy, data_w)

## Callback Class and Job Handler

In [8]:
#--export--#
STAGES = ['begin_fit',
          'begin_epoch',
          'begin_batch',
          'after_loss',
          'after_backward',
          'after_step',
          'begin_valid',
          'after_epoch',
          'after_fit']

class Callback():
    _order = 0
    def __getattr__(self, attr): return getattr(self.job, attr)
    @property
    def name(self):
        return self.__class__.__name__

class DLJob():
    def __init__(self, callbacks=[]):
        self.cbs, self.stop = callbacks, False
    
    @property
    def opt(self): return self.mw.opt
    @property
    def model(self): return self.mw.model
    @property
    def loss_f(self): return self.mw.loss_f
    @property
    def data_w(self): return self.mw.data_w
    
    
    def one_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        if not self('begin_batch'): return
        self.pred = self.model(self.xb)
        if not self('after_pred'): return
        self.loss = self.loss_f(self.pred, self.yb)
        if not self('after_loss') or not self.training: return
        self.loss.backward()
        if not self('after_backward'): return
        self.opt.step()
        if not self('after_step'): return
        self.opt.zero_grad()
    
    def all_batch(self, dl):
        self.iters = len(dl)
        for xb, yb in dl:
            if self.stop: break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop = False
        
    def fit(self, epochs, model_wrapper):
        self.epochs, self.mw = epochs, model_wrapper
        
        try:
            for cb in self.cbs: cb.job = self
            if not self('begin_fit'): return
            for epoch in range(epochs):
                self.epoch = epoch
                if self('begin_epoch'): 
                    self.training = True
                    self.model.train()
                    self.all_batch(self.data_w.train_dl)
                    
                with torch.no_grad():
                    if self('begin_valid'):
                        self.training = False
                        self.model.eval()
                        self.all_batch(self.data_w.valid_dl)
                if not self('after_epoch'): break
                    
        finally:
            self('after_fit')
            self.mw = None
            
            
    def __call__(self, stage):
        for cb in sorted(self.cbs, key=lambda x: x._order):
            f = getattr(cb, stage, None)
            if f and f(): return False
        return True

In [9]:
class TestCallback(Callback):
    """ Stops training after 10 optimization steps. """
    def begin_fit(self):
        self.job.training = True
        self.n_iters = 0
        
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=10: self.job.stop = True
        return True

In [10]:
job = DLJob([TestCallback()])
job.fit(1, model_w)

1
2
3
4
5
6
7
8
9
10


## Callback to Get Running Average of Custom Metrics

In [11]:
#--export--#
class AvgStats():
    def __init__(self, metrics, training):
        self.metrics, self.training = metrics, training
    
    def reset(self):
        self.tot_loss, self.count = 0., 0
        self.tot_mets = [0.] * len(self.metrics)
    
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self): return [stat/self.count for stat in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f'{"train" if self.training else "valid"}: {self.avg_stats}'
    
    def accumulate(self, job):
        batch_size = job.xb.shape[0]
        self.tot_loss += job.loss * batch_size
        self.count += batch_size
        for i, metric in enumerate(self.metrics):
            self.tot_mets[i] += metric(job.pred, job.yb) * batch_size
            
class AvgStatsCB(Callback):
    def __init__(self, metrics=[]):
        self.train_stats, self.valid_stats = AvgStats(metrics, True), AvgStats(metrics, False)
    
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        stats = self.train_stats if self.training else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.job)
    
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

### Print running loss and accuracy

In [None]:
#--export--#
def acc_f(pred, lab): return (torch.argmax(pred, dim=1) == lab).float().mean()

In [12]:
job = DLJob([AvgStatsCB([acc_f])])
job.fit(5, model_w)

train: [0.29554005859375, tensor(0.9106)]
valid: [0.2020339599609375, tensor(0.9417)]
train: [0.1687086328125, tensor(0.9496)]
valid: [0.14223358154296875, tensor(0.9587)]
train: [0.132412578125, tensor(0.9601)]
valid: [0.12514912109375, tensor(0.9620)]
train: [0.110084892578125, tensor(0.9665)]
valid: [0.1145890625, tensor(0.9657)]
train: [0.097337568359375, tensor(0.9693)]
valid: [0.119880615234375, tensor(0.9633)]


In [15]:
!python utils/export_notebook.py 06_Callbacks.ipynb

Notebook 06_Callbacks.ipynb has been converted to module ./exports/e_06_Callbacks.py!
