In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
from os.path import join

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from hyper_param_scheduling import *

In [3]:
#export
class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

In [4]:
#export
class Runner():
    def __init__(self, learner, callbacks=[]):
        self.learner = learner
        self.stop = False
        self.callbacks = sorted([TrainEval()] + callbacks, key=lambda cb: cb.order)
        for callback in self.callbacks:
            callback.runner = self
    
    def __repr__(self):
        return f'{self.learner}\n(Callbacks) ' + ' '.join(map(lambda cb: cb.name, self.callbacks))
        
    @property
    def data_bunch(self): return self.learner.data_bunch
    @property
    def model(self):      return self.learner.model
    @property
    def loss_fn(self):    return self.learner.loss_fn
    @property
    def optimizer(self):  return self.learner.optimizer

    def one_batch(self, x_batch, y_batch):
        try:
            self.x_batch = x_batch
            self.y_batch = y_batch

            if self('before_batch'):     return
            self.pred = self.model(self.x_batch)
            if self('after_pred'):       return
            self.loss = self.loss_fn(self.pred, self.y_batch)
            if self('after_loss'):       return
            if not self.model.training:  return
            self.loss_fn.backward()
            if self('after_loss_back'):  return
            self.model.backward()
            if self('after_model_back'): return
            self.optimizer.step()
            if self('after_step'):       return
            self.optimizer.zero_grad()
        except CancelBatchException: 
            self('after_cancel_batch')

    def all_batches(self):
        data_loader = self.data_bunch.train_dl if self.model.training else self.data_bunch.valid_dl
        self.iters_count, self.iters = 0, len(data_loader)
        try:
            for x_batch, y_batch in data_loader:
                if self.stop: break
                self.one_batch(x_batch, y_batch)
                self.iters_count += 1
                self('after_batch')
        except CancelEpochException: 
            self('after_cancel_epoch')

    def fit(self, num_epochs):
        self.num_epochs = num_epochs
        
        for callback in self.callbacks:
            callback.set_runner(self)

        if self('before_fit'):       return
        try:
            for epoch in range(1, num_epochs+1):
                if self.stop: break
                self.epoch = epoch
                if self('before_epoch'): return
                if self('before_train'): return 
                self.all_batches()
                if self('before_valid'): return
                self.all_batches()
                if self('after_epoch'): break
        except CancelTrainException:
            self('after_cancel_train')
        finally:
            self('after_fit')

    def __call__(self, callback_name):
        for callback in self.callbacks:
            if callback(callback_name):
                return True
        return False

In [5]:
#export
class ItersStop(Callback):
    def after_step(self):
        print(f'iteration: {self.iters_count}')
        if self.iters_count >= 10:
            raise CancelTrainException()

In [6]:
num_hidden = 50
batch_size = 64
num_epochs = 20
learning_rate = 0.1
schedule = combine_schedules([0.4, 0.6], one_cycle_cos(0.01, 0.5, 0.01))

data_bunch = get_data_bunch(*get_mnist_data(), batch_size)
model, optimizer = get_model(data_bunch, learning_rate, num_hidden)
loss_fn = CrossEntropy()
learner = Learner(model, optimizer, loss_fn, data_bunch)

runner = Runner(learner, [ItersStop(), ParamScheduler('learning_rate', schedule), StatsLogging([compute_accuracy])])
print(runner)

(DataBunch) 
	(DataLoader) 
		(Dataset) x: (50000, 784), y: (50000,)
		(Sampler) total: 50000, batch_size: 64, shuffle: True
	(DataLoader) 
		(Dataset) x: (10000, 784), y: (10000,)
		(Sampler) total: 10000, batch_size: 128, shuffle: False
(Sequential)
	(Layer1) Linear(784, 50)
	(Layer2) ReLU()
	(Layer3) Linear(50, 10)
(CrossEntropy)
(Optimizer) num_params: 4, hyper_params: ['learning_rate']
(Callbacks) TrainEval ItersStop ParamScheduler StatsLogging


In [7]:
runner.fit(10000)

iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
iteration: 10
