In [97]:
import numpy as np
import torch as th
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback

In [102]:
class CustomCallback(BaseCallback):
    """
    A custom callback that derives from ``BaseCallback``.

    :param verbose: (int) Verbosity level 0: not output 1: info 2: debug
    """
    def __init__(self, verbose=0, display_rollout=False):
        super(CustomCallback, self).__init__(verbose)
        self.display_rollout = display_rollout
        # Those variables will be accessible in the callback
        # (they are defined in the base class)
        # The RL model
        # self.model = None  # type: BaseRLModel
        # An alias for self.model.get_env(), the environment used for training
        # self.training_env = None  # type: Union[gym.Env, VecEnv, None]
        # Number of time the callback was called
        # self.n_calls = 0  # type: int
        # self.num_timesteps = 0  # type: int
        # local and global variables
        # self.locals = None  # type: Dict[str, Any]
        # self.globals = None  # type: Dict[str, Any]
        # The logger object, used to report things in the terminal
        # self.logger = None  # type: logger.Logger
        # # Sometimes, for event callback, it is useful
        # # to have access to the parent object
        # self.parent = None  # type: Optional[BaseCallback]
        self.value_losses = []
        self.total_losses = []

    def _on_training_start(self) -> None:
        """
        This method is called before the first rollout starts.
        """
        pass

    def _on_rollout_start(self) -> None:

        
        #print("Available metrics:", list(self.model.logger.name_to_value.keys()))

        #print(self.model.logger.name_to_value)

        #print(self.model.logger.name_to_value.get('train/value_loss', None))

        #print('Mean Rollout', self.model.logger.name_to_value.get('rollout/ep_rew_mean', None))
        #print('Episode Length', self.model.logger.name_to_value.get('rollout/ep_len_mean', None))
        
        
        """
        A rollout is the collection of environment interaction
        using the current policy.
        This event is triggered before collecting new samples.
        """
        pass

    def _on_step(self) -> bool:
        """
        This method will be called by the model after each call to `env.step()`.

        For child callback (of an `EventCallback`), this will be called
        when the event is triggered.

        :return: (bool) If the callback returns False, training is aborted early.
        """

        #print(self.locals['rewards'])

        return True

    def _on_rollout_end(self) -> None:
        """
        This event is triggered before updating the policy.
        """
        if self.display_rollout:
            self.display_rollout_buffer()

        pass

    def _on_training_end(self) -> None:
        """
        This event is triggered before exiting the `learn()` method.
        """
        pass

    from tabulate import tabulate

    def display_rollout_buffer(self):
        print("\nRollout Buffer Contents:")
        
        # Create base DataFrame with flattened arrays
        df = pd.DataFrame({
            'observations': self.model.rollout_buffer.observations.reshape(-1),
            'actions': self.model.rollout_buffer.actions.reshape(-1),
            'rewards': self.model.rollout_buffer.rewards.reshape(-1),
            'returns': self.model.rollout_buffer.returns.reshape(-1),
            'episode_starts': self.model.rollout_buffer.episode_starts.reshape(-1),
            'values': self.model.rollout_buffer.values.reshape(-1),
            'log_probs': self.model.rollout_buffer.log_probs.reshape(-1),
            'advantages': self.model.rollout_buffer.advantages.reshape(-1)
        })
        
        # Find indices where episodes start
        episode_starts = np.where(df['episode_starts'])[0]
        
        # Create list to store all rows
        all_rows = []
        
        # Print header with line
        header = "  observations  actions  rewards  returns  episode_starts  values  log_probs  advantages"
        print("-" * (len(header) + 5))  # +5 for index space
        print("idx " + header)
        print("-" * (len(header) + 5))
        
        last_idx = -1
        for i in range(len(df)):
            if i in episode_starts:
                # Add double line if this is end of previous episode (not first episode)
                if last_idx != -1:
                    print("=" * (len(header) + 5))
                # Add "New Episode" row
                print(f"{i:3d} {'New Episode':^{len(header)}}")
                print("-" * (len(header) + 5))
                
            # Format row data with 2 decimal places
            row = df.iloc[i]
            print(f"{i:3d}  {row['observations']:8.0f}    {row['actions']:4.0f}    {row['rewards']:6.2f}  "
                f"{row['returns']:7.2f}  {row['episode_starts']:8.0f}      "
                f"{row['values']:6.2f}  {row['log_probs']:8.2f}  {row['advantages']:9.2f}")
            
            last_idx = i
        
        # Add final double line
        print("=" * (len(header) + 5))
        print('\n')

In [103]:
env_taxi = gym.make('Taxi-v3')

model_taxi = model = PPO(

    policy="MlpPolicy",
    env=env_taxi,
    verbose=0,
    
    n_steps=2048,
    batch_size = 64,
    n_epochs=10,
    
    learning_rate=5e-4,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    clip_range_vf=None,
    ent_coef = 0.0,
    vf_coef = 0.5,
    max_grad_norm=0.5
    )


In [104]:
## Total Loss, Policy Loss, Value Loss, Entropy Loss
## Avg Reward, Discr. Loss, Discr Acc.

# train/learning_rate
# train/entropy_loss
# train/policy_gradient_loss
# train/value_loss
# train/approx_kl
# train/clip_fraction
# train/loss
# train/explained_variance
# train/n_updates
# train/clip_range

In [105]:
callback_taxi = CustomCallback(verbose=0, display_rollout=True)
model_taxi.learn(75000, callback=callback_taxi)
env_taxi.close()


Rollout Buffer Contents:
---------------------------------------------------------------------------------------------
idx   observations  actions  rewards  returns  episode_starts  values  log_probs  advantages
---------------------------------------------------------------------------------------------
  0                                       New Episode                                       
---------------------------------------------------------------------------------------------
  0       233       2     -1.00   -53.34         1       -0.14     -1.79     -53.20
  1       253       1     -1.00   -55.66         0        0.04     -1.79     -55.69
  2       153       3     -1.00   -58.11         0       -0.12     -1.79     -57.99
  3       153       3     -1.00   -60.71         0       -0.12     -1.79     -60.59
  4       153       0     -1.00   -63.48         0       -0.12     -1.79     -63.36
  5       253       1     -1.00   -66.44         0        0.04     -1.79     -66.48
  