# Callbacks

The `isaacai` library is an extremely flexible framework that uses callbacks **a lot**.  They are probably more widely used than in any other framework.  It's super heavily influenced by callbacks system in the `miniai` library developed as part of the fastai course, but goes a bit further in that direction in a couple of aspects.  Because of this it's very important to understand how to use `isaacai` uses them and how you can leverage that.

## Setup

Here I will set up the needed pieces for the tutorial.  This includes imports and loading a small subset of the fashion MNIST dataset.

In [None]:
#|hide
%load_ext autoreload
%autoreload 2

In [None]:
from isaacai.all import *
import fastcore.all as fc
import matplotlib.pyplot as plt,matplotlib as mpl
import torch
from datasets import load_dataset
from torch import nn
from torcheval.metrics import MulticlassAccuracy
import torchvision.transforms.functional as TF

In [None]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'
set_seed(42)

In [None]:
xmean,xstd = 0.28, 0.35
@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)
_dataset = sample_dataset_dict(_dataset)
dls = DataLoaders.from_dataset_dict(_dataset, 64, num_workers=4)

Found cached dataset fashion_mnist (/home/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1)


  0%|          | 0/2 [00:00<?, ?it/s]

## Basic Trainer

In [None]:
trainer = Trainer(dls,
                  nn.CrossEntropyLoss(), 
                  torch.optim.Adam, 
                  SimpleNet(28*28,64,10), 
                  callbacks=[BasicTrainCB(),MetricsCB(Accuracy=MulticlassAccuracy()), DeviceCB()])
trainer.fit()

{'Accuracy': 0.5580000281333923, 'train_loss': 1.8707503662109375, 'valid_loss': 1.4094869384765625, 'epoch': 0, 'elapsed': datetime.timedelta(microseconds=238897)}
{'Accuracy': 0.6539999842643738, 'train_loss': 1.0959883117675782, 'valid_loss': 1.026219970703125, 'epoch': 1, 'elapsed': datetime.timedelta(microseconds=222187)}
{'Accuracy': 0.6819999814033508, 'train_loss': 0.793328239440918, 'valid_loss': 0.8982598266601562, 'epoch': 2, 'elapsed': datetime.timedelta(microseconds=178892)}


So we passed in a `DataLoaders`, a pytorch loss, a pytorch optimizer, a pytorch model, and some callbacks.  As you can see by running `Trainer.fit` it ran a full training loop.  **The training loop is defined entirely in the callbacks**.  For this tutorial we are focusing on the callbacks.  Please refer to pytorch documentation for the pytorch pieces.

### One batch

Let's see how a batch is processed.  The source code for the batch trainer is very small and there's two things we need to understand about it, the decorator and the run_callbacks method.

```python
@with_cbs('batch', CancelBatchException)
def one_batch(self):
    self.run_callbacks(['predict','get_loss'])
    if self.training: self.run_callbacks(['before_backward','backward','step','zero_grad'
```

#### `run_callbacks`

The `run_callbacks` method is what actually executes the callbacks code.  As you can see a batch is just all callbacks. 

The first `run_callbacks` does the following:

::: {.callout-tip}

##### run_callbacks pseudo code
+ Sorts all callbacks according to the "order" attribute (defaults to 0)
+ Loops through `['predict','get_loss']`
    + Loops through ordered callbacks:
        + If "predict" method exists for that callback then run it

:::

Let's look at the `BasicTrainCB` code.  As you can see, each element needed to process the batch is here.  "before_backward" is not defined so this callback won't do anything in that step.  We could however define a callback that happens before the backward pass if we want to add functionality to our trainer.

```python
class BasicTrainCB:
    def predict(self,trainer): trainer.preds = trainer.model(trainer.batch[0])
    def get_loss(self,trainer): trainer.loss = trainer.loss_func(trainer.preds,trainer.batch[1])
    def backward(self,trainer): trainer.loss.backward()
    def step(self,trainer): trainer.opt.step()
    def zero_grad(self,trainer): trainer.opt.zero_grad()
```

#### `with_cbs`

With cbs adds two pieces of functionality.

+ Ability to exit and skip the rest of the function (ie the batch).  This is similar to how you can use `continue` in a for loop.  This can be done with raising the particular `exception`.
+ Adds before, after, and cleanup callbacks to the function.  Before and after run before and after the function.  Cleanup will always run, even if an exception is thrown.

::: {.callout-tip}
##### `one_batch` example

+ To run a callback before or after every batch, you would use `before_batch` and `after_batch`
+ To skip a batch, you would raise a `CancelBatchException` in a callback as that's what is passed to the decorator.
+ The `cleanup_batch` callback will always run if one exists, even if you skipped the batch.  `after_batch` will be skipped once the `CancelBatchException` is raised.

:::

```python
class with_cbs:
    def __init__(self, nm, exception): fc.store_attr()
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.run_callbacks(f'before_{self.nm}')
                f(o, *args, **kwargs)
                o.run_callbacks(f'after_{self.nm}')
            except self.exception: pass
            finally: o.run_callbacks(f'cleanup_{self.nm}')
        return _f
```

### Other Callbacks

Another example of a callback is the device callback, that puts things onto whatever device we want (ie GPU)

```python
class DeviceCB:
    def __init__(self, device=def_device): self.def_device=def_device
    def before_fit(self, trainer):
        if hasattr(trainer.model, 'to'): trainer.model.to(self.device)
    def before_batch(self, trainer): trainer.batch = to_device(trainer.batch, device=self.device)
```    

In addition, the `MetricsCB` in the example above is responsible for calculating and tracking the losses, the metrics it's initalized with, and logging that every epoch.

**All functionality that is done in the training loop is managed through callbacks.**

Epochs work similarly to batches with callbacks, and there is also a fit method which also executes callbacks in the same way.  

::: {.callout-tip}
##### Available Callback List

+ Batch callbacks
    + before_batch
    + predict
    + get_loss
    + before_backward
    + backward
    + step
    + zero_grad
    + after_batch
    + cleanup_batch
+ Epoch callbacks
    + before_epoch
    + after_epoch
    + cleanup_epoch
+ Fit callbacks
    + before_fit
    + after_fit
    + cleanup_fit
:::

::: {.callout-tip}

##### Cancel Exceptions

+ CancelBatchException
+ CancelEpochException
+ CancelFitException
:::

## Inheritance

Now that we know how to modify and extend the training loop, the next thing we want to do is subclass.  It would be really annoying if we had to remember too pass in the right combination of callbacks every time!  As an example, let's create a MomentumTrainer that has momentum using a GPU if it's available.  There's 2 steps for that.  

1. Create a `MomentumTrainCB` similar to the `BasicTrainCB` above but implements momentum
1. Create a `MomentumTrainer` similar to `Trainer` above that uses the appropriate callbacks without us having to add them in
1. Update `__init__` function signature so we get information on any new parameters for those callbacks

### MomentumTrainCB

We can inherit from the `BasicTrainCB` about because momentum is mostly the same as a normal training loop with one small tweak that allows previous gradients to be accounted for.

```python
class MomentumTrainCB(BasicTrainCB):
    def __init__(self,momentum): self.momentum = momentum
    def zero_grad(self,trainer): 
        with torch.no_grad():
            for p in trainer.model.parameters(): p.grad *= self.mom
```

### MomentumTrainer

Next we need to subclass `Trainer` and add the `MomentumTrainCB` and `DeviceCB` to the subclass.  For this we have a special class called `subclassing_method` that will run any arbitrary code on init, so we can add any callbacks (or anything else) there.

```python
@init_delegates()
class MomentumTrainer(Trainer):
    def subclassing_method(self,momentum=0.85,**kwargs):
        super().subclassing_method(**kwargs)
        self.add_callbacks([MomentumTrainCB(momentum),DeviceCB(),TrackingCB()])
```

The main problem with in this way is the `__init__` won't normally aware of the new argument `momentum`, so that argument will show up as `**kwargs` instead of showing the actual function.  While it wouldn't be difficult to look at the subclassing method for the arguments, what if you subclass MomentumTrainer?  You'd then need to go to the subclassing method of the parent class too. 

It's better if the init function signature said what all the arguments are all the way down the chain (subclassing method too) rather than having kwargs.  The `init_delegates` decorator handles that for us, as you can see below (momentum is added to the init signature).

In [None]:
import inspect
print(f"        Trainer.Init:   {inspect.signature(Trainer.__init__)}")
print(f"MomentumTrainer.Init:   {inspect.signature(MomentumTrainer.__init__)}")

        Trainer.Init:   (self, dls, loss_func, opt_func, model, callbacks, **kwargs)
MomentumTrainer.Init:   (self, dls, loss_func, opt_func, model, callbacks, momentum=0.85)


:::{.callout-important}
##### `init_delegates` limitations

There are two known limitations of `init_delegates` currently.  I believe these aren't too hard to solve, but have not gotten to it yet.  If you want to see it fixed let me know and i'll jump on it.

1. The subclassing_method arguments must have defaults
1. The signature is not updating annotations

:::