In [9]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gym
from copy import deepcopy

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)

# Define hyperparameters
gamma = 0.99
learning_rate = 0.0005
n_epochs = 1
update_timestep = 5
kl_constraint = 0.01  # KL divergence constraint

# Create environment
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Actor-critic network architecture
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.actor = nn.Linear(64, action_dim)
        self.critic = nn.Linear(64, 1)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        logits = self.actor(x)
        value = self.critic(x)
        return logits, value

# Trust Region Policy Optimization (TRPO) algorithm
class TRPO:
    def __init__(self):
        self.policy_old = ActorCritic()
        self.policy_new = deepcopy(self.policy_old)  # Create a new instance for the new policy
        self.optimizer = optim.Adam(self.policy_new.parameters(), lr=learning_rate)

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        logits, _ = self.policy_old(state)
        action_probs = torch.softmax(logits, dim=-1)
        action = torch.multinomial(action_probs, 1)
        return action.item()

    def update_policy(self):
        # Copy parameters from the new policy to the old policy
        self.policy_old.load_state_dict(self.policy_new.state_dict())

    def train(self, states, actions, rewards, dones):
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        dones = torch.FloatTensor(dones)

        # Compute old action probabilities
        logits_old, _ = self.policy_old(states)
        action_probs_old = torch.softmax(logits_old, dim=-1)
        old_action_probs = action_probs_old.gather(1, actions.unsqueeze(1)).squeeze()

        # Compute gradients of policy parameters
        logits_new, values = self.policy_new(states)
        action_probs_new = torch.softmax(logits_new, dim=-1)
        new_action_probs = action_probs_new.gather(1, actions.unsqueeze(1)).squeeze()
        ratio = new_action_probs / old_action_probs
        print("Ratio: {}".format(ratio))

        # Compute advantages
        advantages = rewards - values.squeeze().detach()
        advantages = self.discount_rewards(advantages, dones)

        # Compute surrogate loss
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - kl_constraint, 1 + kl_constraint) * advantages
        actor_loss = -torch.min(surr1, surr2).mean()

        # Compute value function loss
        critic_loss = nn.MSELoss()(values, rewards)

        # Compute total loss
        loss = actor_loss + critic_loss

        # Perform backpropagation and update the new policy
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update the old policy to the new policy
        self.update_policy()

    def discount_rewards(self, rewards, dones):
        discounted_rewards = torch.zeros_like(rewards)
        running_add = 0
        for t in reversed(range(len(rewards))):
            running_add = running_add * gamma * (1 - dones[t]) + rewards[t]
            discounted_rewards[t] = running_add
        return discounted_rewards

# Initialize TRPO agent
trpo_agent = TRPO()

# Main training loop
total_timesteps = 0
for epoch in range(n_epochs):
    states, actions, rewards, dones = [], [], [], []
    episode_reward = 0
    state = env.reset()

    while True:
        action = trpo_agent.select_action(state)
        next_state, reward, done, _ = env.step(action)

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        dones.append(done)

        episode_reward += reward
        state = next_state

        total_timesteps += 1

        if total_timesteps % update_timestep == 0:
            trpo_agent.train(states, actions, rewards, dones)
            states, actions, rewards, dones = [], [], [], []

        if done:
            break

    print(f"Epoch: {epoch + 1}, Total Timesteps: {total_timesteps}, Episode Reward: {episode_reward}")

env.close()


Ratio: tensor([1., 1., 1., 1., 1.], grad_fn=<DivBackward0>)
Ratio: tensor([1., 1., 1., 1., 1.], grad_fn=<DivBackward0>)
Epoch: 1, Total Timesteps: 14, Episode Reward: 14.0
