<a href="https://colab.research.google.com/github/Sidhtang/implementation-of-research-papers/blob/main/reinforcement_learning_with_human_feedback.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Proximal policy optimization and policygradient

# It works like a neural network architecture, whereby the gradient of the output, i.e,
#the log of probabilities of actions in that particular state, is taken with respect to parameters of the environment
# and the change is reflected in the policy, based upon the gradients.
# the problem with this mwthod wast that hypersenstivity and hperparameter
# along with their poor sample efficiency

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal

class PPOMemory:
    def __init__(self, batch_size):
        self.states = []
        self.actions = []
        self.probs = []
        self.vals = []
        self.rewards = []
        self.dones = []
        self.batch_size = batch_size

    def store_memory(self, state, action, probs, vals, reward, done):
        self.states.append(state)
        self.actions.append(action)
        self.probs.append(probs)
        self.vals.append(vals)
        self.rewards.append(reward)
        self.dones.append(done)

    def clear_memory(self):
        self.states = []
        self.actions = []
        self.probs = []
        self.vals = []
        self.rewards = []
        self.dones = []

    def generate_batches(self):
        n_states = len(self.states)
        batch_start = np.arange(0, n_states, self.batch_size)
        indices = np.arange(n_states, dtype=np.int64)
        np.random.shuffle(indices)
        batches = [indices[i:i+self.batch_size] for i in batch_start]
        return np.array(self.states), np.array(self.actions), \
               np.array(self.probs), np.array(self.vals), \
               np.array(self.rewards), np.array(self.dones), batches

class ActorNetwork(nn.Module):
    def __init__(self, input_dims, n_actions, alpha, fc1_dims=256, fc2_dims=256):
        super(ActorNetwork, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(input_dims, fc1_dims),
            nn.ReLU(),
            nn.Linear(fc1_dims, fc2_dims),
            nn.ReLU(),
            nn.Linear(fc2_dims, n_actions),
            nn.Softmax(dim=-1)
        )
        self.optimizer = optim.Adam(self.parameters(), lr=alpha)
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        dist = self.actor(state)
        return dist

class CriticNetwork(nn.Module):
    def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256):
        super(CriticNetwork, self).__init__()
        self.critic = nn.Sequential(
            nn.Linear(input_dims, fc1_dims),
            nn.ReLU(),
            nn.Linear(fc1_dims, fc2_dims),
            nn.ReLU(),
            nn.Linear(fc2_dims, 1)
        )
        self.optimizer = optim.Adam(self.parameters(), lr=alpha)
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        value = self.critic(state)
        return value

class PPOAgent:
    def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003, gae_lambda=0.95,
                 policy_clip=0.2, batch_size=64, n_epochs=10):
        self.gamma = gamma
        self.policy_clip = policy_clip
        self.n_epochs = n_epochs
        self.gae_lambda = gae_lambda

        self.actor = ActorNetwork(input_dims, n_actions, alpha)
        self.critic = CriticNetwork(input_dims, alpha)
        self.memory = PPOMemory(batch_size)

    def remember(self, state, action, probs, vals, reward, done):
        self.memory.store_memory(state, action, probs, vals, reward, done)

    def save_models(self):
        torch.save(self.actor.state_dict(), 'actor.pth')
        torch.save(self.critic.state_dict(), 'critic.pth')

    def load_models(self):
        self.actor.load_state_dict(torch.load('actor.pth'))
        self.critic.load_state_dict(torch.load('critic.pth'))

    def choose_action(self, observation):
        state = torch.tensor([observation], dtype=torch.float).to(self.actor.device)

        dist = self.actor(state)
        value = self.critic(state)
        action = dist.sample()

        probs = torch.squeeze(dist).detach().cpu().numpy()
        action = torch.squeeze(action).detach().cpu().numpy()
        value = torch.squeeze(value).item()

        return action, probs, value

    def learn(self):
        for _ in range(self.n_epochs):
            state_arr, action_arr, old_prob_arr, vals_arr, \
            reward_arr, dones_arr, batches = self.memory.generate_batches()

            values = vals_arr
            advantage = np.zeros(len(reward_arr), dtype=np.float32)

            for t in range(len(reward_arr)-1):
                discount = 1
                a_t = 0
                for k in range(t, len(reward_arr)-1):
                    a_t += discount*(reward_arr[k] + self.gamma*values[k+1]*\
                            (1-int(dones_arr[k])) - values[k])
                    discount *= self.gamma*self.gae_lambda
                advantage[t] = a_t

            advantage = torch.tensor(advantage).to(self.actor.device)
            values = torch.tensor(values).to(self.actor.device)

            for batch in batches:
                states = torch.tensor(state_arr[batch], dtype=torch.float).to(self.actor.device)
                old_probs = torch.tensor(old_prob_arr[batch]).to(self.actor.device)
                actions = torch.tensor(action_arr[batch]).to(self.actor.device)

                dist = self.actor(states)
                critic_value = self.critic(states)
                critic_value = torch.squeeze(critic_value)

                new_probs = dist.log_prob(actions)
                prob_ratio = new_probs.exp() / old_probs
                weighted_probs = advantage[batch] * prob_ratio
                weighted_clipped_probs = torch.clamp(prob_ratio, 1-self.policy_clip,
                        1+self.policy_clip)*advantage[batch]
                actor_loss = -torch.min(weighted_probs, weighted_clipped_probs).mean()

                returns = advantage[batch] + values[batch]
                critic_loss = (returns-critic_value)**2
                critic_loss = critic_loss.mean()

                total_loss = actor_loss + 0.5*critic_loss
                self.actor.optimizer.zero_grad()
                self.critic.optimizer.zero_grad()
                total_loss.backward()
                self.actor.optimizer.step()
                self.critic.optimizer.step()

        self.memory.clear_memory()

In [None]:
!pip install gymnasium

Collecting gymnasium
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/958.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m225.3/958.1 kB[0m [31m6.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m952.3/958.1 kB[0m [31m15.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0


In [None]:
import gymnasium as gym
import numpy as np
import torch
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import os

# [Previous PPOMemory, ActorNetwork, and CriticNetwork classes remain the same]

class Agent:
    def __init__(self, n_actions, input_dims, gamma=0.99, alpha=0.0001, gae_lambda=0.95,
                 policy_clip=0.2, batch_size=64, n_epochs=10):
        self.gamma = gamma
        self.policy_clip = policy_clip
        self.n_epochs = n_epochs
        self.gae_lambda = gae_lambda

        self.actor = ActorNetwork(n_actions, input_dims, alpha)
        self.critic = CriticNetwork(input_dims, alpha)
        self.memory = PPOMemory(batch_size)

    def remember(self, state, action, probs, vals, reward, done):
        self.memory.store_memory(state, action, probs, vals, reward, done)

    def choose_action(self, observation):
        state = torch.FloatTensor(observation).unsqueeze(0).to(self.actor.device)

        dist = self.actor(state)
        value = self.critic(state)
        action = dist.sample()

        probs = torch.squeeze(dist.log_prob(action)).item()
        action = torch.squeeze(action).item()
        value = torch.squeeze(value).item()

        return action, probs, value

    def learn(self):
        total_actor_loss = 0
        total_critic_loss = 0

        for _ in range(self.n_epochs):
            state_arr, action_arr, old_prob_arr, vals_arr, \
            reward_arr, dones_arr, batches = \
                    self.memory.generate_batches()

            values = vals_arr
            advantage = np.zeros(len(reward_arr), dtype=np.float32)

            # Calculate advantages using GAE
            for t in range(len(reward_arr)-1):
                discount = 1
                a_t = 0
                for k in range(t, len(reward_arr)-1):
                    a_t += discount*(reward_arr[k] + self.gamma*values[k+1]*\
                            (1-int(dones_arr[k])) - values[k])
                    discount *= self.gamma*self.gae_lambda
                advantage[t] = a_t

            # Normalize advantages
            advantage = torch.FloatTensor(advantage).to(self.actor.device)
            advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)

            values = torch.FloatTensor(values).to(self.actor.device)

            for batch in batches:
                states = torch.FloatTensor(state_arr[batch]).to(self.actor.device)
                old_probs = torch.FloatTensor(old_prob_arr[batch]).to(self.actor.device)
                actions = torch.FloatTensor(action_arr[batch]).to(self.actor.device)

                dist = self.actor(states)
                critic_value = self.critic(states)
                critic_value = torch.squeeze(critic_value)

                new_probs = dist.log_prob(actions)
                prob_ratio = (new_probs - old_probs).exp()
                weighted_probs = advantage[batch] * prob_ratio
                weighted_clipped_probs = advantage[batch] * \
                        torch.clamp(prob_ratio, 1-self.policy_clip, 1+self.policy_clip)
                actor_loss = -torch.min(weighted_probs, weighted_clipped_probs).mean()

                returns = advantage[batch] + values[batch]
                critic_loss = (returns-critic_value)**2
                critic_loss = critic_loss.mean()

                total_loss = actor_loss + 0.5*critic_loss

                self.actor.optimizer.zero_grad()
                self.critic.optimizer.zero_grad()
                total_loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=0.5)
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=0.5)

                self.actor.optimizer.step()
                self.critic.optimizer.step()

                total_actor_loss += actor_loss.item()
                total_critic_loss += critic_loss.item()

        self.memory.clear_memory()
        return total_actor_loss / self.n_epochs, total_critic_loss / self.n_epochs

def train():
    env = gym.make('CartPole-v1')
    N = 1000  # Increased maximum episodes
    batch_size = 64  # Increased batch size
    n_epochs = 5
    alpha = 0.0001  # Learning rate

    agent = Agent(n_actions=env.action_space.n, batch_size=batch_size,
                 alpha=alpha, n_epochs=n_epochs,
                 input_dims=env.observation_space.shape[0],
                 gamma=0.99,
                 gae_lambda=0.95,
                 policy_clip=0.2)

    if not os.path.exists('plots'):
        os.makedirs('plots')

    figure_file = 'plots/cartpole.png'
    best_score = float('-inf')
    score_history = []
    loss_history = {'actor': [], 'critic': []}

    learn_iters = 0
    avg_score = 0
    n_steps = 0

    # Early stopping parameters
    patience = 150  # Increased patience
    best_avg_score = float('-inf')
    episodes_without_improvement = 0
    min_episodes = 200  # Minimum episodes before early stopping

    for i in range(N):
        observation, _ = env.reset()
        done = False
        score = 0
        episode_steps = 0

        while not done:
            action, prob, val = agent.choose_action(observation)
            new_observation, reward, terminated, truncated, _ = env.step(action)

            # Modified reward to encourage longer episodes
            modified_reward = reward
            if terminated and episode_steps < 500:  # Max steps for CartPole-v1
                modified_reward = -1

            done = terminated or truncated
            n_steps += 1
            episode_steps += 1
            score += reward  # Keep original reward for score
            agent.remember(observation, action, prob, val, modified_reward, done)

            if n_steps % batch_size == 0:
                actor_loss, critic_loss = agent.learn()
                loss_history['actor'].append(actor_loss)
                loss_history['critic'].append(critic_loss)
                learn_iters += 1

            observation = new_observation

        score_history.append(score)
        avg_score = np.mean(score_history[-100:])

        if avg_score > best_score:
            best_score = avg_score

        # Early stopping check with minimum episodes requirement
        if i >= min_episodes and avg_score > best_avg_score:
            best_avg_score = avg_score
            episodes_without_improvement = 0
        elif i >= min_episodes:
            episodes_without_improvement += 1

        if i >= min_episodes and episodes_without_improvement >= patience:
            print(f'Early stopping triggered at episode {i} with best average score: {best_avg_score:.1f}')
            break

        if i % 10 == 0:  # Print every 10 episodes
            print(f'Episode: {i} Score: {score:.1f} Avg Score: {avg_score:.1f} Best: {best_avg_score:.1f}')

    # Plot learning curves
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 2, 1)
    plt.plot(score_history)
    plt.title('Learning Curve')
    plt.xlabel('Episode')
    plt.ylabel('Score')
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(loss_history['actor'], label='Actor Loss')
    plt.plot(loss_history['critic'], label='Critic Loss')
    plt.title('Loss History')
    plt.xlabel('Update Step')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(figure_file)
    plt.close()

if __name__ == '__main__':
    train()

Episode: 0 Score: 10.0 Avg Score: 10.0 Best: -inf
Episode: 10 Score: 12.0 Avg Score: 17.6 Best: -inf
Episode: 20 Score: 14.0 Avg Score: 24.7 Best: -inf
Episode: 30 Score: 35.0 Avg Score: 25.9 Best: -inf
Episode: 40 Score: 29.0 Avg Score: 26.2 Best: -inf
Episode: 50 Score: 29.0 Avg Score: 26.1 Best: -inf
Episode: 60 Score: 48.0 Avg Score: 25.9 Best: -inf
Episode: 70 Score: 15.0 Avg Score: 25.7 Best: -inf
Episode: 80 Score: 14.0 Avg Score: 25.2 Best: -inf
Episode: 90 Score: 25.0 Avg Score: 25.2 Best: -inf
Episode: 100 Score: 22.0 Avg Score: 25.1 Best: -inf
Episode: 110 Score: 30.0 Avg Score: 26.3 Best: -inf
Episode: 120 Score: 18.0 Avg Score: 25.4 Best: -inf
Episode: 130 Score: 15.0 Avg Score: 24.2 Best: -inf
Episode: 140 Score: 29.0 Avg Score: 24.0 Best: -inf
Episode: 150 Score: 43.0 Avg Score: 24.3 Best: -inf
Episode: 160 Score: 23.0 Avg Score: 25.1 Best: -inf
Episode: 170 Score: 33.0 Avg Score: 26.9 Best: -inf
Episode: 180 Score: 68.0 Avg Score: 28.6 Best: -inf
Episode: 190 Score: 48.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
from typing import List, Tuple

class PolicyNetwork(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)

class ValueNetwork(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)

class RewardModel(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)

class RLHF:
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        learning_rate: float = 3e-4,
        gamma: float = 0.99,
        epsilon: float = 0.2,
        c1: float = 1.0,
        c2: float = 0.01
    ):
        self.policy = PolicyNetwork(input_dim, hidden_dim, output_dim)
        self.value = ValueNetwork(input_dim, hidden_dim)
        self.reward_model = RewardModel(input_dim, hidden_dim)

        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
        self.value_optimizer = optim.Adam(self.value.parameters(), lr=learning_rate)
        self.reward_model_optimizer = optim.Adam(self.reward_model.parameters(), lr=learning_rate)

        self.gamma = gamma
        self.epsilon = epsilon  # PPO clipping parameter
        self.c1 = c1  # Value loss coefficient
        self.c2 = c2  # Entropy coefficient

    def train_reward_model(self, demonstrations: List[Tuple[torch.Tensor, float]]):
        """Train reward model on human feedback demonstrations"""
        self.reward_model.train()

        for state, human_reward in demonstrations:
            predicted_reward = self.reward_model(state)
            loss = nn.MSELoss()(predicted_reward, torch.tensor([human_reward]))

            self.reward_model_optimizer.zero_grad()
            loss.backward()
            self.reward_model_optimizer.step()

    def compute_gae(
        self,
        rewards: List[float],
        values: List[float],
        next_value: float,
        dones: List[bool],
        gamma: float = 0.99,
        lam: float = 0.95
    ) -> List[float]:
        """Compute Generalized Advantage Estimation"""
        advantages = []
        gae = 0

        for r, v, done, next_v in zip(
            reversed(rewards),
            reversed(values),
            reversed(dones),
            reversed(values[1:] + [next_value])
        ):
            delta = r + gamma * next_v * (1 - done) - v
            gae = delta + gamma * lam * (1 - done) * gae
            advantages.insert(0, gae)

        return advantages

    def ppo_update(
        self,
        states: torch.Tensor,
        actions: torch.Tensor,
        old_log_probs: torch.Tensor,
        advantages: torch.Tensor,
        returns: torch.Tensor,
        epochs: int = 10
    ):
        """Update policy using PPO algorithm"""
        for _ in range(epochs):
            # Get current policy distributions
            action_probs = self.policy(states)
            dist = Categorical(action_probs)
            curr_log_probs = dist.log_prob(actions)

            # Calculate ratio and surrogate losses
            ratios = torch.exp(curr_log_probs - old_log_probs)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.epsilon, 1 + self.epsilon) * advantages

            # Calculate policy loss with clipping
            policy_loss = -torch.min(surr1, surr2).mean()

            # Calculate value loss
            value_pred = self.value(states)
            value_loss = nn.MSELoss()(value_pred, returns)

            # Calculate entropy bonus
            entropy = dist.entropy().mean()

            # Combined loss
            total_loss = (
                policy_loss +
                self.c1 * value_loss -
                self.c2 * entropy
            )

            # Update networks
            self.policy_optimizer.zero_grad()
            self.value_optimizer.zero_grad()
            total_loss.backward()
            self.policy_optimizer.step()
            self.value_optimizer.step()

    def collect_trajectory(self, env, max_steps: int = 1000) -> Tuple[List]:
        """Collect a trajectory using current policy"""
        states, actions, rewards, values, log_probs, dones = [], [], [], [], [], []
        state = env.reset()
        done = False
        steps = 0

        while not done and steps < max_steps:
            # Convert state to tensor and get action from policy
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action_probs = self.policy(state_tensor)
            dist = Categorical(action_probs)
            action = dist.sample()

            # Get value estimate and log probability
            value = self.value(state_tensor)
            log_prob = dist.log_prob(action)

            # Take action in environment
            next_state, reward, done, _ = env.step(action.item())

            # Get reward from reward model
            predicted_reward = self.reward_model(state_tensor).item()

            # Store transition
            states.append(state)
            actions.append(action)
            rewards.append(predicted_reward)  # Use reward model's prediction
            values.append(value.item())
            log_probs.append(log_prob)
            dones.append(done)

            state = next_state
            steps += 1

        return states, actions, rewards, values, log_probs, dones

    def train(
        self,
        env,
        n_episodes: int = 1000,
        max_steps: int = 1000,
        update_freq: int = 10
    ):
        """Main training loop"""
        for episode in range(n_episodes):
            # Collect trajectory
            states, actions, rewards, values, log_probs, dones = self.collect_trajectory(env, max_steps)

            # Convert to tensors
            states = torch.FloatTensor(states)
            actions = torch.LongTensor(actions)
            old_log_probs = torch.stack(log_probs)

            # Compute returns and advantages
            next_value = self.value(torch.FloatTensor(states[-1])).item()
            advantages = self.compute_gae(rewards, values, next_value, dones)
            returns = torch.FloatTensor(advantages) + torch.FloatTensor(values)
            advantages = torch.FloatTensor(advantages)

            # Normalize advantages
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

            # PPO update
            if episode % update_freq == 0:
                self.ppo_update(states, actions, old_log_probs, advantages, returns)

# Example usage:
"""
# Initialize environment and RLHF
env = gym.make('CartPole-v1')
rlhf = RLHF(
    input_dim=4,  # CartPole state dimension
    hidden_dim=64,
    output_dim=2  # CartPole action dimension
)

# Create synthetic human feedback for demonstration
demonstrations = [
    (torch.randn(4), 1.0),  # (state, human_reward)
    (torch.randn(4), -1.0),
    # ... more demonstrations
]

# Train reward model
rlhf.train_reward_model(demonstrations)

# Train policy using RLHF
rlhf.train(env)
"""

"\n# Initialize environment and RLHF\nenv = gym.make('CartPole-v1')\nrlhf = RLHF(\n    input_dim=4,  # CartPole state dimension\n    hidden_dim=64,\n    output_dim=2  # CartPole action dimension\n)\n\n# Create synthetic human feedback for demonstration\ndemonstrations = [\n    (torch.randn(4), 1.0),  # (state, human_reward)\n    (torch.randn(4), -1.0),\n    # ... more demonstrations\n]\n\n# Train reward model\nrlhf.train_reward_model(demonstrations)\n\n# Train policy using RLHF\nrlhf.train(env)\n"