In [1]:
import gymnasium as gym
import torch
import torch.optim as optim
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, CallbackList, BaseCallback
from stable_baselines3.common.monitor import Monitor
from torch.utils.tensorboard import SummaryWriter

import config
from wrappers import CliffWalkingStateWrapper, RewardPredictorWrapper
from reward_model import RewardModel
from teacher import SemanticTeacher
from buffer import TrajectoryBuffer
from train_rm import train_reward_model

class RLHFDataCollectionCallback(BaseCallback):
    def __init__(self, buffer, verbose=0):
        super(RLHFDataCollectionCallback, self).__init__(verbose)
        self.buffer = buffer

    def _on_step(self) -> bool:
        infos = self.locals['infos'][0]
        obs = self.locals['new_obs'][0]
        action = self.locals['actions'][0]
        self.buffer.add_step(obs, action)

        if 'episode' in infos or self.locals['dones'][0]:
            self.buffer.finalize_episode()
        return True

def inject_optimal_trajectory(env, buffer):
    print("--- Injecting Optimal Demonstration ---")

    for _ in range(50):
        obs, _ = env.reset()
        actions = [0] + [1] * 11 + [2]

        for action in actions:
            next_obs, reward, terminated, truncated, info = env.step(action)
            buffer.add_step(obs, action)
            obs = next_obs

            if terminated or truncated:
                buffer.finalize_episode()

    print("Optimal trajectory added to buffer (x50 copies)!")

def run_training():
    print(f"--- Setting up RLHF on {config.ENV_ID} ---")

    raw_env = gym.make(config.ENV_ID)
    raw_env = gym.wrappers.TimeLimit(raw_env, max_episode_steps=config.MAX_STEPS)
    env = CliffWalkingStateWrapper(raw_env)
    env = Monitor(env)
    env = RewardPredictorWrapper(env, reward_model=None)

    eval_env_raw = gym.make(config.ENV_ID)
    eval_env_raw = gym.wrappers.TimeLimit(eval_env_raw, max_episode_steps=config.MAX_STEPS)
    eval_env = CliffWalkingStateWrapper(eval_env_raw)
    eval_env = Monitor(eval_env)

    demo_raw_env = gym.make(config.ENV_ID)
    demo_env = CliffWalkingStateWrapper(demo_raw_env)

    reward_model = RewardModel()
    env.reward_model = reward_model
    rm_optimizer = optim.Adam(reward_model.parameters(), lr=config.RM_LR)

    teacher = SemanticTeacher()
    trajectory_buffer = TrajectoryBuffer(config.BUFFER_CAPACITY, config.SEGMENT_LENGTH)

    writer = SummaryWriter(log_dir=f"{config.TB_LOG_DIR}/rm_metrics")

    print("--- Pre-training: Collecting random trajectories ---")
    obs, _ = demo_env.reset()
    for _ in range(config.PRETRAIN_STEPS):
        action = demo_env.action_space.sample()
        obs, reward, terminated, truncated, info = demo_env.step(action)
        trajectory_buffer.add_step(obs, action)

        if terminated or truncated:
            trajectory_buffer.finalize_episode()
            obs, _ = demo_env.reset()

    inject_optimal_trajectory(demo_env, trajectory_buffer)

    print("--- Training Reward Model (Initial) ---")
    pairs = trajectory_buffer.sample_pairs(config.RM_BATCH_SIZE)
    if pairs:
        initial_loss = train_reward_model(reward_model, pairs, teacher, rm_optimizer)
        print(f"Initial RM Loss: {initial_loss:.4f}")

    data_callback = RLHFDataCollectionCallback(buffer=trajectory_buffer)

    checkpoint_callback = CheckpointCallback(
        save_freq=config.CHECKPOINT_FREQ,
        save_path=config.LOG_DIR,
        name_prefix="ppo_cliff"
    )

    eval_callback = EvalCallback(
        eval_env,
        best_model_save_path=config.LOG_DIR,
        log_path=config.LOG_DIR,
        eval_freq=config.EVAL_FREQ,
        deterministic=True,
        render=False
    )

    callback_list = CallbackList([data_callback, checkpoint_callback, eval_callback])

    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=config.PPO_LR,
        n_steps=config.FEEDBACK_FREQ,
        gamma=config.PPO_GAMMA,
        gae_lambda=config.PPO_GAE_LAMBDA,
        ent_coef=config.PPO_ENTROPY_COEF,
        clip_range=config.PPO_EPS_CLIP,
        batch_size=config.PPO_BATCH_SIZE,
        verbose=1,
        tensorboard_log=config.TB_LOG_DIR,
        device=config.DEVICE, 
        n_epochs = config.PPO_K_EPOCHS
    )

    num_rounds = config.TOTAL_TIMESTEPS // config.FEEDBACK_FREQ
    total_timesteps_so_far = 0

    for i in range(num_rounds):
        print(f"\n--- Round {i+1}/{num_rounds} ---")

        if i % 5 == 0 and i > 0:
            inject_optimal_trajectory(demo_env, trajectory_buffer)

        model.learn(
            total_timesteps=config.FEEDBACK_FREQ,
            callback=callback_list,
            reset_num_timesteps=False
        )
        total_timesteps_so_far += config.FEEDBACK_FREQ

        print("Training Reward Model...")
        avg_loss = 0
        training_steps = 10
        for _ in range(training_steps):
            pairs = trajectory_buffer.sample_pairs(config.RM_BATCH_SIZE)
            if len(pairs) > 0:
                loss = train_reward_model(reward_model, pairs, teacher, rm_optimizer)
                avg_loss += loss

        final_loss = avg_loss / training_steps

        writer.add_scalar("RewardModel/Loss", final_loss, total_timesteps_so_far)
        writer.add_scalar("RewardModel/Buffer_Size", len(trajectory_buffer), total_timesteps_so_far)

    print("\n--- Training Complete ---")
    model.save(f"{config.LOG_DIR}/ppo_cliffwalking_final")
    reward_model.save(f"{config.LOG_DIR}/reward_model_final.pth")
    writer.close()
    print("Models saved.")

if __name__ == "__main__":
    run_training()


--- Setting up RLHF on CliffWalking-v1 ---
--- Pre-training: Collecting random trajectories ---
--- Injecting Optimal Demonstration ---
Optimal trajectory added to buffer (x50 copies)!
--- Training Reward Model (Initial) ---
Initial RM Loss: 0.6701
Using cpu device
Wrapping the env in a DummyVecEnv.

--- Round 1/24 ---
Logging to ./rlhf_tb_logs/PPO_0
Eval num_timesteps=2000, episode_reward=-100.00 +/- 0.00
Episode length: 100.00 +/- 0.00
---------------------------------
| eval/              |          |
|    mean_ep_length  | 100      |
|    mean_reward     | -100     |
| time/              |          |
|    total_timesteps | 2000     |
---------------------------------
New best mean reward!
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 100      |
|    ep_rew_mean     | -897     |
| time/              |          |
|    fps             | 560      |
|    iterations      | 1        |
|    time_elapsed    | 3        |
|    total_timesteps | 204