# Callbacks

The `isaacai` library is an extremely flexible framework that uses callbacks **a lot**.  They are probably more widely used than in any other framework.  It's super heavily influenced by callbacks system in the `miniai` library developed as part of the fastai course, but goes a bit further in that direction in a couple of aspects.  Because of this it's very important to understand how to use `isaacai` uses them and how you can leverage that.

## Setup

Here I will set up the needed pieces for the tutorial.  This includes imports and loading a small subset of the fashion MNIST dataset.

In [None]:
#|hide
%load_ext autoreload
%autoreload 2

In [None]:
#|hide
import inspect
from subprocess import check_output
from IPython.core.display import HTML

def view_source_code(f):
    output = check_output(["pygmentize","-f","html","-O","full,style=emacs","-l","python"],
        input=inspect.getsource(f), encoding='ascii')
    display(HTML(output))

In [None]:
from isaacai.all import *
import fastcore.all as fc
import matplotlib.pyplot as plt,matplotlib as mpl
import torch
from datasets import load_dataset
from torch import nn
from torcheval.metrics import MulticlassAccuracy
import torchvision.transforms.functional as TF

In [None]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'
set_seed(42)

In [None]:
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)
dls = DataLoaders.from_dataset_dict(_dataset, 64, num_workers=4)

Found cached dataset fashion_mnist (/home/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1)


  0%|          | 0/2 [00:00<?, ?it/s]

## Basic Trainer

In [None]:
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  model = get_model_conv(),
                  callbacks=[BasicTrainCB(),MetricsCB(Accuracy=MulticlassAccuracy()), DeviceCB(),ProgressCB()])
trainer.fit()

Unnamed: 0,train,valid,Accuracy
0,0.601941,0.436749,0.8439


Unnamed: 0,train,valid,Accuracy
1,0.389781,0.381396,0.8584


Unnamed: 0,train,valid,Accuracy
2,0.339881,0.348775,0.8709


So we passed in a `DataLoaders`, a pytorch loss, a pytorch optimizer, a pytorch model, and some callbacks.  As you can see by running `Trainer.fit` it ran a full training loop.  **The training loop is defined entirely in the callbacks**.  For this tutorial we are focusing on the callbacks.  Please refer to pytorch documentation for the pytorch pieces.

### One batch

Let's see how a batch is processed.  The source code for the batch trainer is very small and there's two things we need to understand about it, the decorator and the run_callbacks method.

```python
@with_cbs('batch', CancelBatchException)
def one_batch(self):
    self.run_callbacks(['predict','get_loss'])
    if self.training: self.run_callbacks(['before_backward','backward','step','zero_grad'
```

#### `run_callbacks`

The `run_callbacks` method is what actually executes the callbacks code.  As you can see a batch is just all callbacks. 

The first `run_callbacks` does the following:

::: {.callout-tip}

##### run_callbacks pseudo code
+ Sorts all callbacks according to the "order" attribute (defaults to 0)
+ Loops through `['predict','get_loss']`
    + Loops through ordered callbacks:
        + If "predict" method exists for that callback then run it

:::

Let's look at the `BasicTrainCB` code.  Each element needed to process the batch is here.  "before_backward" is not defined so this callback won't do anything in that step.  We could however define a callback that happens before the backward pass if we want to add functionality there to our training loop.

In [None]:
view_source_code(BasicTrainCB)

#### `with_cbs`

With cbs adds two pieces of functionality.

+ Ability to exit and skip the rest of the function (ie the batch).  This is similar to how you can use `continue` in a for loop.  This can be done with raising the particular `exception`.
+ Adds before, after, and cleanup callbacks to the function.  Before and after run before and after the function.  Cleanup will always run, even if an exception is thrown.

::: {.callout-tip}
##### `one_batch` example

+ To run a callback before or after every batch, you would use `before_batch` and `after_batch`
+ To skip a batch, you would raise a `CancelBatchException` in a callback as that's what is passed to the decorator.
+ The `cleanup_batch` callback will always run if one exists, even if you skipped the batch.  `after_batch` will be skipped once the `CancelBatchException` is raised.

:::

In [None]:
view_source_code(with_cbs)

### Other Callbacks

Another example of a callback is the device callback, that puts things onto whatever device we want (ie GPU)

In [None]:
view_source_code(DeviceCB)

In [None]:
view_source_code(MetricsCB)

In addition, the `MetricsCB` in the example above is responsible for calculating and tracking the losses, the metrics it's initalized with, and logging that every epoch.

**All functionality that is done in the training loop is managed through callbacks.**

Epochs work similarly to batches with callbacks, and there is also a fit method which also executes callbacks in the same way.  

::: {.callout-tip}
##### Available Callback List

+ Batch callbacks
    + before_batch
    + predict
    + get_loss
    + before_backward
    + backward
    + step
    + zero_grad
    + after_batch
    + cleanup_batch
+ Epoch callbacks
    + before_epoch
    + after_epoch
    + cleanup_epoch
+ Fit callbacks
    + before_fit
    + after_fit
    + cleanup_fit
:::

::: {.callout-tip}

##### Cancel Exceptions

+ CancelBatchException
+ CancelEpochException
+ CancelFitException
:::

### Callback Subclassing/Inheritance

We can inherit from the `BasicTrainCB` because momentum is mostly the same as a normal training loop with one small tweak that allows previous gradients to be accounted for.  In this way we can build callbacks from other similar callbacks.

Rather than subclassing the Trainer, we subclass callbacks.

In [None]:
view_source_code(MomentumTrainCB)

## Multiple Callbacks

Now that we know how to modify and extend the training loop with individual callbacks, one next logical question is how to we create abstractions with this.  For example, we probably don't want to add `DeviceCB`, `MetricsCB`, and `BasicTrainCB` to every Trainer we create as lots of Trainers will use those.  As we build more complex models we may want combinations of callbacks as well that are commonly used together, rather than having to memorize lots of callback recipes.

To do this, we create a recursive call when we add callbacks that allows us to group callbacks together.  Instead of subclassing the `Trainer` in a way that may be more common in other frameworks, we group the callbacks together.  We create these callbacks that are a combination of other callbacks, by defining the `callbacks` attribute in a callback.

:::{.callout-tip}
##### Callbacks Attribute

Adding callbacks works recursively.  Once a callback is added, it will check for a `callbacks` attribute and add those `callbacks`.  Those in turn could have `callbacks` attributes of their own.
:::

Here is an example of a group of callbacks that will likely go together.  This simple class will add all of these callbacks when used.  While this class is not a callback itself because it does not have a callback method (ie `before_batch`), you have the flexibility to add those methods to this class to add behavior to your trainer specific to this grouping of callbacks.  When passed as a Callback it will add the 5 callbacks stores in `self.callbacks`.  It could also be a callback itsel

In [None]:
view_source_code(CoreCBs)

In [None]:
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  get_model_conv(),
                  callbacks=[CoreCBs(Accuracy=MulticlassAccuracy())])

In [None]:
trainer.fit()

Unnamed: 0,train,valid,Accuracy
0,0.599093,0.452553,0.8334


Unnamed: 0,train,valid,Accuracy
1,0.381516,0.377897,0.862


Unnamed: 0,train,valid,Accuracy
2,0.332677,0.343139,0.8763
