## Training a CNN on synthetic data with the fastai training loops and callbacks


In [20]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../src/')
from exp.nb_06 import *

import luke_model
import luke_data

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
x_train,y_train,x_valid,y_valid = luke_data.get_data()
x_train,x_valid = normalize_to(x_train,x_valid)
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)

In [22]:
c = 1
bs = 20
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [33]:
model = luke_model.Luke(cn_dropout=0.05)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
learn = Learner(model, optimizer, loss_func, data)
run = Runner(cbs=[AvgStatsCallback([accuracy])],cb_funcs=[CudaCallback])

In [35]:
from livelossplot import PlotLosses
class LivelossCallback(AvgStatsCallback):
    def __init__(self, metrics):
        super().__init__(metrics)
        self.liveloss = PlotLosses(skip_first=0)
        self.metricnames = [m.__name__ for m in metrics]
        self.logs={}
    
    def begin_epoch(self):
        super().begin_epoch()
        self.logs={}
        self.iteration = 0
        
    def after_loss(self):
        super().after_loss()
        if self.in_train:
            self.iteration += 1
            print('\r[%d, %5d] Train_loss: %.3f' %(self.epoch + 1, self.iteration, self.loss),end='')
    
    def after_epoch(self):
        super().after_epoch()
        self.logs['loss'] = self.train_stats.avg_stats[0]
        self.logs['val_loss'] = self.valid_stats.avg_stats[0]
        for i,metric in enumerate(self.metricnames):
            self.logs[metric] = self.train_stats.avg_stats[i+1].item()
            self.logs['val_'+metric] = self.valid_stats.avg_stats[i+1].item()
        self.liveloss.update(self.logs)
        self.liveloss.draw()

In [None]:
run = Runner([LivelossCallback([accuracy])],cb_funcs=[CudaCallback])
run.fit(30, learn)

In [15]:
# torch.save(model.state_dict(), '../../local/luke_model.pth')

In [16]:
Runner??

[0;31mInit signature:[0m [0mRunner[0m[0;34m([0m[0mcbs[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mcb_funcs[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      <no docstring>
[0;31mSource:[0m        
[0;32mclass[0m [0mRunner[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mcbs[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mcb_funcs[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0min_train[0m [0;34m=[0m [0;32mFalse[0m[0;34m[0m
[0;34m[0m        [0mcbs[0m [0;34m=[0m [0mlistify[0m[0;34m([0m[0mcbs[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0;32mfor[0m [0mcbf[0m [0;32min[0m [0mlistify[0m[0;34m([0m[0mcb_funcs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m            [0mcb[0m [0;34m=[0m [0mcbf[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m  

In [17]:
Callback??

[0;31mInit signature:[0m [0mCallback[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      <no docstring>
[0;31mSource:[0m        
[0;32mclass[0m [0mCallback[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0m_order[0m[0;34m=[0m[0;36m0[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0mset_runner[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mrun[0m[0;34m)[0m[0;34m:[0m [0mself[0m[0;34m.[0m[0mrun[0m[0;34m=[0m[0mrun[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__getattr__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mk[0m[0;34m)[0m[0;34m:[0m [0;32mreturn[0m [0mgetattr[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mrun[0m[0;34m,[0m [0mk[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m    [0;34m@[0m[0mproperty[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0mname[0m[0;34m([0m[0mself[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mname[0m [0;34m=[0m [0mre[0m[0;34m.[0m[0msub