In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_03 import *

# DataBunch/ Learner

In [3]:
x_train, y_train, x_valid, y_valid = get_data()
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)
nh, bs = 50, 64
c = y_train.max().item() + 1
loss_func = F.cross_entropy

In [4]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl, self.valid_dl, self.c = train_dl, valid_dl, c
        
    @property
    def train_ds(self): return self.train_dl.dataset
    
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [5]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [6]:
data.train_dl.dataset

<exp.nb_03.Dataset at 0x7f8757b93b20>

In [7]:
data.train_ds.x.shape

torch.Size([60000, 784])

In [8]:
#export 
def get_model(data, lr=0.1, nh=50):
    m = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh, data.c))
    return model, optim.Adam(model.parameters(), lr=lr)

class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model, self.opt, self.loss_func, self.data = model, opt, loss_func, data

In [9]:
learn = Learner(*get_model(data), loss_func, data)

In [10]:
xb, yb = train_ds[:bs]

In [11]:
loss_func(learn.model(xb), yb)

tensor(2.1768, grad_fn=<NllLossBackward>)

In [12]:
def fit(epochs, learn):
    for epoch in range(epochs):
        learn.model.train()
        for xb, yb in learn.data.train_dl:
            pred = learn.model(xb)
            loss = learn.loss_func(pred, yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()
            
        learn.model.eval()
        with torch.no_grad():
            tot_loss, tot_acc = 0., 0.
            for xb, yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += learn.loss_func(pred, yb)
                tot_acc += accuracy(pred, yb)
                
        nv = len(learn.data.valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv

In [13]:
loss, acc = fit(5, learn)

0 tensor(0.2660) tensor(0.9228)
1 tensor(0.2639) tensor(0.9238)
2 tensor(0.2644) tensor(0.9184)
3 tensor(0.2323) tensor(0.9367)
4 tensor(0.2094) tensor(0.9399)


In [14]:
loss_func(learn.model(xb), yb)

tensor(0.0498, grad_fn=<NllLossBackward>)

In [15]:
assert acc > 0.8

# CallbackHandler

In [16]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()
        
def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return
        
def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb) # why not cb.learn.data.train_dl??
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
        cb.after_fit()

In [17]:
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self): return True
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True
    def begin_validate(self): return True
    def after_epoch(self): return True
    def begin_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self): return True
    def after_step(self): return True

In [18]:
class TestCallback(Callback):
    def begin_fit(self, learn):
        super().begin_fit(learn)
        self.n_iters = 0
        return True
    
    def after_step(self):
        self.n_iters += 1
        if self.n_iters % 100 ==  0:
            print(self.n_iters)
        if self.n_iters >= 600: self.learn.stop = True
        return True

In [19]:
class CallbackHandler():
    def __init__(self, cbs=None):
        self.cbs = cbs if cbs else []
        
    def begin_fit(self, learn):
        self.learn, self.in_train = learn, True
        learn.stop = False
        res = True
        for cb in self.cbs: res = res and cb.begin_fit(learn) # why write this with res = res?
        return res
    
    def after_fit(self):
        res = not self.in_train # why do this at all?
        for cb in self.cbs: res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        self.learn.model.train()
        self.in_train = True
        res = True
        for cb in self.cbs: res = res and cb.begin_epoch(epoch)
        return res
    
    def begin_validate(self):
        self.learn.model.eval()
        self.in_train = False
        res = True
        for cb in self.cbs: res = res and cb.begin_validate()
        return res
    
    def after_epoch(self):
        res = True
        for cb in self.cbs: res = res and cb.after_epoch()
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
        return res
    
    def after_loss(self, loss):
        res = self.in_train
        for cb in self.cbs: res = res and cb.after_loss(loss)
        return res
    
    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res
    
    def after_step(self):
        res = True
        for cb in self.cbs: res = res and cb.after_step()
        return res
    
    def do_stop(self):
        try: return self.learn.stop
        finally: self.learn.stop = False

In [20]:
learn = Learner(*get_model(data), loss_func, data)

In [21]:
fit(1, learn, cb=CallbackHandler([TestCallback()]))
# fit(1, learn, cb=CallbackHandler([]))

100
200
300
400
500
600


In [29]:
accuracy(learn.model(xb), yb)

tensor(0.9375)

# Runner

In [23]:
#export 
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')

def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

In [24]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)

'train_eval_callback'

In [25]:
# self.__class__.__name__ why __class__ needed

In [32]:
#export
class Callback():
    _order = 0
    def set_runner(self, run): self.run = run
    def __getattr__(self, k): return getattr(self.run, k) # getattr gets called after the attribute was not found the first time so now the callback has all the attributes from the runner
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback') # why the or 'callback'

In [204]:
#export
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs = 0
        self.run.n_iter = 0
    
    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1. / self.iters # self.iters gets initialized in all_batches() with len(dl), so the amount of data; so its the percentage of data that went through the model
        self.run.n_iter += 1
        
    def begin_epoch(self):
        self.run.n_epochs = self.epoch
        self.model.train()
        self.run.in_train = True
        
    def begin_validate(self):
        self.model.eval()
        self.run.in_train = False

In [345]:
937 * 2

1874

In [416]:
class TestCallback(Callback):
    def after_step(self):
        if self.n_iter % 500 == 0: print(self.n_iter)
#         print(self.run.n_iter) # prints from 0 to 937*2 (number of batches * 2)
        # i think it should break after 10 iters but instead it still computes everything but after the if statement is True, the opt.step() step gets skipped
        if self.n_iter >= 1850: # if the if-statement gets called before both epochs after processed, the accuracy drops immensely fast 24 batches skipped -> 70% accuracy reduction
                                # makes no sense 
#             self.stop = True # makes no difference
            return True #n_iter means number of batches

In [536]:
class TestCallback(Callback):
#     def after_batch(self): doenst work
#         if self.n_iter % 500 == 0: print(self.n_iter)
        
#         if self.n_iter >= 10: return True

    def begin_batch(self): # DOES WORK!!
        # after the if statement is True, the runner run still through every of the 1874 batches but every batch after the first 10 gets returned after the if self('begin_batch'): return line
        if self.n_iter >= 10: return True

In [537]:
#export 
from typing import *

def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

In [538]:
#export
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb) # what does that do?
            cbs.append(cb)
        self.stop, self.cbs = False, [TrainEvalCallback()] + cbs
        
    @property
    def opt(self): return self.learn.opt
    @property
    def model(self): return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self): return self.learn.data
    
    def one_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        if self('begin_batch'): return
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return
        self.opt.zero_grad()
        
    def all_batches(self, dl):
        self.iters = len(dl)
        for xb, yb in dl:
            if self.stop: break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop = False
        
    def fit(self, epochs, learn):
        self.epochs, self.learn = epochs, learn
        
        try:
            for cb in self.cbs: cb.set_runner(self)
            if self('begin_fit'): return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): self.all_batches(self.data.train_dl)
                    
                with torch.no_grad():
                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)
                if self('after_epoch'): break
                    
        finally:
            self('after_fit')
            self.learn = None
            
    def __call__(self, cb_name):
        for cb in sorted(self.cbs, key=lambda x: x._order):
            f = getattr(cb, cb_name,  None) # checks whether the callback has an attribute called cb_name, eg. after_step as the TestCallback
            if f and f(): return True # if f exists and the return value from f() is True, the call returns true and it hits a return or break in the if statement
        return False

In [539]:
#export
class AvgStats():
    def __init__(self, metrics, in_train): self.metrics, self.in_train = listify(metrics), in_train
        
    def reset(self):
        self.tot_loss, self.count = 0., 0
        self.tot_mets = [0] * len(self.metrics)
        
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self): return [o / self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ''
        return f'{"train" if self.in_train else "valid"}: {self.avg_stats}'
    
    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i,m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn
            
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats, self.valid_stats = AvgStats(metrics, True), AvgStats(metrics, False)
        
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats # gets the self.in_train from Callback inherit and there from the runner + the __setattr__
        with torch.no_grad(): stats.accumulate(self.run)
            
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

In [540]:
stats = AvgStatsCallback([accuracy])

In [541]:
stats.__class__, stats.__class__.__name__

(__main__.AvgStatsCallback, 'AvgStatsCallback')

In [542]:
learn = Learner(*get_model(data), loss_func, data)

In [543]:
run = Runner(cbs=TestCallback())

In [544]:
%time run.fit(2, learn)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
CPU times: user 2.22 s, sys: 79.9 ms, total: 2.3 s
Wall time: 1.01 s


In [545]:
accuracy(learn.model(x_valid), y_valid)

tensor(0.3978)