In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))

from callback import *

In [3]:
#export
class EpochLogger(Callback):
    def __init__(self):
        '''Most simple callback to just log epoch number.'''
        super().__init__()
        
    def before_epoch(self):
        print(f'Epoch {self.epoch}')

In [4]:
#export
class CancelTrainException(Exception): 
    def __init__(self):
        '''Exception class for early stopping training.'''
        pass

class CancelEpochException(Exception):
    def __init__(self):
        '''Exception class for early stopping epoch.'''
        pass

class CancelBatchException(Exception): 
    def __init__(self):
        '''Exception class for early stopping batch.'''
        pass

In [5]:
#export
class Learner():
    def __init__(self, data_bunch, model, loss_fn, optimizer, callbacks=[]):
        '''Learner class containing data bunch, model, loss function, optimizer, and callbacks for flexible training procedures.
            data_bunch: data bunch with training and validation data
            model: Sequential model
            loss_fn: fn that takes in predicted labels and labels to compute loss
            optimizer: optimizer that keeps track of hyperparameters and updates parameters
            callbacks: callback function for flexible training procedure
        '''
        self.data_bunch = data_bunch
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.callbacks = sorted([TrainEval()] + callbacks, key=lambda cb: cb.order)
        for callback in self.callbacks:
            callback.learner = self
    
    def __repr__(self):
        return f'{self.data_bunch}\n{self.model}\n{self.loss_fn}\n{self.optimizer}\n(Callbacks) {[cb.__class__.__name__ for cb in self.callbacks]}'

    def one_batch(self, x_batch, y_batch):
        try:
            self.x_batch = x_batch
            self.y_batch = y_batch
            if self('before_batch'):     return
            self.pred = self.model(self.x_batch)
            if self('after_pred'):       return
            self.loss = self.loss_fn(self.pred, self.y_batch)
            if self('after_loss'):       return
            if not self.model.training:  return
            self.loss_fn.backward()
            if self('after_loss_back'):  return
            self.model.backward()
            if self('after_model_back'): return
            self.optimizer.step()
            if self('after_step'):       return
            self.optimizer.zero_grad()
        except CancelBatchException:
            self('after_cancel_batch')

    def all_batches(self):
        data_loader = self.data_bunch.train_dl if self.model.training else self.data_bunch.valid_dl
        self.iters_count, self.iters = 0, len(data_loader)
        try:
            for x_batch, y_batch in data_loader:
                self.one_batch(x_batch, y_batch)
                self.iters_count += 1
                self('after_batch')
        except CancelEpochException: 
            self('after_cancel_epoch')

    def fit(self, num_epochs):
        self.num_epochs = num_epochs

        for callback in self.callbacks:
            callback.set_learner(self)
            
        if self('before_fit'):       return
        try:
            for epoch in range(1, num_epochs+1):
                self.epoch = epoch
                if self('before_epoch'): return
                if self('before_train'): return 
                self.all_batches()
                if self('before_valid'): return
                self.all_batches()
                if self('after_epoch'): break
        except CancelTrainException:
            self('after_cancel_train')
        finally:
            self('after_fit')

    def __call__(self, callback_name):
        for callback in self.callbacks:
            if callback(callback_name):
                return True
        return False

# Tests

In [6]:
data_bunch = get_data_bunch(*get_mnist_data(), batch_size=64)
model = get_lin_model(data_bunch)
optimizer = Optimizer(list(model.parameters()), learning_rate=0.1)
loss_fn = CrossEntropy()
callbacks = [EpochLogger()]

In [7]:
learner = Learner(data_bunch, model, loss_fn, optimizer, callbacks)
print(learner)

(DataBunch) 
    (DataLoader) 
        (Dataset) x: (50000, 784), y: (50000,)
        (Sampler) total: 50000, batch_size: 64, shuffle: True
    (DataLoader) 
        (Dataset) x: (10000, 784), y: (10000,)
        (Sampler) total: 10000, batch_size: 128, shuffle: False
(Model)
    Linear(784, 50)
    ReLU()
    Linear(50, 10)
(CrossEntropy)
(Optimizer) learning_rate: 0.1
(Callbacks) ['TrainEval', 'EpochLogger']


In [8]:
learner.fit(3)

Epoch 1
Epoch 2
Epoch 3


In [9]:
x_train, y_train, x_valid, y_valid = get_mnist_data()
x_train, y_train, x_valid, y_valid = x_train[:8000], y_train[:8000], x_valid[:2000], y_valid[:2000]

data_bunch = get_data_bunch(x_train, y_train, x_valid, y_valid, batch_size=64)
model = get_conv_model(data_bunch)
optimizer = DynamicOpt(list(model.parameters()), learning_rate=0.1) # dynamic optimizer
loss_fn = CrossEntropy()
callbacks = [EpochLogger()]

In [10]:
learner = Learner(data_bunch, model, loss_fn, optimizer, callbacks)
print(learner)

(DataBunch) 
    (DataLoader) 
        (Dataset) x: (8000, 784), y: (8000,)
        (Sampler) total: 8000, batch_size: 64, shuffle: True
    (DataLoader) 
        (Dataset) x: (2000, 784), y: (2000,)
        (Sampler) total: 2000, batch_size: 128, shuffle: False
(Model)
    Reshape(1, 28, 28)
    Conv(1, 8, 5, 4)
    ReLU()
    Conv(8, 16, 3, 2)
    Flatten()
    Linear(256, 10)
(CrossEntropy)
(DynamicOpt) hyper_params: ['learning_rate']
(Callbacks) ['TrainEval', 'EpochLogger']


In [11]:
learner.fit(3)

Epoch 1
Epoch 2
Epoch 3


In [12]:
x_train, y_train, x_valid, y_valid = get_mnist_data()
x_train, y_train, x_valid, y_valid = x_train[:8000], y_train[:8000], x_valid[:2000], y_valid[:2000]

data_bunch = get_data_bunch(x_train, y_train, x_valid, y_valid, batch_size=64)
model = get_conv_final_model(data_bunch)
optimizer = DynamicOpt(list(model.parameters()), learning_rate=0.1)
loss_fn = CrossEntropy()
callbacks = [EpochLogger()]

In [13]:
learner = Learner(data_bunch, model, loss_fn, optimizer, callbacks)
print(learner)

(DataBunch) 
    (DataLoader) 
        (Dataset) x: (8000, 784), y: (8000,)
        (Sampler) total: 8000, batch_size: 64, shuffle: True
    (DataLoader) 
        (Dataset) x: (2000, 784), y: (2000,)
        (Sampler) total: 2000, batch_size: 128, shuffle: False
(Model)
    Reshape(1, 28, 28)
    Conv(1, 4, 5, 2)
    AvgPool(2, 1)
    BatchNorm()
    Conv(4, 16, 3, 2)
    BatchNorm()
    Flatten()
    Linear(400, 64)
    ReLU()
    Linear(64, 10)
(CrossEntropy)
(DynamicOpt) hyper_params: ['learning_rate']
(Callbacks) ['TrainEval', 'EpochLogger']


In [14]:
learner.fit(3)

Epoch 1
Epoch 2
Epoch 3
