In [9]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gym

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)

# Define hyperparameters
gamma = 0.99
eps_clip = 0.2
learning_rate = 0.0005
n_epochs = 1
update_timestep = 20

# Create environment
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Actor-critic network architecture
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__() #fc: fully connected
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.actor = nn.Linear(64, action_dim)
        self.critic = nn.Linear(64, 1)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        logits = self.actor(x)
        value = self.critic(x)
        return logits, value

# Proximal Policy Optimization (PPO) algorithm
class PPO:
    def __init__(self):
        self.policy = ActorCritic()
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        logits, _ = self.policy(state)
        action_probs = torch.softmax(logits, dim=-1)
        action = torch.multinomial(action_probs, 1)
        return action.item()

    def train(self, states, actions, advantages, returns):
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        advantages = torch.FloatTensor(advantages).unsqueeze(1)
        returns = torch.FloatTensor(returns)

        unique_actions, unique_indices = torch.unique(actions, return_inverse=True)

        for _ in range(update_timestep // len(states)):
            logits, values = self.policy(states)
            values = values.squeeze()

            action_probs = torch.softmax(logits, dim=-1)
            action_masks = torch.zeros_like(action_probs).scatter_(1, unique_actions.unsqueeze(1), 1)
            old_action_probs = torch.sum(action_probs * action_masks[unique_indices.unsqueeze(1)], dim=1)

            ratios = torch.exp(torch.log(old_action_probs + 1e-10) - torch.log(action_probs + 1e-10))

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - eps_clip, 1 + eps_clip) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()

            critic_loss = nn.MSELoss()(returns, values)

            loss = actor_loss + 0.5 * critic_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()


# Initialize PPO agent
ppo_agent = PPO()

# Main training loop
total_timesteps = 0
for epoch in range(n_epochs):
    states, actions, rewards, action_probs, dones, next_states = [], [], [], [], [], []
    episode_reward = 0
    state = env.reset()

    while True:
        action = ppo_agent.select_action(state)
        next_state, reward, done, _ = env.step(action)

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        dones.append(done)
        next_states.append(next_state)

        episode_reward += reward
        state = next_state

        total_timesteps += 1

        if total_timesteps % update_timestep == 0:
            _, next_value = ppo_agent.policy(torch.FloatTensor(next_states))
            returns, advantages = [], []
            discounted_sum = 0
            for i in range(len(rewards) - 1, -1, -1):
                discounted_sum = rewards[i] + gamma * discounted_sum * (1 - dones[i])
                advantage = discounted_sum - next_value[i].item()
                advantages.insert(0, advantage)
                returns.insert(0, discounted_sum)

            ppo_agent.train(states, actions, advantages, returns)

            states, actions, rewards, action_probs, dones, next_states = [], [], [], [], [], []

        env.render()

        if done:
            break

    print(f"Epoch: {epoch + 1}, Total Timesteps: {total_timesteps}, Episode Reward: {episode_reward}")

env.close()


Epoch: 1, Total Timesteps: 12, Episode Reward: 12.0
Epoch: 2, Total Timesteps: 27, Episode Reward: 15.0
Epoch: 3, Total Timesteps: 44, Episode Reward: 17.0
Epoch: 4, Total Timesteps: 65, Episode Reward: 21.0
Epoch: 5, Total Timesteps: 77, Episode Reward: 12.0
Epoch: 6, Total Timesteps: 93, Episode Reward: 16.0
Epoch: 7, Total Timesteps: 106, Episode Reward: 13.0
Epoch: 8, Total Timesteps: 125, Episode Reward: 19.0
Epoch: 9, Total Timesteps: 139, Episode Reward: 14.0
Epoch: 10, Total Timesteps: 152, Episode Reward: 13.0
Epoch: 11, Total Timesteps: 170, Episode Reward: 18.0
Epoch: 12, Total Timesteps: 193, Episode Reward: 23.0
Epoch: 13, Total Timesteps: 206, Episode Reward: 13.0
Epoch: 14, Total Timesteps: 216, Episode Reward: 10.0
Epoch: 15, Total Timesteps: 229, Episode Reward: 13.0
Epoch: 16, Total Timesteps: 249, Episode Reward: 20.0
Epoch: 17, Total Timesteps: 307, Episode Reward: 58.0
Epoch: 18, Total Timesteps: 320, Episode Reward: 13.0
Epoch: 19, Total Timesteps: 355, Episode Re

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch: 728, Total Timesteps: 10012, Episode Reward: 13.0
Epoch: 729, Total Timesteps: 10023, Episode Reward: 11.0
Epoch: 730, Total Timesteps: 10034, Episode Reward: 11.0
Epoch: 731, Total Timesteps: 10049, Episode Reward: 15.0
Epoch: 732, Total Timesteps: 10064, Episode Reward: 15.0
Epoch: 733, Total Timesteps: 10074, Episode Reward: 10.0
Epoch: 734, Total Timesteps: 10088, Episode Reward: 14.0
Epoch: 735, Total Timesteps: 10102, Episode Reward: 14.0
Epoch: 736, Total Timesteps: 10115, Episode Reward: 13.0
Epoch: 737, Total Timesteps: 10128, Episode Reward: 13.0
Epoch: 738, Total Timesteps: 10143, Episode Reward: 15.0
Epoch: 739, Total Timesteps: 10155, Episode Reward: 12.0
Epoch: 740, Total Timesteps: 10169, Episode Reward: 14.0
Epoch: 741, Total Timesteps: 10180, Episode Reward: 11.0
Epoch: 742, Total Timesteps: 10193, Episode Reward: 13.0
Epoch: 743, Total Timesteps: 10203, Episode Reward: 10.0
Epoch: 744, Total Timesteps: 10215, Episode Reward: 12.0
Epoch: 745, Total Timesteps: 10