# trainer

> Training loop

In [None]:
#|default_exp trainer

In [None]:
#|hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from AIsaac.utils import *
from AIsaac.dataloaders import *
from AIsaac.models import *
from AIsaac.initialization import *

from datetime import datetime, timedelta
import torchvision.transforms.functional as TF,torch.nn.functional as F
import math, time

import matplotlib.pyplot as plt
import matplotlib as mpl
import fastcore.all as fc
import torch
from torch import nn, Tensor
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
import pandas as pd , numpy as np
from torcheval.metrics import MulticlassAccuracy,Mean
from torch.optim.lr_scheduler import ExponentialLR

import dill as pickle
from fastprogress.fastprogress import master_bar, progress_bar
import inspect
import torchinfo


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#|hide
import logging

In [None]:
#|hide
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
# mpl.rcParams['image.cmap'] = 'gray'

logging.disable(logging.WARNING)

set_seed(42)


In [None]:
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)

100%|██████████| 2/2 [00:00<00:00, 354.29it/s]


In [None]:
_dataset = sample_dataset_dict(_dataset)

In [None]:
dls = DataLoaders.from_dataset_dict(_dataset, 1024, num_workers=4)

## Base Trainer

In [None]:
#| export
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

In [None]:
#| export
class Callback: order=0

In [None]:
#| export
cb_steps = ['before_fit','before_epoch', 'before_batch','predict','get_loss','before_backward','backward', 'step',
            'zero_grad','after_batch','cleanup_batch','after_epoch','cleanup_epoch','after_fit','cleanup_fit']

In [None]:
#| export
def summarize_callbacks(trainer):    
    res = pd.DataFrame(columns=['Step','Callback','Doc String'])
    callbacks = [getattr(trainer,o) for o in trainer.callbacks]
    for attr in cb_steps: 
        for callback in sorted(callbacks, key=lambda x: getattr(x, 'order')):
            callback_name = callback.__class__.__name__
            callback = getattr(trainer,callback_name)
            if getattr(callback,attr,None) is not None:
                docstring = getattr(callback,attr).__doc__
                row = pd.DataFrame([[attr,callback_name,fc.ifnone(docstring,'')],],columns=res.columns,index=[''])
                res = pd.concat([res,row])
    return res

In [None]:
def summarize_model(model,batch,row_settings=("var_names",),verbose=0,depth=3,col_names=("input_size","output_size","kernel_size")):
    
    # Other useful columns: "num_params","mult_adds"
    return torchinfo.summary(model,input_data=batch,row_settings=row_settings,verbose=verbose,depth=depth,col_names=col_names)

In [None]:
#| export 
class Trainer:
    def __init__(self, dls, loss_func, opt_func, model, callbacks):
        self.add_callbacks(callbacks)
        fc.store_attr(but='callbacks')
            
    @with_cbs('batch', CancelBatchException)
    def one_batch(self):
        self.run_callbacks(['predict','get_loss'])
        if self.training: self.run_callbacks(['before_backward','backward','step','zero_grad'])
    
    @with_cbs('epoch',CancelEpochException)
    def _one_epoch(self):
        for self.batch_num,self.batch in zip(self.batches,self.dl): self.one_batch()

    def one_epoch(self, training):
        self.model.train(training)
        self.dl = self.dls.train if training else self.dls.valid
        self.batches = range(len(self.dl))
        self._one_epoch()

    @with_cbs('fit', CancelFitException)
    def _fit(self,train, valid):
        for self.epoch in self.epochs: 
            if train: self.one_epoch(True)
            if valid: torch.no_grad()(self.one_epoch)(False)
    
    def fit(self, n_epochs=3, lr=1e-3, callbacks=None,train=True,valid=True):
        fc.store_attr('n_epochs,lr')
        try:
            self.add_callbacks(fc.L(callbacks))
            self.opt = self.opt_func(self.model.parameters(), self.lr)
            self.epochs = range(self.n_epochs)
            self._fit(train,valid)
        finally:
            self.callbacks = [o for o in self.callbacks if o not in [o.__class__.__name__ for o in callbacks]]
                                                        
    @property
    def training(self): return self.model.training
    
    def add_callbacks(self,callbacks,force=False): add_callbacks(self,callbacks,force)

    def run_callbacks(self,method_names):
        cbs = [getattr(self,o) for o in self.callbacks]
        for method_name in fc.L(method_names): run_callbacks(cbs,method_name,self)
            
    def summarize_model(self): return summarize_model(model=self.model,batch=fc.first(self.dls.train)[0])
    def summarize_callbacks(self): return summarize_callbacks(self)

In [None]:
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  get_model_conv(),
                  callbacks=[])

In [None]:
trainer.summarize_model()

Layer (type (var_name))                  Input Shape               Output Shape              Kernel Shape
Sequential (Sequential)                  [500, 1, 28, 28]          [500, 10]                 --
├─Sequential (0)                         [500, 1, 28, 28]          [500, 8, 14, 14]          --
│    └─Conv2d (0)                        [500, 1, 28, 28]          [500, 8, 14, 14]          [3, 3]
│    └─ReLU (1)                          [500, 8, 14, 14]          [500, 8, 14, 14]          --
├─Sequential (1)                         [500, 8, 14, 14]          [500, 16, 7, 7]           --
│    └─Conv2d (0)                        [500, 8, 14, 14]          [500, 16, 7, 7]           [3, 3]
│    └─ReLU (1)                          [500, 16, 7, 7]           [500, 16, 7, 7]           --
├─Sequential (2)                         [500, 16, 7, 7]           [500, 32, 4, 4]           --
│    └─Conv2d (0)                        [500, 16, 7, 7]           [500, 32, 4, 4]           [3, 3]
│    └─ReLU (1)   

## Trainer Summaries

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()