In [None]:
# default_exp training

In [None]:
# hide
%load_ext autoreload
%autoreload 2

# Training

> Utils for basic training loop.

In [None]:
# export
import torch
import torch.nn as nn

In [1]:
# export
class AverageMeter:
    
    def __init__(self, store_vals=False, store_avgs=False):
        self.store_vals = store_vals
        self.store_avgs = store_avgs
        if store_vals: self.values = []
        if store_avgs: self.avgs = []
        self.sum, self.n, self.avg = 0, 0, 0
        
    def update(self, v):
        if self.store_vals: self.values.append(v)
        self.n += 1
        self.avg += (v - self.avg)/self.n
        
    def reset(self):
        if self.store_avgs and self.avg: self.avgs.append(self.avg)
        self.sum, self.n, self.avg = 0, 0, 0

In [None]:
def train_step(batch, model, optimizer, loss_func, scheduler):
    xb = batch.to(device)
    out = model(xb)

    loss, extra = loss_func(out[0], xb, *out[1:])
    
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.)
    optimizer.step()
    if scheduler is not None:
        scheduler.step()
    optimizer.zero_grad()
    return loss.item(), extra


In [None]:

def eval_step(batch, model, loss_func):
    xb = batch.to(device)
    with torch.no_grad():
        out = model(xb)
        loss, extra = loss_func(out[0], xb, *out[1:])
    return loss.item(), extra


In [None]:

def fit(n_epoch, model, train_dl, valid_dl, optimizer, loss_func, scheduler=None):
    
    steps_per_epoch = len(train_dl)
    total_steps = n_epoch * steps_per_epoch
    train_losses = np.ones((total_steps, 3))
    valid_losses = np.ones((n_epoch, 3))
    for e in trange(n_epoch):
        
        model.train()
        train_pbar = tqdm(train_dl, leave=False)
        for step, batch in enumerate(train_pbar):
            total_step = (e*steps_per_epoch)+step
            loss, h = train_step(batch, model, optimizer, loss_func, scheduler)
            train_losses[total_step, :] = np.array(h)
            train_pbar.set_description(f"{loss:.2f}")

        model.eval()
        avg_valid_loss = np.zeros(3)
        for step, batch in enumerate(valid_dl):
            loss, h = eval_step(batch, model, loss_func)
            avg_valid_loss += (np.array(h)-avg_valid_loss) / (step+1)
        valid_losses[e, :] = avg_valid_loss
    return train_losses, valid_losses

In [None]:
# hide
from nbdev.export import notebook2script; notebook2script()