<a href="https://colab.research.google.com/github/LondonNode/Pearl-tutorials/blob/main/7_Callbacks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pearll

# Introduction

This notebook is a tutorial for the `callbacks` module within Pearl. This allows the user to inject unique logic into the training flow to be called every time the agent steps in the environment via the `step_env()` method in the `BaseAgent`.

# Base Callback

The `BaseCallback` is the base class for other callbacks. This contains the abtract method `_on_step()` to be implemented. This should return a boolean, which if `False` will abort the training early.

In [12]:
from pearll.callbacks import BaseCallback
from pearll.common.logging_ import Logger
from pearll.models.actor_critics import ActorCritic, Dummy
from pearll.settings import LoggerSettings

from gym.spaces import Box


logger_settings = LoggerSettings()
logger = Logger(
  tensorboard_log_path=logger_settings.tensorboard_log_path,
  file_handler_level=logger_settings.file_handler_level,
  stream_handler_level=logger_settings.stream_handler_level,
  verbose=logger_settings.verbose,
  num_envs=1,
)

space = Box(-100, 100, shape=(1,))
actor = Dummy(space=space)
critic = Dummy(space=space)
model = ActorCritic(actor, critic)

class YourCallback(BaseCallback):
  def __init__(self, logger: Logger, model: ActorCritic):
    super().__init__(logger, model)

  def _on_step(self):
    return True

callback = YourCallback(logger, model)
print(f"n_calls records the number of times the main on_step method of the callback is called\nat initialization n_calls = {callback.n_calls}\n")
print(f"step records how many steps the agent has taken in the environment\nat initialization step = {callback.step}\n")

keep_training = callback.on_step(step=5)
print(f"After on_step is called\nn_calls = {callback.n_calls}\nstep = {callback.step}")


n_calls records the number of times the main on_step method of the callback is called
at initialization n_calls = 0

step records how many steps the agent has taken in the environment
at initialization step = 0

After on_step is called
n_calls = 1
step = 5


# Checkpoint Callback

For now, only the `CheckpointCallback` is implemented. This saves the model as the agent trains. Let's demonstrate with a simple DQN on CartPole.

In [18]:
from pearll.agents import DQN
from pearll.models import ActorCritic, EpsilonGreedyActor, Critic
from pearll.models.encoders import IdentityEncoder
from pearll.models.torsos import MLP
from pearll.models.heads import CategoricalHead
from pearll.settings import ExplorerSettings, Settings
from pearll.callbacks import CheckpointCallback

import gym
import os
from dataclasses import dataclass

@dataclass
class CheckpointSettings(Settings):
  save_freq: int = 10
  save_path: str = os.path.join(os.getcwd(), "checkpoints")
  name_prefix: str = "agent"

env = gym.make("CartPole-v0")
agent = DQN(
  env=env,
  model=None,
  # NOTE: multiple callbacks can be used, just add them to the list!
  callbacks=[CheckpointCallback],
  # NOTE: each callback needs its own settings object.
  callback_settings=[CheckpointSettings()],
  explorer_settings=ExplorerSettings(start_steps=0),
)
agent.fit(
  num_steps=100, batch_size=32, critic_epochs=16, train_frequency=("episode", 1)
)

Using device cpu
Saving weights to /content/checkpoints/agent_9_steps.pt
16: Log(reward=17.0, actor_loss=None, critic_loss=None, divergence=None, entropy=None)
Saving weights to /content/checkpoints/agent_19_steps.pt
29: Log(reward=13.0, actor_loss=None, critic_loss=None, divergence=None, entropy=None)
Saving weights to /content/checkpoints/agent_29_steps.pt
Saving weights to /content/checkpoints/agent_39_steps.pt
44: Log(reward=15.0, actor_loss=None, critic_loss=0.8602405413985252, divergence=None, entropy=None)
Saving weights to /content/checkpoints/agent_49_steps.pt
Saving weights to /content/checkpoints/agent_59_steps.pt
66: Log(reward=22.0, actor_loss=None, critic_loss=0.8353521041572094, divergence=None, entropy=None)
Saving weights to /content/checkpoints/agent_69_steps.pt
Saving weights to /content/checkpoints/agent_79_steps.pt
Saving weights to /content/checkpoints/agent_89_steps.pt
93: Log(reward=27.0, actor_loss=None, critic_loss=0.7781975753605366, divergence=None, entropy=