In [None]:
#default_exp callbacks_05

In [None]:
#export
from ModernArchitecturesFromScratch.basic_operations_01 import *
from ModernArchitecturesFromScratch.fully_connected_network_02 import *
from ModernArchitecturesFromScratch.model_training_03 import *
from ModernArchitecturesFromScratch.convolutions_pooling_04 import *
from nbdev.showdoc import *
import math

# DataLoader

In [None]:
#export
#hide
class Dataset():
    def __init__(self, x, y): self.x, self.y = x, y
    def __getitem__(self, i): return self.x[i], self.y[i]
    def __len__(self): return len(self.x)
    def __repr__(self): return f'X: {self.x.shape}, Y: {self.y.shape}'

class DataLoader():
    def __init__(self, ds, batcher, collate_fcn): self.ds, self.batcher, self.collate_fcn = ds, batcher, collate_fcn    
    def __iter__(self):
        for b in self.batcher: yield self.collate_fcn([self.ds[i] for i in b])     
    @property
    def dataset(self): return self.ds
    def __len__(self): return math.ceil(len(self.ds) / self.batcher.bs)
    def __repr__(self): return f'Data: {self.ds}, bs = {self.batcher.bs}'

In [None]:
#export
class Databunch():
    "Wrapper to combine training and validation datasets"
    def __init__(self, train_dl, valid_dl): self.train, self.valid = train_dl, valid_dl
    
    @property
    def train_ds(self): return self.train.dataset
    
    @property
    def valid_ds(self): return self.valid.dataset
    
    def __repr__(self): return f'Databunch(\nTrain: {self.train}, \nValid{self.valid}\n)'

In [None]:
def get_databunch(xt, yt, xv, yv, bs=64):
    "Helper function to get a databunch of given `bs`"
    t_data, v_data = Dataset(xt, yt), Dataset(xv, yv)
    t_dl, v_dl = DataLoader(t_data, Batcher(t_data, bs, True), collate), DataLoader(t_data, Batcher(t_data, bs*2, False), collate)
    return Databunch(t_dl, v_dl)

def get_mnist_databunch():
    "Grabs MNIST databuunch usuing `get_mnist` and `get_databunch`"
    return get_databunch(*get_mnist())

In [None]:
db = get_mnist_databunch()

In [None]:
db

Databunch(
Train: Data: X: torch.Size([50000, 784]), Y: torch.Size([50000]), bs = 64, 
ValidData: X: torch.Size([50000, 784]), Y: torch.Size([50000]), bs = 128
)

# Learner

In [None]:
class Learner():
    def __init__(self, model, loss_func, optimizer, db):
        "Wrapper for model, loss function, optimizer and databunch"
        self.model, self.loss_func, self.optimizer, self.db = model, loss_func, optimizer, db

# Runner

In [None]:
#export
class Runner():
    "All encompossing class to train a model with specific callbacks"
    def __init__(self, learner, cbs=None):
        cbs = [] if cbs is None else cbs
        self.stop,self.cbs = False,[TrainEvalCallback()]+cbs
        
        for cb in self.cbs:
            cb.runner = self
            
        self.learner = learner
    
    @property
    def model(self): return self.learner.model
    @property
    def optimizer(self): return self.learner.optimizer
    @property
    def loss_func(self): return self.learner.loss_func
    @property
    def databunch(self): return self.learner.db
    
    def do_one_batch(self, xb, yb):
        "Applies forward and backward passes of model to one batch"
        self.xb, self.yb = xb, yb
        
        self.pred = self.learner.model(xb)
        self.loss = self.learner.loss_func(self.pred, yb)
        if self.check_callbacks('after_loss') or not self.learner.model.training: return
        
        self.learner.loss_func.backward()
        if self.check_callbacks('after_loss_back'): return
        
        self.learner.model.backward()
        if self.check_callbacks('after_model_back'): return
        
        self.opt.step()
        if self.check_callbacks('after_opt'): return
        
        self.opt.zero_grad()
        if self.check_callbacks('after_zero_grad'): return
    
    def do_all_batches(self, dl):
        "Runs every batch of a dataloader through `do_one_batch`"
        self.iters, self.iters_done = len(dl), 0
        for xb, yb in dl:
            if self.stop: break
            if self.check_callbacks('before_batch'): return
            self.do_one_batch(xb,yb)
            if self.check_callbacks('after_batch'): return
        self.iters = 0
            
        self.stop = False

    def fit(self, epochs, lr=0.1):
        "Method to fit the model `epoch` times using learning rate `lr`"
        self.lr, self.epochs = lr, epochs
        if self.check_callbacks('before_fit'): return
        
        for epoch in range(epochs):
            self.epoch = epoch
            if self.check_callbacks('before_epoch'): return
            if not self.check_callbacks('before_train'): self.do_all_batches(self.learner.db.train)
            if not self.check_callbacks('before_valid'): self.do_all_batches(self.learner.db.valid)
            if self.check_callbacks('after_epoch'): break
        
        if self.check_callbacks('after_fit'): return
    
    def check_callbacks(self, state):
        "Helper functions to run through each callback, calling it's state method if applicable"
        for cb in sorted(self.cbs, key=lambda x: x._order):
            f = getattr(cb, state, None)
            if f and f(): return True
        return False

In [None]:
show_doc(Runner.__init__)

<h4 id="Runner.__init__" class="doc_header"><code>Runner.__init__</code><a href="__main__.py#L4" class="source_link" style="float:right">[source]</a></h4>

> <code>Runner.__init__</code>(**`learner`**, **`cbs`**=*`None`*)

Initialize self.  See help(type(self)) for accurate signature.

In [None]:
show_doc(Runner.do_one_batch)

<h4 id="Runner.do_one_batch" class="doc_header"><code>Runner.do_one_batch</code><a href="__main__.py#L22" class="source_link" style="float:right">[source]</a></h4>

> <code>Runner.do_one_batch</code>(**`xb`**, **`yb`**)

Applies forward and backward passes of model to one batch

In [None]:
show_doc(Runner.do_all_batches)

<h4 id="Runner.do_all_batches" class="doc_header"><code>Runner.do_all_batches</code><a href="__main__.py#L42" class="source_link" style="float:right">[source]</a></h4>

> <code>Runner.do_all_batches</code>(**`dl`**)

Runs every batch of a dataloader through `do_one_batch`

In [None]:
show_doc(Runner.fit)

<h4 id="Runner.fit" class="doc_header"><code>Runner.fit</code><a href="__main__.py#L54" class="source_link" style="float:right">[source]</a></h4>

> <code>Runner.fit</code>(**`epochs`**, **`lr`**=*`0.1`*)

Method to fit the model `epoch` times using learning rate `lr`

In [None]:
show_doc(Runner.check_callbacks)

<h4 id="Runner.check_callbacks" class="doc_header"><code>Runner.check_callbacks</code><a href="__main__.py#L68" class="source_link" style="float:right">[source]</a></h4>

> <code>Runner.check_callbacks</code>(**`state`**)

Helper functions to run through each callback, calling it's state method if applicable

# Callbacks

In [None]:
#export
class Callback():
    "Base class for callbacks, defines order of execution and allows abstraction of self to runner class"
    _order = 0
    def __getattr__(self,k):
        #If callback doesn't have an attribute, check the runner
        return getattr(self.runner, k)

    def __repr__(self): return f'{self.__class__.__name__}'

In [None]:
#export
class TrainEvalCallback(Callback):
    "Keeps track of training/eval mode of model and progress through training"
    _order = 10
    
    def before_fit(self):
        self.runner.opt = self.learner.optimizer(self.learner.model.parameters(), self.lr)
        self.runner.epochs_done = 0.
        
    def before_batch(self):
        self.runner.iters_done += 1
        self.runner.epochs_done += 1/self.iters
        
    def before_valid(self):
        self.model.training = False
    
    def before_train(self):
        self.model.training = True
    
    def after_epoch(self):
        self.runner.iters_done = 0

In [None]:
#export
class Stat():
    "Defines a metric to keep track of through training, metric calculated using `calc`"
    def __init__(self, calc): self.calc, self.value, self.count = calc, 0., 0
    
    def __call__(self, bs, *args):
        self.value += self.calc(*args) * bs
        self.count += bs
    
    def reset(self): self.value, self.count = 0., 0
        
    def __repr__(self): return f'{(self.calc.__name__).capitalize()}: {self.value / self.count}' if self.count > 0 else f'{(self.calc.__name__).capitalize()}'
    
class StatTracker():
    "Class to implement thet `Stats` callback using metrics of class `Stat`"
    def __init__(self, metrics, in_train):
        self.in_train = in_train
        self.metrics = [Stat(m) for m in metrics]
    
    def reset(self):
        self.count, self.tot_loss = 0., 0.
        for met in self.metrics: met.reset()
    
    def __len__(self): return len(self.metrics)
    
    def accumulate(self, run):
        "Scales the metric value by the amount of data in each batch"
        bs = run.xb.shape[0]
        self.tot_loss = run.loss * bs
        self.count += bs
        for i,met in enumerate(self.metrics):
            met(bs, run.pred, run.yb)
    
    def __repr__(self):
        if self.count < 1: return ""
        else:
            printed_stats = f'Loss: {self.tot_loss / self.count}'
            for met in self.metrics:
                printed_stats += f', {met}'
            return f'{"Train" if self.in_train else "Valid"}: {printed_stats}'
    
class Stats(Callback):
    "Callback to keep track of `metrics`"
    def __init__(self, metrics):
        self.train, self.valid = StatTracker(metrics, True), StatTracker(metrics, False)
    
    def before_epoch(self):
        self.train.reset()
        self.valid.reset()
    
    def after_loss(self):
        stats = self.train if self.model.training else self.valid
        stats.accumulate(self.runner)
        
    def after_epoch(self):
        print(f'Epoch: {self.epoch+1}')
        print(self.train)
        print(self.valid)

In [None]:
run = Runner(learn, [Stats([accuracy])])

In [None]:
run.fit(5, 0.1)

Epoch: 1
Train: Loss: 0.00012180877820355818, Accuracy: 0.9136000275611877
Valid: Loss: 0.0002788938581943512, Accuracy: 0.9381999969482422
Epoch: 2
Train: Loss: 4.875722515862435e-05, Accuracy: 0.9558600187301636
Valid: Loss: 0.00014287835801951587, Accuracy: 0.9671800136566162
Epoch: 3
Train: Loss: 0.00010866371303563938, Accuracy: 0.966920018196106
Valid: Loss: 0.00022182443353813142, Accuracy: 0.9489399790763855
Epoch: 4
Train: Loss: 5.4490061302203685e-05, Accuracy: 0.9728999733924866
Valid: Loss: 7.179772364906967e-05, Accuracy: 0.9632599949836731
Epoch: 5
Train: Loss: 4.08835694543086e-05, Accuracy: 0.9775000214576721
Valid: Loss: 8.388313290197402e-05, Accuracy: 0.9725599884986877
