### GluonTS Callbacks
This notebook illustrates how one can control the training with GluonTS Callback's. A callback is a function which gets called at one or more specific hook points during training.
You can use predefined GluonTS callbacks like the logging callback TrainingHistory, ModelAveraging or TerminateOnNaN, or you can implement your own callback.

#### 1. Using a single Callback


In [None]:
# fetching some data
from gluonts.dataset.repository.datasets import get_dataset

dataset = "m4_hourly"
dataset = get_dataset(dataset)
prediction_length = dataset.metadata.prediction_length
freq = dataset.metadata.freq

In [None]:
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.mx.trainer import Trainer
from gluonts.mx.trainer.callback import TrainingHistory

# defining a callback, which will log the training loss for each epoch
history = TrainingHistory()

trainer=Trainer(epochs=20, callbacks=history)
estimator = SimpleFeedForwardEstimator(prediction_length=prediction_length, freq = freq, trainer=trainer)

predictor = estimator.train(dataset.train, num_workers=None)


In [None]:
# print the training loss over the epochs
print(history.loss_history)

# in case you are using a validation dataset you can get the validation loss with 
# history.validation_loss_history


#### 2. Using multiple Callbacks
To continue the training from a given predictor you can use the WarmStart Callback. When you want to use more than one callback, provide them as a list:

In [None]:
from gluonts.mx.trainer.callback import WarmStart

warm_start = WarmStart(predictor=predictor)

trainer=Trainer(epochs=10, callbacks=[history, warm_start])

estimator = SimpleFeedForwardEstimator(prediction_length=prediction_length, freq = freq, trainer=trainer)

predictor = estimator.train(dataset.train, num_workers=None)

In [None]:
print(history.loss_history) # The training loss history of all 20+10 epochs we trained the model


#### 3. Default Callbacks
In addition to the Callbacks you specify, the GluonTS Trainer uses the two default Callbacks ModelAveraging and LearningRateReduction. You can turn them off by setting add_default_callbacks=False when initializing the Trainer.

In [None]:
trainer=Trainer(epochs=20, callbacks=history) # use the TrainingHistory Callback and the default callbacks.
trainer=Trainer(epochs=20, callbacks=history, add_default_callbacks=False) # use only the TrainingHistory Callback
trainer=Trainer(epochs=20, add_default_callbacks=False) # use no callback at all

#### 4. Custom Callbacks
To implement your own Callback you can write a class which inherits from the GluonTS Callback class and overwrite one or more of the hooks.

In [None]:
# Have a look at the abstract Callback class, the hooks take different arguments which you can use.
# Hook methods with boolean return value stop the training if False is returned.

from gluonts.mx.trainer.callback import Callback
import inspect
lines = inspect.getsource(Callback)
print(lines)

In [None]:
# Here is an example implementation of a Metric value based early stopping custom callback implementation
# it only implements the hook method "on_epoch_end()"
# which gets called after all batches of one epoch have been processed 

In [None]:
from gluonts.evaluation import Evaluator
from gluonts.dataset.common import Dataset
from gluonts.mx.model.predictor import GluonPredictor
from mxnet.gluon import nn
from mxnet import gluon
import numpy as np
import mxnet as mx
from gluonts.support.util import copy_parameters

class MetricInferenceEarlyStopping(Callback):
    """
    Early Stopping mechanism based on the prediction network.
    Can be used to base the Early Stopping directly on a metric of interest, instead of on the training/validation loss.
    In the same way as test datasets are used during model evaluation,
    the time series of the validation_dataset can overlap with the train dataset time series,
    except for a prediction_length part at the end of each time series.

    Parameters
    ----------
    validation_dataset
        An out-of-sample dataset which is used to monitor metrics
    predictor
        A gluon predictor, with a prediction network that matches the training network
    evaluator
        The Evaluator used to calculate the validation metrics.
    metric
        The metric on which to base the early stopping on.
    patience
        Number of epochs to train on given the metric did not improve more than min_delta.
    min_delta
        Minimum change in the monitored metric counting as an improvement
    verbose
        Controls, if the validation metric is printed after each epoch.
    minimize_metric
        The metric objective.
    restore_best_network
        Controls, if the best model, as assessed by the validation metrics is restored after training.
    num_samples
        The amount of samples drawn to calculate the inference metrics.
    """

    def __init__(
        self,
        validation_dataset: Dataset,
        predictor: GluonPredictor,
        evaluator: Evaluator = Evaluator(num_workers=None),
        metric: str = "MSE",
        patience: int = 10,
        min_delta: float = 0.0,
        verbose: bool = True,
        minimize_metric: bool = True,
        restore_best_network: bool = True,
        num_samples: int = 100,
    ):
        assert (
            patience >= 0
        ), "EarlyStopping Callback patience needs to be >= 0"
        assert (
            min_delta >= 0
        ), "EarlyStopping Callback min_delta needs to be >= 0.0"
        assert (
            num_samples >= 1
        ), "EarlyStopping Callback num_samples needs to be >= 1"

        self.validation_dataset = list(validation_dataset)
        self.predictor = predictor
        self.evaluator = evaluator
        self.metric = metric
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.restore_best_network = restore_best_network
        self.num_samples = num_samples

        if minimize_metric:
            self.best_metric_value = np.inf
            self.is_better = np.less
        else:
            self.best_metric_value = -np.inf
            self.is_better = np.greater

        self.validation_metric_history: List[float] = []
        self.best_network = None
        self.n_stale_epochs = 0

    def on_epoch_end(
        self,
        epoch_no: int,
        epoch_loss: float,
        training_network: nn.HybridBlock,
        trainer: gluon.Trainer,
        best_epoch_info: dict,
        ctx: mx.Context
    ) -> bool:
        should_continue = True
        copy_parameters(training_network, self.predictor.prediction_net)

        from gluonts.evaluation.backtest import make_evaluation_predictions

        forecast_it, ts_it = make_evaluation_predictions(
            dataset=self.validation_dataset,
            predictor=self.predictor,
            num_samples=self.num_samples,
        )

        agg_metrics, item_metrics = self.evaluator(ts_it, forecast_it)
        current_metric_value = agg_metrics[self.metric]
        self.validation_metric_history.append(current_metric_value)

        if self.verbose:
            print(
                f"Validation metric {self.metric}: {current_metric_value}, best: {self.best_metric_value}"
            )

        if self.is_better(current_metric_value, self.best_metric_value):
            self.best_metric_value = current_metric_value

            if self.restore_best_network:
                training_network.save_parameters("best_network.params")

            self.n_stale_epochs = 0
        else:
            self.n_stale_epochs += 1
            if self.n_stale_epochs == self.patience:
                should_continue = False
                print(
                    f"EarlyStopping callback initiated stop of training at epoch {epoch_no}."
                )

                if self.restore_best_network:
                    print(
                        f"Restoring best network from epoch {epoch_no - self.patience}."
                    )
                    training_network.load_parameters("best_network.params")

        return should_continue

In [None]:
# use the custom callback

from gluonts.dataset.repository.datasets import get_dataset
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.mx.trainer import Trainer

dataset = "m4_hourly"
dataset = get_dataset(dataset)
prediction_length = dataset.metadata.prediction_length
freq = dataset.metadata.freq

estimator = SimpleFeedForwardEstimator(prediction_length=prediction_length, freq = freq)
training_network = estimator.create_training_network()
transformation = estimator.create_transformation()

predictor = estimator.create_predictor(transformation=transformation, trained_network=training_network)

es_callback = MetricInferenceEarlyStopping(validation_dataset=dataset.test, predictor=predictor, metric="MSE")

trainer = Trainer(epochs=200, callbacks=es_callback)

estimator.trainer = trainer

pred = estimator.train(dataset.train)