# training

> Training loop

In [None]:
#|default_exp training

In [None]:
#|hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from isaacai.utils import *
from isaacai.dataloaders import *
from isaacai.models import *

from datetime import datetime
import torchvision.transforms.functional as TF,torch.nn.functional as F

import matplotlib.pyplot as plt,matplotlib as mpl
import fastcore.all as fc
import torch
from torch import nn, Tensor
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
import pandas as pd , numpy as np
from torcheval.metrics import MulticlassAccuracy,Mean

In [None]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

import logging
logging.disable(logging.WARNING)

set_seed(42)

In [None]:
sample_size = 2000

xmean,xstd = 0.28, 0.35

@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)

_dataset = sample_dataset_dict(_dataset)
    
dls = DataLoaders.from_dataset_dict(_dataset, 64, num_workers=4)

  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
#| export
def run_callbacks(callbacks, method_name, trainer=None):
    for callback in sorted(callbacks, key=lambda x: getattr(x, 'order',0)):
        callback_method = getattr(callback, method_name,None)
        if callback_method is not None: callback_method(trainer)

In [None]:
#| export
class ProgressCB:
    def __init__(self, precision=4, **metrics):
        fc.store_attr(names=['precision'])
        self.metrics = metrics
        self.loss_train, self.loss_valid = Mean(), Mean()
        self.stats_epoch = fc.L()
        
    def log(self,x): print(x)
    
    def before_batch(self,trainer):
        self.batch_size = len(trainer.batch[1])
    def after_batch(self,trainer):
        # Collect loss, metrics and store
        if trainer.training: self.loss_train.update(to_cpu(trainer.loss.detach()),weight=self.batch_size)
        else: 
            self.loss_valid.update(to_cpu(trainer.loss.detach()),weight=self.batch_size)
            for name, metric in self.metrics.items():
                self.metrics[name].update(to_cpu(trainer.preds.detach()),to_cpu(trainer.batch[1]))
            
    def before_epoch(self,trainer): self.st = datetime.now()
    def after_epoch(self,trainer):
        # compute metrics and append to epoch stats and display
        _stats = {'epoch':trainer.epoch}
        _stats.update({'train_loss':round(float(self.loss_train.compute()),self.precision),
                  'valid_loss':round(float(self.loss_valid.compute()),self.precision)})
        _stats.update({name:round(float(metric.compute()),self.precision) for name, metric in self.metrics.items()})
        _stats.update({'elapsed':str(datetime.now() - self.st)})
        self.stats_epoch.append(_stats)
        self.loss_train.reset(); self.loss_valid.reset(); [metric.reset() for _,metric in self.metrics.items()];
        self.log(_stats)

In [None]:
#| export
class DeviceCB:
    def __init__(self, device=def_device): fc.store_attr()
    def before_fit(self, trainer):
        if hasattr(trainer.model, 'to'): trainer.model.to(self.device)
    def before_batch(self, trainer): 
        trainer.batch = to_device(trainer.batch, device=self.device)

In [None]:
#| export 
class Trainer:
    def __init__(self, dls, loss_func, opt_func, model, callbacks):
        self.callbacks = [o.__class__.__name__ for o in callbacks]
        for callback in callbacks: setattr(self,callback.__class__.__name__,callback)
        fc.store_attr(but='callbacks')

    def one_batch(self):
        self.run_callbacks('before_batch')
        self.preds = self.model(self.batch[0])
        self.loss = self.loss_func(self.preds, self.batch[1])
        if self.training:
            self.run_callbacks('before_backward')
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
        self.run_callbacks('after_batch')

    def one_epoch(self):
        self.run_callbacks('before_epoch')
        
        self.model.train()
        self.run_callbacks('before_train')
        for self.batch in self.dls.train: self.one_batch()
        self.run_callbacks('after_train')

        self.model.eval()
        self.run_callbacks('before_valid')
        for self.batch in self.dls.valid: self.one_batch()
        self.run_callbacks('after_valid')
        
        self.run_callbacks('after_epoch')

    def fit(self, epochs=3, lr=1e-3):
        self.run_callbacks('before_fit')
        self.opt = self.opt_func(self.model.parameters(), lr)
        for self.epoch in range(epochs): self.one_epoch()
        self.run_callbacks('after_fit')

    @property
    def training(self): return self.model.training

    def run_callbacks(self,method_name): 
        cbs = [getattr(self,o) for o in self.callbacks]
        run_callbacks(cbs,method_name,self)

In [None]:
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  SimpleNet(28*28,64,10), 
                  callbacks=[ProgressCB(Accuracy=MulticlassAccuracy()), DeviceCB()])

In [None]:
trainer.fit()

{'epoch': 0, 'train_loss': 1.142, 'valid_loss': 0.7138, 'Accuracy': 0.739, 'elapsed': '0:00:02.553357'}
{'epoch': 1, 'train_loss': 0.6201, 'valid_loss': 0.639, 'Accuracy': 0.7585, 'elapsed': '0:00:01.516741'}
{'epoch': 2, 'train_loss': 0.5118, 'valid_loss': 0.5663, 'Accuracy': 0.7925, 'elapsed': '0:00:01.518346'}


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()