In [None]:
import gymnasium as gym
import torch
import torch.optim as optim
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback

import config
from wrappers import CliffWalkingStateWrapper, RewardPredictorWrapper
from reward_model import RewardModel
from teacher import Teacher
from buffer import TrajectoryBuffer
from train_rm import train_reward_model

class RLHFDataCollectionCallback(BaseCallback):
    def __init__(self, buffer, verbose=0):
        super(RLHFDataCollectionCallback, self).__init__(verbose)
        self.buffer = buffer

    def _on_step(self) -> bool:
        infos = self.locals['infos'][0]
        obs = self.locals['new_obs'][0]
        action = self.locals['actions'][0]
        
        if 'original_reward' in infos:
            real_reward = infos['original_reward']
            self.buffer.add_step(obs, action, real_reward)
            
        if 'episode' in infos or self.locals['dones'][0]:
            self.buffer.finalize_episode()
            
        return True
    
def inject_optimal_trajectory(env, buffer):
    print("--- Injecting Optimal Demonstration ---")
    
    for _ in range(50): 
        obs, _ = env.reset()
        actions = [0] + [1]*11 + [2]
        
        for action in actions:
            next_obs, reward, terminated, truncated, info = env.step(action)
            
            if 'original_reward' in info:
                buffer.add_step(obs, action, info['original_reward'])
            
            obs = next_obs
            
            if terminated or truncated:
                buffer.finalize_episode()
                
    print("Optimal trajectory added to buffer (x50 copies)!")
    

def main():
    print(f"--- Setting up RLHF on {config.ENV_ID} ---")
    
    raw_env = gym.make(config.ENV_ID)
    env = CliffWalkingStateWrapper(raw_env)
    
    reward_model = RewardModel()
    rm_optimizer = optim.Adam(reward_model.parameters(), lr=config.RM_LR)
    
    env = RewardPredictorWrapper(env, reward_model)
    
    teacher = Teacher()
    trajectory_buffer = TrajectoryBuffer(
        capacity=config.BUFFER_CAPACITY,
        segment_length=config.SEGMENT_LENGTH
    )
    
    print("--- Pre-training: Collecting random trajectories ---")
    obs, _ = env.reset()
    for _ in range(config.PRETRAIN_STEPS):
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        
        if 'original_reward' in info:
            trajectory_buffer.add_step(obs, action, info['original_reward'])
            
        if terminated or truncated:
            trajectory_buffer.finalize_episode()
            obs, _ = env.reset()
            
    inject_optimal_trajectory(env, trajectory_buffer)
            
    print(f"Buffer populated with {len(trajectory_buffer)} segments.")

    print("--- Training Reward Model (Initial) ---")
    pairs = trajectory_buffer.sample_pairs(config.RM_BATCH_SIZE)
    if pairs:
        initial_loss = train_reward_model(
            reward_model, pairs, teacher, rm_optimizer
        )
        print(f"Initial RM Loss: {initial_loss:.4f}")

    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=config.PPO_LR,
        n_steps=config.FEEDBACK_FREQ,
        gamma=config.PPO_GAMMA,
        gae_lambda=config.PPO_GAE_LAMBDA,
        ent_coef=config.PPO_ENTROPY_COEF,
        clip_range=config.PPO_EPS_CLIP,
        batch_size=config.PPO_BATCH_SIZE,
        verbose=0,
        device=config.DEVICE
    )

    callback = RLHFDataCollectionCallback(buffer=trajectory_buffer)
    
    num_rounds = config.TOTAL_TIMESTEPS // config.FEEDBACK_FREQ
    
    for i in range(num_rounds):
        print(f"\n--- Round {i+1}/{num_rounds} ---")
        
        if i % 5 == 0 and i > 0:
            print(">>> Re-injecting Optimal Trajectories to refresh memory...")
            inject_optimal_trajectory(env, trajectory_buffer)
        
        model.learn(
            total_timesteps=config.FEEDBACK_FREQ,
            callback=callback,
            reset_num_timesteps=False
        )
        
        print("Training Reward Model...")
        avg_loss = 0
        training_steps = 10
        
        for _ in range(training_steps):
            pairs = trajectory_buffer.sample_pairs(config.RM_BATCH_SIZE)
            if len(pairs) > 0:
                loss = train_reward_model(
                    reward_model, pairs, teacher, rm_optimizer
                )
                avg_loss += loss
        
        print(f"Avg RM Loss: {avg_loss/training_steps:.4f}")

    print("\n--- Training Complete ---")
    
    model.save("ppo_cliffwalking_rlhf")
    reward_model.save("reward_model_final.pth")
    print("Models saved.")

if __name__ == "__main__":
    main()


--- Setting up RLHF on CliffWalking-v1 ---
--- Pre-training: Collecting random trajectories ---
--- Injecting Optimal Demonstration ---
Optimal trajectory added to buffer (x50 copies)!
Buffer populated with 90 segments.
--- Training Reward Model (Initial) ---
Initial RM Loss: 0.4066

--- Round 1/24 ---




Training Reward Model...
Avg RM Loss: 0.3968

--- Round 2/24 ---
Training Reward Model...
Avg RM Loss: 0.3604

--- Round 3/24 ---
Training Reward Model...
Avg RM Loss: 0.3925

--- Round 4/24 ---
Training Reward Model...
Avg RM Loss: 0.3368

--- Round 5/24 ---
Training Reward Model...
Avg RM Loss: 0.3392

--- Round 6/24 ---
>>> Re-injecting Optimal Trajectories to refresh memory...
--- Injecting Optimal Demonstration ---
Optimal trajectory added to buffer (x50 copies)!
Training Reward Model...
Avg RM Loss: 0.3484

--- Round 7/24 ---
Training Reward Model...
Avg RM Loss: 0.3325

--- Round 8/24 ---
Training Reward Model...
Avg RM Loss: 0.3088

--- Round 9/24 ---
Training Reward Model...
Avg RM Loss: 0.2795

--- Round 10/24 ---
Training Reward Model...
Avg RM Loss: 0.3740

--- Round 11/24 ---
>>> Re-injecting Optimal Trajectories to refresh memory...
--- Injecting Optimal Demonstration ---
Optimal trajectory added to buffer (x50 copies)!
Training Reward Model...
Avg RM Loss: 0.2867

--- Ro