## Librarires

In [1]:
import re
from exp.nb_09 import *

## Useful functions

In [2]:
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')

# Converts camel names to snake names
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

c2s = camel2snake("BoomDoomWeeDg")
c2s, _camel_re1

('boom_doom_wee_dg', re.compile(r'(.)([A-Z][a-z]+)', re.UNICODE))

## Using Exceptions as flow control

In [3]:
#export
class Callback():
    _order=0
    def set_runner(self, run): self.run=run
    def __getattr__(self, k): return getattr(self.run, k) # if cannot find the attribute inside the cb, 
    #look inside the runner
    
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')
    
    def __call__(self, cb_name):
        f = getattr(self, cb_name, None)
        if f and f(): return True
        return False

class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs=0.
        self.run.n_iter=0
    
    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter   += 1
        
    def begin_epoch(self):
        self.run.n_epochs=self.epoch
        self.model.train()
        self.run.in_train=True

    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False

class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

class TestCbsCallback(Callback):
    def begin_fit(self):
        print("I'm a test callback")



## Combining Learner and runner

In [4]:
#export
def param_getter(m): return m.parameters()

def sgd_opt():
    pass

class LearnerRunnerDJ():
    
    def __init__(self, model, data, loss_func, opt_func=sgd_opt, lr=1e-2, splitter=param_getter,
                 cbs=None, cb_funcs=None): # callback functions used to create callbacks
        
        # model - the model
        # data 
        # loss func - usually cross entropy
        # opt func - optimiser
        self.model , self.data, self.loss_func, self.opt_func, self.lr, self.splitter = model, data, loss_func, opt_func,lr, splitter
        
        self.logger, self.in_train, self.opt = print, False, None # the logger is a print function to print the output from a nn
        
        self.cbs = []
        self.add_cb(TrainEvalCallback())
        self.add_cbs(cbs)
        self.add_cbs(cbf() for cbf in listify(cb_funcs)) # goes through all callback creastor funcs and runs them
        
    def add_cbs(self, cbs):
        for cb in listify(cbs): self.add_cb(cb) # call add cb on all callbacks
            
    def add_cb(self, cb):
        cb.set_runner(self) # sets a 'pointer' to the runner inside the callback (so can access runner members such as)
        #set_trace()
        setattr(self, cb.name, cb) # set the callback as a member of the learner class
        self.cbs.append(cb) # add the callback to the call backs list

    def remove_cbs(self, cbs):
        for cb in listify(cbs): self.cbs.remove(cb)
            
    def all_batches(self):
        pass
    
    def one_batch(self):
        pass
    
    def fit(self):
        pass
    def __call__(self):
        res = False
        assert cb_name in self.ALL_CBS
        for cb in sorted(self.cbs, key=lambda x: x._order): res = cb(cb_name) and res
        #set_trace()
        return res # callback needs to return True to stop. With no return Python returns None (False)
        

## Average stats call back

In [5]:
#export
class AvgStats(): # average stats is a class to calculate and store training stats (i.e accuracy)
    def __init__(self, metrics, in_train): 
        self.metrics,self.in_train = listify(metrics),in_train
        print("Metrics:")
        print(metrics)
        
    def reset(self):
        self.tot_loss,self.count = 0.,0
        self.tot_mets = [0.] * len(self.metrics)
        
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self): return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run): # run is presumably the runner
        bn = run.xb.shape[0] # xb is the mini batch
        self.tot_loss += run.loss * bn # I think this is accounting for the batch size
        self.count += bn
        for i,m in enumerate(self.metrics): # applies the metric function to the prediction and the loss
            self.tot_mets[i] += m(run.pred, run.yb) * bn # can have any metrics on the predictions and truth 

class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        # train_stats and valid_stats are containers for stats for training and valdation sets
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)
        
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self): # after loss, the accumulates the loss on each mini batch
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run)
    
    def after_epoch(self):
        #We use the logger function of the `Learner` here, it can be customized to write in a file or in a progress bar
        self.logger(self.train_stats)
        self.logger(self.valid_stats) 

## Testing

In [20]:
#x = Callback()
tr = TrainEvalCallback()
#x.__class__.__name__, x.name, tr.name


def makeTestCB():
    return TestCbsCallback()

run = LearnerRunnerDJ(1, 2, 3, cb_funcs= makeTestCB)



run.cbs[0].run, run.cbs[0].name, run.cbs[1].name, run.train_eval


avgSts = AvgStats(accuracy, True)

x = tensor([1, 2])
y = tensor([[2, 6, 3], [2, 5, 3]])

accuracy(y, x)

avgSts.accuml

Metrics:
<function accuracy at 0x7f3c180e4d90>


tensor(0.5000)

In [None]:
class test():
    
    def print(self):
        print("Hello")
        return True
        
    def __call__(self, func):
        f = getattr(self, func, None)
        print(f)
        if f and f(): return True # note f() runs the function
        return False
    
ttt = test()
ttt('print')