In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
from os.path import join

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from param_scheduling import *

In [3]:
#export
class ItersStopper(Callback):
    def __init__(self, end_iter=10):
        self.end_iter = end_iter

    def after_step(self):
        print(f'iteration: {self.iters_count}')
        if self.iters_count >= self.end_iter:
            raise CancelTrainException()
    
    def after_cancel_train(self):
        print(f'Training cancelled at the end of iteration {self.end_iter}')

class EpochsStopper(Callback):
    def __init__(self, end_epoch=10):
        self.end_epoch = end_epoch
        
    def before_epoch(self):
        if self.epoch > self.end_epoch:
            raise CancelTrainException()
    
    def after_cancel_train(self):
        print(f'Training cancelled at the end of epoch {self.end_epoch}')

In [4]:
#export
class AccuracyStopper(Callback):
    def __init__(self, patience=5, log=True):
        self.valid_stats = AvgStats([compute_accuracy], False)
        self.patience = patience
        self.log = log
        self.best_acc = 0
        self.waited = 0
        
    def before_epoch(self):
        self.valid_stats.reset()
        
    def after_loss(self):
        self.valid_stats.accumulate(self.learner)
    
    def _update(self):
        self.waited += 1
        if self.best_acc < self.valid_stats.avg_stats[1]:
            self.best_acc = self.valid_stats.avg_stats[1]
            self.waited = 0
        
    def after_epoch(self):
        if self.log: print(f'Epoch - {self.epoch}    Acc: {self.valid_stats.avg_stats[1]}')
        self._update()
        if self.waited > self.patience:
            raise CancelTrainException()

In [5]:
schedule = combine_schedules([0.4, 0.6], one_cycle_cos(0.01, 0.5, 0.01))

data_bunch = get_data_bunch(*get_mnist_data(), batch_size=64)
model = get_lin_model(data_bunch)
optimizer = DynamicOpt(list(model.parameters()), learning_rate=0.1)
loss_fn = CrossEntropy()

In [6]:
callbacks = [EpochsStopper(1), ParamScheduler('learning_rate', schedule), StatsLogging()]
learner = Learner(data_bunch, model, loss_fn, optimizer, callbacks)
learner.fit(10000)

Epoch - 1
train metrics - [8.296754837036133e-05, 0.888]
valid metrics - [2.171611785888672e-05, 0.9648]

Training cancelled at the end of epoch 1


In [7]:
callbacks = [ItersStopper(5), ParamScheduler('learning_rate', schedule), StatsLogging()]
learner = Learner(data_bunch, model, loss_fn, optimizer, callbacks)
learner.fit(10000)

iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
Training cancelled at the end of iteration 5


In [8]:
callbacks = [AccuracyStopper(1), ParamScheduler('learning_rate', schedule)]
learner = Learner(data_bunch, model, loss_fn, optimizer, callbacks)
learner.fit(10000)

Epoch - 1    Acc: 0.9596833333333333
Epoch - 2    Acc: 0.9655
Epoch - 3    Acc: 0.97105
Epoch - 4    Acc: 0.9741666666666666
Epoch - 5    Acc: 0.97565
Epoch - 6    Acc: 0.9777333333333333
Epoch - 7    Acc: 0.9789333333333333
Epoch - 8    Acc: 0.9802666666666666
Epoch - 9    Acc: 0.9794833333333334
Epoch - 10    Acc: 0.98075
Epoch - 11    Acc: 0.9817666666666667
Epoch - 12    Acc: 0.9828833333333333
Epoch - 13    Acc: 0.9845166666666667
Epoch - 14    Acc: 0.9848666666666667
Epoch - 15    Acc: 0.9837833333333333
Epoch - 16    Acc: 0.9838
