In [0]:
from pathlib import Path
import torch.nn
import matplotlib.pyplot as plt

# Data

try to shuffel your data in training set. Random sampling

We want our training set to be in a random order, and that order should differ each iteration. But the validation set shouldn't be randomized.

PyTorch's defaults work fine for most things however:

In [0]:
# train_dl = DataLoader(train_ds, bs, shuffle=True, drop_last=True)
# valid_dl = DataLoader(valid_ds, bs, shuffle=False)

In [0]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl,self.valid_dl,self.c = train_dl,valid_dl,c
        
    @property
    def train_ds(self): return self.train_dl.dataset
        
    @property
    def valid_ds(self): return self.valid_dl.dataset

#Callbacks

In [0]:
class Callback():
    _order = 0
    
    def set_runner(self, run): self.run = run
    def __getattr__(self, k): return getattr(self.run, k)

    def begin_fit(self):
        return True
    def after_fit(self): return True
    def begin_epoch(self):
        return True
    def begin_validate(self): return True
    def after_epoch(self): return True
    def begin_batch(self):
        return True
    def after_loss(self):
        return True
    def after_backward(self): return True
    def after_step(self): return True
    def after_pred(self): return True
    def after_cancel_batch(self): return True
    def after_cancel_epoch(self): return True
    
    def __call__(self, cb_name):
        f = getattr(self, cb_name, None)
        if f and f(): return True
        return False
        

**This is a must callback**
calls model.train() and model.eval() at appropriate times and sets other necessary values

In [0]:
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs=0.
        self.run.n_iter=0
    
    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter   += 1
        
    def begin_epoch(self):
        self.run.n_epochs=self.epoch
        self.model.train()
        self.run.in_train=True

    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False

class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

In [0]:
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.metrics = metrics
        self.train_stats,self.valid_stats = AvgStats( [],True),AvgStats(metrics,False)
        
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run)
    
    def after_epoch(self):
        mtrs = ""
        for metric in self.metrics:
            mtrs = "%14s"%metric.__name__
        if self.epoch == 0: print("%14s"%"train loss"+"%14s"%"valid loss"+ mtrs)
        print(str(self.train_stats)+ str(self.valid_stats))
    

 the Recorder to save track of the loss and our scheduled learning rate, and a ParamScheduler that can schedule any hyperparameter as long as it's registered in the state_dict of the optimizer.

In [0]:
class Recorder(Callback):
    def begin_fit(self): self.lrs, self.losses = [], []

    def after_batch(self):
        if not self.in_train: return 
        self.lrs.append(self.opt.param_groups[-1]['lr'])
        self.losses.append(self.loss.detach().cpu())

    def plot_lr(self): plt.plot(self.lrs)
    def plot_loss(self): plt.plot(self.losses)

In [0]:
class ParamSchheduler(Callback):
  _order = 1
  def __init__(self, pname, sched_func): self.pname, self.sched_func = pname, sched_func
    
  def set_param(self):
    for pg in self.opt.param_groups:
      pg[self.pname] = self.sched_func(self.n_epochs/self.epochs)
      
  def begin_batch(self):
    if self.in_train: self.set_param()

**scheduler functions to be added**

In [0]:
class CudaCallback(Callback):
    def begin_fit(self): self.model.cuda()
    def begin_batch(self): self.run.xb,self.run.yb = self.xb.cuda(),self.yb.cuda()

# Training

a basic optimizer

In [0]:
# class Optimizer():
  
#   def __init__(self, params, lr=0.5): self.params,self.lr=list(params),lr
    
#   def step(self):
#     with torch.no_grad():
#       for p in self.params: p -= p.grad*lr
        
#   def zero_grad(self):
#     for p in self.params: p.grad.data.zero_()

In [0]:
class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model,self.opt,self.loss_func,self.data = model,opt,loss_func,data

In [0]:
class Runner():
    def __init__(self, callbacks = None):
        self.cbs = [TrainEvalCallback()] + callbacks
        self.stop = False

    @property
    def opt(self):
        return self.learn.opt
    @property
    def model(self):
        return self.learn.model
    @property
    def loss_func(self):
        return self.learn.loss_func
    @property
    def data(self):
        return self.learn.data

    def one_batch(self, xb, yb):
        try:
            self.xb, self.yb = xb, yb
            self('begin_batch')
            self.pred = self.model(self.xb)
            self('after_pred')
            self.loss = self.loss_func(self.pred, self.yb)
            self('after_loss')
            if not self.in_train: return
            self.loss.backward()
            self('after_backward')
            self.opt.step()
            self('after_step')
            self.opt.zero_grad()
        except CancelBatchException: self('')
        finally: self('after_batch')

    def all_batches(self, dl):
        self.iters = len(dl)
        try:
            for xb, yb in dl: self.one_batch(xb, yb)
        except CancelEpochException: self('after_cancel_epoch')

    def fit(self, epochs, learn):
        self.epochs, self.learn, self.loss = epochs, learn, torch.tensor(0.)

        try:
            for cb in self.cbs: cb.set_runner(self)
            self('begin_fit')
            self
            for epoch in range(epochs):
                self.epoch = epoch
                self('begin_epoch') 
                self.all_batches(self.data.train_dl)

                with torch.no_grad():
                    self('begin_validate')
                    self.all_batches(self.data.valid_dl)
                self('after_epoch')

        except CancelTrainException: self('after_cancel_train')
        finally:
            self('after_fit')
            self.learn = None

    def __call__(self, cb_name):
        res = False
        for cb in sorted(self.cbs, key=lambda x: x._order): res = cb(cb_name) and res
        return res


# Metrics

In [0]:
def accuracy(out, yb): return (torch.argmax(out,dim=1)==yb).float().mean()

In [0]:
class AvgStats():
  def __init__(self, metrics, in_train): self.metrics, self.in_train = metrics, in_train
    
  def reset(self):
    self.tot_loss, self.count = 0., 0
    self.tot_mets = [0.]*len(self.metrics)
    
  @property
  def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
  
  @property
  def avg_stats(self): return [o/self.count for o in self.all_stats]

  @property
  def printAvgStats(self):
      result = ''
      for idx,stat in enumerate(self.avg_stats):
          if idx==0 :
              result += "%14s"%('%.6f'%stat)
          else:

              result += "%14s"%('%.6f'%stat)
      return result
  
  def __repr__(self):
    if not self.count: return ""
    return f"{self.printAvgStats}"
  
  def accumulate(self, run):
    bn = run.xb.shape[0]
    self.tot_loss += run.loss*bn
    self.count += bn
    for i,m in enumerate(self.metrics):
      self.tot_mets[i] += m(run.pred, run.yb)*bn