# Callbacks

The `isaacai` library is an extremely flexible framework that uses callbacks **a lot**.  They are probably more widely used than in any other framework.  Because of this it's very important to understand how to use `isaacai` uses them and how you can leverage that.

## Setup

Here I will set up the needed pieces for the tutorial.  This includes imports and loading a small subset of the fashion MNIST dataset.

In [None]:
#|hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
from isaacai.all import *
import matplotlib.pyplot as plt,matplotlib as mpl
import torch
from datasets import load_dataset
from torch import nn
from torcheval.metrics import MulticlassAccuracy
import torchvision.transforms.functional as TF

In [None]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'
set_seed(42)

In [None]:
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)
_dataset = sample_dataset_dict(_dataset)
dls = DataLoaders.from_dataset_dict(_dataset, 64, num_workers=4)

Found cached dataset fashion_mnist (/home/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1)


  0%|          | 0/2 [00:00<?, ?it/s]

## Very basic Trainer

In [None]:
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  SimpleNet(28*28,64,10), 
                  callbacks=[BasicTrainCB(),MetricsCB(Accuracy=MulticlassAccuracy()), DeviceCB()])
trainer.fit()

{'Accuracy': 0.5580000281333923, 'train_loss': 1.8707503662109375, 'valid_loss': 1.4094869384765625, 'epoch': 0, 'elapsed': datetime.timedelta(microseconds=237908)}
{'Accuracy': 0.6539999842643738, 'train_loss': 1.0959883117675782, 'valid_loss': 1.026219970703125, 'epoch': 1, 'elapsed': datetime.timedelta(microseconds=186135)}
{'Accuracy': 0.6819999814033508, 'train_loss': 0.793328239440918, 'valid_loss': 0.8982598266601562, 'epoch': 2, 'elapsed': datetime.timedelta(microseconds=181564)}


So we passed in a `DataLoaders`, a pytorch loss, a pytorch optimizer, a pytorch model, and some callbacks.  As you can see by running `Trainer.fit` it ran a full training loop.  **The training loop is defined entirely in the callbacks**.  For this tutorial we are focusing on the callbacks.  Please refer to pytorch documentation for the pytorch pieces.

### Training Loop