In [26]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gym

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)

# Define hyperparameters
gamma = 0.99
learning_rate = 0.0005
n_epochs = 1
update_timestep = 10
kl_constraint = 0.01  # KL divergence constraint

# Create environment
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Actor-critic network architecture
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.actor = nn.Linear(64, action_dim)
        self.critic = nn.Linear(64, 1)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        logits = self.actor(x)
        value = self.critic(x)
        return logits, value

# Trust Region Policy Optimization (TRPO) algorithm
class TRPO:
    def __init__(self):
        self.policy = ActorCritic()
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        logits, _ = self.policy(state)
        action_probs = torch.softmax(logits, dim=-1)
        action = torch.multinomial(action_probs, 1)
        return action.item()

    def train(self, states, actions, advantages):
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        advantages = torch.FloatTensor(advantages).unsqueeze(1)

        # Compute old action probabilities
        logits, _ = self.policy(states)
        action_probs = torch.softmax(logits, dim=-1)
        old_action_probs = action_probs.gather(1, actions.unsqueeze(1)).squeeze()

        # Compute gradients of policy parameters
        logits, values = self.policy(states)
        values = values.squeeze()

        action_probs = torch.softmax(logits, dim=-1)
        new_action_probs = action_probs.gather(1, actions.unsqueeze(1)).squeeze()
        ratio = new_action_probs / old_action_probs

        # Compute surrogate loss
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - kl_constraint, 1 + kl_constraint) * advantages
        actor_loss = -torch.min(surr1, surr2).mean()

        # Compute value function loss
        critic_loss = nn.MSELoss()(values, torch.FloatTensor(returns))

        # Compute total loss
        loss = actor_loss + critic_loss

        # Perform backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


    def trust_region_update(self, loss):
        # Placeholder for trust region update
        pass

# Initialize TRPO agent
trpo_agent = TRPO()

# Main training loop
total_timesteps = 0
for epoch in range(n_epochs):
    states, actions, rewards, action_probs, dones, next_states = [], [], [], [], [], []
    episode_reward = 0
    state = env.reset()

    while True:
        action = trpo_agent.select_action(state)
        next_state, reward, done, _ = env.step(action)

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        dones.append(done)
        next_states.append(next_state)


        episode_reward += reward
        state = next_state

        total_timesteps += 1

        if total_timesteps % update_timestep == 0:
            # print(next_states)
            _, next_value = trpo_agent.policy(torch.FloatTensor(next_states))
            # print("Next Value: {}".format(next_value))
            returns, advantages = [], []
            discounted_sum = 0

            # print(rewards)    
            # print(1 - dones[i])
            for i in range(len(rewards) - 1, -1, -1):
                # discounted_sum = rewards[i] + gamma * discounted_sum * (1 - dones[i])
                discounted_sum = rewards[i] + gamma * discounted_sum
                # print(discounted_sum)
                advantage = discounted_sum - next_value[i].item()
                advantages.insert(0, advantage)
                returns.insert(0, discounted_sum)

            trpo_agent.train(states, actions, advantages)

            states, actions, rewards, action_probs, dones, next_states = [], [], [], [], [], []

        # env.render()

        if done:
            break

    # print(f"Epoch: {epoch + 1}, Total Timesteps: {total_timesteps}, Episode Reward: {episode_reward}")

env.close()


Next Value: tensor([[-0.1228],
        [-0.1147],
        [-0.1226],
        [-0.1139],
        [-0.1069],
        [-0.0964],
        [-0.1060],
        [-0.0939],
        [-0.0790],
        [-0.0892]], grad_fn=<AddmmBackward0>)
