# training

> Training loop

In [1]:
#|default_exp training

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
from isaacai.utils import *
from isaacai.dataloaders import *
from isaacai.models import *
from datetime import datetime

from matplotlib import pyplot as plt
from fastcore.all import *
import torch
from torch import nn
from torch import Tensor
from datasets import load_dataset
from torch.utils.data import DataLoader
import pandas as pd 
import numpy as np
from datasets import Dataset
from torcheval.metrics import MulticlassAccuracy,Mean

In [4]:
dls = DataLoaders(*load_fashion_mnist(batch_size=128))
dls.x_name,dls.y_name = 'image','label'

Found cached dataset fashion_mnist (/home/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1)


  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
#| export
def run_callbacks(callbacks, method_name, trainer=None):
    for callback in sorted(callbacks, key=lambda x: getattr(x, 'order',0)):
        callback_method = getattr(callback, method_name,None)
        if callback_method is not None: callback_method(trainer)

In [17]:
#| export
class ProgressCB:
    def __init__(self, precision=4, **metrics):
        store_attr(names=['precision'])
        self.metrics = metrics
        self.loss_train, self.loss_valid = Mean(), Mean()
        self.stats_epoch = L()
        
    def log(self,x): print(x)
        
    def after_batch(self,trainer):
        # Collect loss, metrics and store
        if trainer.training: self.loss_train.update(to_cpu(trainer.loss.detach()))
        else: 
            self.loss_valid.update(to_cpu(trainer.loss.detach()))
            for name, metric in self.metrics.items():
                self.metrics[name].update(to_cpu(trainer.preds.detach()),to_cpu(trainer.batch[trainer.dls.y_name]))
            
    def before_epoch(self,trainer): self.st = datetime.now()
    def after_epoch(self,trainer):

        # compute metrics and append to epoch stats and display
        _stats = {'epoch':trainer.epoch}
        _stats.update({'train_loss':round(float(self.loss_train.compute()),self.precision),
                  'valid_loss':round(float(self.loss_valid.compute()),self.precision)})
        _stats.update({name:round(float(metric.compute()),self.precision) for name, metric in self.metrics.items()})
        _stats.update({'elapsed':str(datetime.now() - self.st)})

        self.stats_epoch.append(_stats)
        self.loss_train.reset(); self.loss_valid.reset(); [metric.reset() for _,metric in self.metrics.items()];

        self.log(_stats)


In [18]:
#| export
class DeviceCB:
    def __init__(self, device=def_device): store_attr()
    def before_fit(self, trainer):
        if hasattr(trainer.model, 'to'): trainer.model.to(self.device)
    def before_batch(self, trainer): 
        trainer.batch = to_device(trainer.batch, device=self.device)

In [19]:
#| export 
class Trainer:
    def __init__(self, dls, loss_func, opt_func, model, callbacks):
        store_attr()

    def one_batch(self):
        self.run_callbacks('before_batch')
        self.preds = self.model(self.batch[self.dls.x_name])
        self.loss = self.loss_func(self.preds, self.batch[self.dls.y_name])
        if self.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
        self.run_callbacks('after_batch')

    def one_epoch(self):
        self.run_callbacks('before_epoch')
        
        self.model.train()
        self.run_callbacks('before_train')
        for self.batch in self.dls.train: self.one_batch()
        self.run_callbacks('after_train')

        self.model.eval()
        self.run_callbacks('before_valid')
        for self.batch in self.dls.valid: self.one_batch()
        self.run_callbacks('after_valid')
        
        self.run_callbacks('after_epoch')

    def fit(self, epochs=3, lr=1e-3):
        self.run_callbacks('before_fit')
        self.opt = self.opt_func(self.model.parameters(), lr)
        for self.epoch in range(epochs): self.one_epoch()
        self.run_callbacks('after_fit')

    @property
    def training(self): return self.model.training

    def run_callbacks(self,method_name): run_callbacks(self.callbacks,method_name,self)

In [20]:
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  SimpleNet(784,64,10), 
                  callbacks=[ProgressCB(Accuracy=MulticlassAccuracy()), DeviceCB()])

In [21]:
trainer.fit()

{'epoch': 0, 'train_loss': 0.6705, 'valid_loss': 0.4896, 'Accuracy': 0.8294, 'elapsed': '0:00:20.304069'}
{'epoch': 1, 'train_loss': 0.4605, 'valid_loss': 0.4393, 'Accuracy': 0.848, 'elapsed': '0:00:20.163087'}
{'epoch': 2, 'train_loss': 0.4249, 'valid_loss': 0.4177, 'Accuracy': 0.8529, 'elapsed': '0:00:20.070252'}


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()