In [63]:
from email import policy
import gym
from platformdirs import user_desktop_dir
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback

import random, math
import numpy as np
from utils import pref_save, pref_load
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import torch.nn.functional as F

import torch.nn as nn

import torch.optim as optim
from reinforce_PPORLHF import reinforce_rwd2go_PPO_RLHF

from torch.distributions import Normal

In [65]:
ENV_NAME = "MountainCarContinuous-v0"
env = gym.make(ENV_NAME)

In [None]:
policy1 = PPO.load('./policies/ppo_mountain_ctn_final.zip')
policy2 = PPO.load('./policies/ppo_mountain_ctn_10000_steps.zip')
pref_data = pref_load('./pref_data/pref_data_100_MountainCarContinuous-v0.pickle')

In [68]:
class RewardModel(nn.Module):
    def __init__(self, state_size=2, action_size=1, hidden_size=32):
        super(RewardModel, self).__init__()
        self.fc1 = nn.Linear(state_size + action_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = self.fc2(x)
        return F.sigmoid(x)

    def predict_reward(self, state, action):
        state = state
        action = torch.tensor(action).reshape(1,1)  
        
        state_action = torch.cat((state, action), dim=1)
        reward = self.forward(state_action).cpu()
        return reward
    

lr        = 3e-2
epochs    = 8

reward_model = RewardModel(state_size=2, action_size=1)

optimizer = torch.optim.Adam(reward_model.parameters(), lr=lr)

def trajectory_reward(reward_model, states, actions):
    total_reward = torch.tensor(0., device=device)
    for s, a in zip(states, actions):
        s_t = torch.tensor(s, dtype=torch.float32, device=device)
        reward = reward_model.predict_reward(s_t.unsqueeze(0), a).squeeze(0)
        total_reward += reward.squeeze(0)
    return total_reward

losses_reward_model = []
for epoch in range(1, epochs+1):
    total_loss = 0.0
    
    for s0, tau_plus, tau_minus in pref_data:
        reward_plus = trajectory_reward(reward_model, tau_plus["states"], tau_plus["actions"])
        reward_minus = trajectory_reward(reward_model, tau_minus["states"], tau_minus["actions"])
        stacked = torch.stack([reward_plus, reward_minus])
        log_Z   = torch.logsumexp(stacked, dim=0)
        total_loss += - (reward_plus - log_Z)
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    avg_loss = total_loss / len(pref_data)
    losses_reward_model.append(avg_loss.detach().numpy().item())
    
    print(f"Epoch {epoch}/{epochs} — avg loss: {avg_loss:.4f}")

Epoch 1/8 — avg loss: 461.0858
Epoch 2/8 — avg loss: 384.6009
Epoch 3/8 — avg loss: 308.2630
Epoch 4/8 — avg loss: 235.1981
Epoch 5/8 — avg loss: 168.4429
Epoch 6/8 — avg loss: 110.3193
Epoch 7/8 — avg loss: 63.1294
Epoch 8/8 — avg loss: 27.4724


In [71]:
class sb3Wrapper(nn.Module):
    def __init__(self, model, std_=3.0):
        super(sb3Wrapper,self).__init__()
        self.extractor = model.policy.mlp_extractor
        self.policy_net = model.policy.mlp_extractor.policy_net
        self.action_net = model.policy.action_net

    def forward(self,x):
        x = self.policy_net(x)
        x = self.action_net(x)
        return x

    def act(self, state):
        state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(next(self.parameters()).device)
        mean_act = self.forward(state_tensor)

        std = torch.ones_like(mean_act) * 3
        dist = Normal(mean_act, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        
        action_np = action.detach().squeeze(0).cpu().numpy()
        return action_np, log_prob.squeeze(0)

In [73]:
policyCopied = sb3Wrapper(policy2)

In [75]:
#proba2

In [77]:
class sb3Wrapper(nn.Module):
    def __init__(self, model, std_=10.0):
        super(sb3Wrapper,self).__init__()
        self.extractor = model.policy.mlp_extractor
        self.policy_net = model.policy.mlp_extractor.policy_net
        self.action_net = model.policy.action_net

    def forward(self,x):
        x = self.policy_net(x)
        x = self.action_net(x)
        return x

    def act(self, state):
        state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(next(self.parameters()).device)
        mean_act = self.forward(state_tensor)

        std = torch.ones_like(mean_act) * 10
        dist = Normal(mean_act, std)
        action = dist.sample()

        delta = 0.2
        lower = action - delta
        upper = action + delta
        
        cdf_upper = dist.cdf(upper)
        cdf_lower = dist.cdf(lower)
        
        prob_interval = cdf_upper - cdf_lower
        
        log_prob = torch.log(prob_interval + 1e-10)
            
        action_np = action.detach().squeeze(0).cpu().numpy()
        return action_np, log_prob.squeeze(0)

In [81]:
policyCopied = sb3Wrapper(policy2)
opt1 = optim.Adam(policyCopied.parameters(), lr=1e-3)
reward_model.eval()

reward_evaluation_every=10
losses, mean_returns, std_returns = reinforce_rwd2go_PPO_RLHF(env, policyCopied, opt1, reward_model, n_episodes=1000)

Ep 100	avg100: -6339.68
Ep 200	avg100: -7265.26
Ep 300	avg100: -7666.10
Ep 400	avg100: -8036.48
Ep 500	avg100: -8233.76
Ep 600	avg100: -9230.97
Ep 700	avg100: -9428.03
Ep 800	avg100: -9378.62
Ep 900	avg100: -9626.19


In [83]:
mean_returns

array([ -5393.29970061,  -5712.46297244,  -6886.96661506,  -3522.33361626,
        -6628.95613077,  -6902.51355744,  -5863.53092378,  -5887.57982098,
        -4635.32796411,  -7088.54656869,  -4611.7968164 ,  -5280.48155177,
        -6015.49543261,  -6582.74921262,  -7612.94159208,  -9822.3370435 ,
        -4489.8062188 ,  -9600.85280448,  -8039.94871635,  -6125.21295883,
        -5989.72666766,  -8767.01499473,  -9003.08321822,  -5083.42619445,
        -8281.76813405,  -3763.96004788,  -9751.30844791,  -6869.54999997,
        -6257.14062601,  -9335.45634646,  -8162.45716704,  -9034.08646427,
        -7226.03268554,  -9110.7390822 ,  -5814.09392934,  -4749.76913099,
        -8485.97417231,  -8986.54212531, -10526.95213295,  -8469.21583391,
       -10348.45247739,  -7935.31496262,  -6324.15043948,  -9564.5327546 ,
        -8797.73413904,  -7338.57789092,  -9768.67332368,  -6726.93354003,
        -7683.51502292,  -5811.75279539, -10038.41662882,  -5800.86336981,
        -8829.02740212, -