# training

> Training loop

In [None]:
#|default_exp training

In [None]:
#|hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from AIsaac.utils import *
from AIsaac.dataloaders import *
from AIsaac.models import *
from AIsaac.initialization import *
from AIsaac.trainer import *

from datetime import datetime, timedelta
import torchvision.transforms.functional as TF,torch.nn.functional as F
import math, time

import matplotlib.pyplot as plt
import matplotlib as mpl
import fastcore.all as fc
import torch
from torch import nn, Tensor
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
import pandas as pd , numpy as np
from torcheval.metrics import MulticlassAccuracy,Mean
from torch.optim.lr_scheduler import ExponentialLR

import dill as pickle
from fastprogress.fastprogress import master_bar, progress_bar
import inspect
import torchinfo
from accelerate import Accelerator

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#|hide
import logging

In [None]:
#|hide
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
# mpl.rcParams['image.cmap'] = 'gray'

logging.disable(logging.WARNING)

set_seed(42)


In [None]:
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)

100%|██████████| 2/2 [00:00<00:00, 345.18it/s]


In [None]:
_dataset = sample_dataset_dict(_dataset)

In [None]:
dls = DataLoaders.from_dataset_dict(_dataset, 1024, num_workers=4)

## Core

In [None]:
#| export
class OneBatchCB(Callback):
    order = 100
    def after_batch(self, learn): raise CancelFitException      

In [None]:
#| export
class BasicTrainCB(Callback):
    '''Callback for basic pytorch training loop'''
    def predict(self,trainer): trainer.preds = trainer.model(trainer.batch[0])
    def get_loss(self,trainer): trainer.loss = trainer.loss_func(trainer.preds,trainer.batch[1])
    def backward(self,trainer): trainer.loss.backward()
    def step(self,trainer): trainer.opt.step()
    def zero_grad(self,trainer): trainer.opt.zero_grad()

In [None]:
#| export
class DeviceCB(Callback):
    '''Callback to train on specific device'''
    def __init__(self, device=def_device): self.device=device
    def before_fit(self, trainer):
        '''Moves model to device'''
        if hasattr(trainer.model, 'to'): trainer.model.to(self.device)
    def before_batch(self, trainer): 
        '''moves batch to device'''
        trainer.batch = to_device(trainer.batch, device=self.device)

In [None]:
#| export
class MomentumTrainCB(BasicTrainCB):
    def __init__(self,momentum): self.momentum = momentum
    def zero_grad(self,trainer): 
        '''Multiply grads by momentum (instead of zero)'''
        with torch.no_grad():
            for p in trainer.model.parameters(): p.grad *= self.momentum

In [None]:
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  get_model_conv(), 
                  callbacks=[BasicTrainCB(), DeviceCB(),OneBatchCB()])

In [None]:
trainer.fit()

In [None]:
trainer.summarize_callbacks()

Unnamed: 0,Step,Callback,Doc String
,before_fit,DeviceCB,Moves model to device
,before_batch,DeviceCB,moves batch to device
,predict,BasicTrainCB,
,get_loss,BasicTrainCB,
,backward,BasicTrainCB,
,step,BasicTrainCB,
,zero_grad,BasicTrainCB,
,after_batch,OneBatchCB,


## Optimization

In [None]:
#| export
class BaseSchedulerCB(Callback):
    def __init__(self, scheduler_func): fc.store_attr()
    def before_fit(self, trainer): 
        '''Initializes scheduled with opt'''
        self.scheduler = self.scheduler_func(trainer.opt)
    def _step(self, trainer):
        if trainer.training: self.scheduler.step()

In [None]:
#|export        
class BatchSchedulerCB(BaseSchedulerCB):
    '''Steps scheduler'''
    def after_batch(self, trainer): self._step(trainer) 
    
class EpochSchedulerCB(BaseSchedulerCB):
    '''Steps scheduler'''
    def after_epoch(self, trainer): self._step(trainer)   

In [None]:
#| export
class OneCycleSchedulerCB(BatchSchedulerCB):
    @fc.delegates(to=torch.optim.lr_scheduler.OneCycleLR,
                  but=['optimizer','max_lr','total_steps','steps_per_epoch','epochs'])
    def __init__(self,**kwargs):
        self.scheduler_kwargs = kwargs
        self.scheduler_func =  torch.optim.lr_scheduler.OneCycleLR
    
    def before_fit(self,trainer):
        '''Initializes Scheduler'''
        total_steps = trainer.n_epochs*len(trainer.dls.train)
        self.scheduler = self.scheduler_func(trainer.opt, max_lr=trainer.lr, total_steps=total_steps,**self.scheduler_kwargs)

## Acceleration

In [None]:
#| export
class AccelerateCB(BasicTrainCB):
    order = DeviceCB.order+10
    def __init__(self, mixed_precision="fp16"):
        self.acc = Accelerator(mixed_precision=mixed_precision)
        
    def before_fit(self, trainer):
        '''Wraps model, opt, data in accelerate'''
        trainer.model,trainer.opt,trainer.dls.train,trainer.dls.valid = self.acc.prepare(
            trainer.model, trainer.opt, trainer.dls.train, trainer.dls.valid)

        
    def backward(self, trainer): 
        '''Using accelerate for backward pass'''
        self.acc.backward(trainer.loss)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
