# Callbacks usage

When using discrete-optimization to solve a problem, 
it is possible to execute your own code at various stage of the solving process.


To achieve that, you have to
-  either create your own callback by inheriting from [Callback](https://airbus.github.io/discrete-optimization/master/api/discrete_optimization.generic_tools.callbacks.html#discrete_optimization.generic_tools.callbacks.callback.Callback) base class,
  and implementing the hooks you need
- or directly use one of the already implemented ones available in discrete_optimization.generic_tools.callbacks submodules
  (like
  [loggers](https://airbus.github.io/discrete-optimization/master/api/discrete_optimization.generic_tools.callbacks.html#module-discrete_optimization.generic_tools.callbacks.loggers),
  [early_stoppers](https://airbus.github.io/discrete-optimization/master/api/discrete_optimization.generic_tools.callbacks.html#module-discrete_optimization.generic_tools.callbacks.early_stoppers),
  or [optuna](https://airbus.github.io/discrete-optimization/master/api/discrete_optimization.generic_tools.callbacks.html#module-discrete_optimization.generic_tools.callbacks.optuna)
  )
- and put them in `callbacks` argument of `SolverDO.solve()`, as shown in the [API doc](https://airbus.github.io/discrete-optimization/master/api/discrete_optimization.generic_tools.html#discrete_optimization.generic_tools.do_solver.SolverDO.solve).

The main usecases for using a callback are
- Logging: you need to display more information about what happens during the solving process;
- Backuping: you need to store a model at an intermediate stage;
- Early stopping: you want to stop the solving process under your own specific condition, not available in the solver api;
- [Tuning hyperparameters with Optuna](./optuna.ipynb): you want let Optuna having access to intermediate results so that it can decide whether to drop the current trial.
  (See dedicated notebook.)


Here we are using the knapsack problem, already presented in a [dedicated notebook](../Knapsack%20tutorial.ipynb). For details about the problem and the solvers used, please refer to it.

## Prerequisites

Concerning the python kernel to use for this notebook:
- If running locally, be sure to use an environment with discrete-optimization and minizinc.
- If running on colab, the next cell does it for you.
- If running on binder, the environment should be ready.


In [None]:
# On Colab: install the library
on_colab = "google.colab" in str(get_ipython())
if on_colab:
    import importlib
    import os
    import sys  # noqa: avoid having this import removed by pycln

    !{sys.executable} -m pip install -U pip

    # uninstall google protobuf conflicting with ray and sb3
    ! pip uninstall -y protobuf

    # install dev version for dev doc, or release version for release doc
    !{sys.executable} -m pip install git+https://github.com/airbus/discrete-optimization@master#egg=discrete-optimization

    # be sure to load the proper cffi (downgraded compared to the one initially on colab)
    import cffi

    importlib.reload(cffi)

    # install and configure minizinc
    !curl -o minizinc.AppImage -L https://github.com/MiniZinc/MiniZincIDE/releases/download/2.6.3/MiniZincIDE-2.6.3-x86_64.AppImage
    !chmod +x minizinc.AppImage
    !./minizinc.AppImage --appimage-extract
    os.environ["PATH"] = f"{os.getcwd()}/squashfs-root/usr/bin/:{os.environ['PATH']}"
    os.environ["LD_LIBRARY_PATH"] = (
        f"{os.getcwd()}/squashfs-root/usr/lib/:{os.environ['LD_LIBRARY_PATH']}"
    )

### Imports

In [None]:
from __future__ import annotations

import logging
import random

import nest_asyncio
import numpy as np

from discrete_optimization.datasets import fetch_data_from_coursera
from discrete_optimization.generic_tools.callbacks.callback import Callback
from discrete_optimization.generic_tools.callbacks.early_stoppers import TimerStopper
from discrete_optimization.generic_tools.callbacks.loggers import ObjectiveLogger
from discrete_optimization.generic_tools.cp_tools import CPSolverName, ParametersCP
from discrete_optimization.generic_tools.do_problem import get_default_objective_setup
from discrete_optimization.generic_tools.lns_cp import LNS_CP
from discrete_optimization.knapsack.knapsack_parser import (
    get_data_available,
    parse_file,
)
from discrete_optimization.knapsack.solvers.cp_solvers import CPKnapsackMZN2
from discrete_optimization.knapsack.solvers.knapsack_lns_cp_solver import (
    ConstraintHandlerKnapsack,
)
from discrete_optimization.knapsack.solvers.knapsack_lns_solver import (
    InitialKnapsackMethod,
    InitialKnapsackSolution,
)

# patch asyncio so that applications using async functions can run in jupyter
nest_asyncio.apply()

# set logging level
logging.basicConfig(level=logging.WARNING, format="%(asctime)s:%(message)s")

### Download datasets

If not yet available, we import the datasets from [coursera](https://github.com/discreteoptimization/assignment).

In [None]:
needed_datasets = ["ks_500_0"]
download_needed = False
try:
    files_available_paths = get_data_available()
    for dataset in needed_datasets:
        if len([f for f in files_available_paths if dataset in f]) == 0:
            download_needed = True
            break
except:
    download_needed = True

if download_needed:
    fetch_data_from_coursera()

We will use the dataset [ks_500_0](https://github.com/discreteoptimization/assignment/blob/master/knapsack/data/ks_500_0) where we have 500 items at hand to put in the knapsack.

In [None]:
files_available_paths = get_data_available()
model_file = [f for f in files_available_paths if "ks_500_0" in f][0]
model = parse_file(model_file, force_recompute_values=True)
print(type(model))

### Set random seed

If reproducible results are wanted, we can fix the random seed.

In [None]:
def set_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)


set_random_seed()

## Using existing callbacks

We first show how to plug already existing callbacks.

We initialize a solver where we will plug the callbacks.

In [None]:
set_random_seed()
params_objective_function = get_default_objective_setup(problem=model)
params_cp = ParametersCP.default()
params_cp.time_limit = 5
params_cp.time_limit_iter0 = 5
nb_iteration_lns = 5

# Base CP solver.
cp_solver = CPKnapsackMZN2(
    model,
    cp_solver_name=CPSolverName.CHUFFED,
    params_objective_function=params_objective_function,
)

# initial solution: DUMMY corresponds to a starting solution filled with 0!
initial_solution_provider = InitialKnapsackSolution(
    problem=model,
    initial_method=InitialKnapsackMethod.DUMMY,
    params_objective_function=params_objective_function,
)

# constraint handler: will fix 80% of variables to current solution.
constraint_handler = ConstraintHandlerKnapsack(problem=model, fraction_to_fix=0.8)

# LNS Solver.
lns_solver = LNS_CP(
    problem=model,
    cp_solver=cp_solver,
    initial_solution_provider=initial_solution_provider,
    constraint_handler=constraint_handler,
    params_objective_function=params_objective_function,
)

### Logger

In this first example, we add a callback that will track the current iteration and display the current objective.
We set the logging level to warning, because the lns solver used here produced already a lot of log at info level 
and we wish to show you only the information displayed by our callback.

To plug the callback, we only need to specify it in `solve()` `callbacks` argument.

In [None]:
# callbacks
tracker = ObjectiveLogger(
    step_verbosity_level=logging.WARNING, end_verbosity_level=logging.WARNING
)

# solve
set_random_seed()
result_lns = lns_solver.solve(
    parameters_cp=params_cp, nb_iteration_lns=nb_iteration_lns, callbacks=tracker
)

Note that we have here 6 iterations logged even though we set `nb_iteration_lns = 5`.

This is because the callback is called already at the end of the initial CP solve, 
which is not counted for `nb_iteration_lns` by `LNS_CP` solver.

### Timer
Here we use a callback that can stop the solving process after a given timeout. 

More precisely, as it is called only at the end of an iteration, 
this will stop whenever an iteration finishes after the given elapsed time since the start of the solving process.

As we want to plug several callbacks, we need to pass them as a list.

In [None]:
# callbacks
tracker = ObjectiveLogger(
    step_verbosity_level=logging.WARNING, end_verbosity_level=logging.WARNING
)
timer = TimerStopper(total_seconds=10)
callbacks = [tracker, timer]

# solve
set_random_seed()
result_lns = lns_solver.solve(
    parameters_cp=params_cp, nb_iteration_lns=nb_iteration_lns, callbacks=callbacks
)

## Implementing its own callback

You can implement your own callback to display specific information at each step or stops the solving process at your own particular condition.

First you need to derive from the base class `Callback`, and then implement one, or several methods depending on 
which points you need to hook to execute your code:

- start of solve: `on_solve_start()`
- end of solve: `on_solve_end()`
- end of a step in the optimization process: `on_step_end()`, and you need to return `True` to stop the solving process



In [None]:
class MyCallback(Callback):
    """Custom callback

    We print the number of items taken at each iteration
    and stop whenever at least 15 items have been taken.

    """

    def on_step_end(
        self, step: int, res: ResultStorage, solver: SolverDO
    ) -> Optional[bool]:
        """Called at the end of an optimization step.

        Args:
            step: index of step
            res: current result storage
            solver: solvers using the callback

        Returns:
            If `True`, the optimization process is stopped, else it goes on.

        """
        sol, fit = res.get_best_solution_fit()
        nb_items = sum(sol.list_taken)
        print(f"Number of items: {nb_items}")

        # we stop if at least 15 items have been taken
        stopping = nb_items >= 15

        return stopping

In [None]:
# callbacks
tracker = ObjectiveLogger(
    step_verbosity_level=logging.WARNING, end_verbosity_level=logging.WARNING
)
timer = TimerStopper(total_seconds=30)
mycallback = MyCallback()
callbacks = [tracker, timer, mycallback]

# solve
set_random_seed()
result_lns = lns_solver.solve(
    parameters_cp=params_cp, nb_iteration_lns=nb_iteration_lns, callbacks=callbacks
)