**Load the libraries**

In [1]:
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import PPO, SAC
import torch as th

from rllte.xplore.reward import E3B

**For on-policy RL algorithms**

In [4]:
class RLeXploreWithOnPolicyRL(BaseCallback):
    """
    A custom callback for combining RLeXplore and on-policy algorithms from SB3.
    """
    def __init__(self, irs, verbose=0):
        super(RLeXploreWithOnPolicyRL, self).__init__(verbose)
        self.irs = irs
        self.buffer = None

    def init_callback(self, model: BaseAlgorithm) -> None:
        super().init_callback(model)
        self.buffer = self.model.rollout_buffer

    def _on_step(self) -> bool:
        """
        This method will be called by the model after each call to `env.step()`.

        :return: (bool) If the callback returns False, training is aborted early.
        """
        observations = self.locals["obs_tensor"]
        device = observations.device
        actions = th.as_tensor(self.locals["actions"], device=device)
        rewards = th.as_tensor(self.locals["rewards"], device=device)
        dones = th.as_tensor(self.locals["dones"], device=device)
        next_observations = th.as_tensor(self.locals["new_obs"], device=device)

        # ===================== watch the interaction ===================== #
        self.irs.watch(observations, actions, rewards, dones, dones, next_observations)
        # ===================== watch the interaction ===================== #
        return True

    def _on_rollout_end(self) -> None:
        # ===================== compute the intrinsic rewards ===================== #
        # prepare the data samples
        obs = th.as_tensor(self.buffer.observations)
        # get the new observations
        new_obs = obs.clone()
        new_obs[:-1] = obs[1:]
        new_obs[-1] = th.as_tensor(self.locals["new_obs"])
        actions = th.as_tensor(self.buffer.actions)
        rewards = th.as_tensor(self.buffer.rewards)
        dones = th.as_tensor(self.buffer.episode_starts)
        print(obs.shape, actions.shape, rewards.shape, dones.shape, obs.shape)
        # compute the intrinsic rewards
        intrinsic_rewards = irs.compute(
            samples=dict(observations=obs, actions=actions, 
                         rewards=rewards, terminateds=dones, 
                         truncateds=dones, next_observations=new_obs),
            sync=True)
        # add the intrinsic rewards to the buffer
        self.buffer.advantages += intrinsic_rewards.cpu().numpy()
        self.buffer.returns += intrinsic_rewards.cpu().numpy()
        # ===================== compute the intrinsic rewards ===================== #

# Parallel environments
device = 'cuda'
n_envs = 4
envs = make_vec_env("Pendulum-v1", n_envs=n_envs)

# ===================== build the reward ===================== #
irs = E3B(envs, device=device)
# ===================== build the reward ===================== #

model = PPO("MlpPolicy", envs, verbose=1, device=device)
model.learn(total_timesteps=25000, callback=RLeXploreWithOnPolicyRL(irs))

Using cuda device
torch.Size([2048, 4, 3]) torch.Size([2048, 4, 1]) torch.Size([2048, 4]) torch.Size([2048, 4]) torch.Size([2048, 4, 3])
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 200       |
|    ep_rew_mean     | -1.31e+03 |
| time/              |           |
|    fps             | 1706      |
|    iterations      | 1         |
|    time_elapsed    | 4         |
|    total_timesteps | 8192      |
----------------------------------
torch.Size([2048, 4, 3]) torch.Size([2048, 4, 1]) torch.Size([2048, 4]) torch.Size([2048, 4]) torch.Size([2048, 4, 3])
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 200          |
|    ep_rew_mean          | -1.25e+03    |
| time/                   |              |
|    fps                  | 1111         |
|    iterations           | 2            |
|    time_elapsed         | 14           |
|    total_timesteps      | 16384        |
| train

<stable_baselines3.ppo.ppo.PPO at 0x1951542f3d0>

**For off-policy RL algorithms**

In [2]:
class RLeXploreWithOffPolicyRL(BaseCallback):
    """
    A custom callback for combining RLeXplore and off-policy algorithms from SB3. 
    """
    def __init__(self, irs, verbose=0):
        super(RLeXploreWithOffPolicyRL, self).__init__(verbose)
        self.irs = irs
        self.buffer = None

    def init_callback(self, model: BaseAlgorithm) -> None:
        super().init_callback(model)
        self.buffer = self.model.replay_buffer
        

    def _on_step(self) -> bool:
        """
        This method will be called by the model after each call to `env.step()`.

        :return: (bool) If the callback returns False, training is aborted early.
        """
        device = self.irs.device
        obs = th.as_tensor(self.locals['self']._last_obs, device=device)
        actions = th.as_tensor(self.locals["actions"], device=device)
        rewards = th.as_tensor(self.locals["rewards"], device=device)
        dones = th.as_tensor(self.locals["dones"], device=device)
        next_obs = th.as_tensor(self.locals["new_obs"], device=device)

        # ===================== watch the interaction ===================== #
        self.irs.watch(obs, actions, rewards, dones, dones, next_obs)
        # ===================== watch the interaction ===================== #
        
        # ===================== compute the intrinsic rewards ===================== #
        intrinsic_rewards = irs.compute(samples={'observations':obs.unsqueeze(0), 
                                            'actions':actions.unsqueeze(0), 
                                            'rewards':rewards.unsqueeze(0),
                                            'terminateds':dones.unsqueeze(0),
                                            'truncateds':dones.unsqueeze(0),
                                            'next_observations':next_obs.unsqueeze(0)}, 
                                            sync=False)
        # ===================== compute the intrinsic rewards ===================== #

        try:
            # add the intrinsic rewards to the original rewards
            self.locals['rewards'] += intrinsic_rewards.cpu().numpy().squeeze()
            # update the intrinsic reward module
            replay_data = self.buffer.sample(batch_size=self.irs.batch_size)
            self.irs.update(samples={'observations': th.as_tensor(replay_data.observations).unsqueeze(1).to(device), # (n_steps, n_envs, *obs_shape)
                                     'actions': th.as_tensor(replay_data.actions).unsqueeze(1).to(device),
                                     'rewards': th.as_tensor(replay_data.rewards).to(device),
                                     'terminateds': th.as_tensor(replay_data.dones).to(device),
                                     'truncateds': th.as_tensor(replay_data.dones).to(device),
                                     'next_observations': th.as_tensor(replay_data.next_observations).unsqueeze(1).to(device)
                                     })
        except:
            pass

        return True

    def _on_rollout_end(self) -> None:
        pass

# Parallel environments
device = 'cuda'
n_envs = 4
envs = make_vec_env("Pendulum-v1", n_envs=n_envs)

# ===================== build the reward ===================== #
irs = E3B(envs, device=device)
# ===================== build the reward ===================== #

model = SAC("MlpPolicy", envs, verbose=1, device=device)
model.learn(total_timesteps=25000, callback=RLeXploreWithOffPolicyRL(irs))