In [None]:
import numpy as np
import gymnasium as gym  # Use gymnasium instead of gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

# Define the Actor model
class Actor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Actor, self).__init__()
        self.fc = nn.Linear(input_dim, 128)
        self.actor = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc(x))
        return torch.softmax(self.actor(x), dim=-1)

# Define the Critic model
class Critic(nn.Module):
    def __init__(self, input_dim):
        super(Critic, self).__init__()
        self.fc = nn.Linear(input_dim, 128)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc(x))
        return self.critic(x)

# A2C algorithm
def a2c(env, num_episodes=1000, gamma=0.99, learning_rate=0.001):
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    actor = Actor(input_dim, output_dim)
    critic = Critic(input_dim)
    optimizer = optim.Adam(list(actor.parameters()) + list(critic.parameters()), lr=learning_rate)

    for episode in range(num_episodes):
        state, _ = env.reset()  # Extract only state

        done = False
        rewards = []
        log_probs = []
        values = []

        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action_probs = actor(state_tensor)
            value = critic(state_tensor)

            dist = Categorical(action_probs)
            action = dist.sample()

            log_prob = dist.log_prob(action)
            next_state, reward, terminated, truncated, _ = env.step(action.item())

            done = terminated or truncated  # Consider both termination cases

            rewards.append(reward)
            log_probs.append(log_prob)
            values.append(value)

            state = next_state

        # Compute returns and advantages
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)

        returns = torch.FloatTensor(returns)
        log_probs = torch.cat(log_probs)
        values = torch.cat(values).squeeze()

        # Compute advantages
        advantages = returns - values.detach()

        # Update the model
        actor_loss = -(log_probs * advantages).mean()
        critic_loss = (returns - values).pow(2).mean()
        loss = actor_loss + critic_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if episode % 100 == 0:
            print(f'Episode {episode}, Total Reward: {sum(rewards)}')

    return actor, critic

# Main function to run the A2C algorithm
if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    actor, critic = a2c(env)
    env.close()


Episode 0, Total Reward: 18.0
Episode 100, Total Reward: 26.0
Episode 200, Total Reward: 75.0
Episode 300, Total Reward: 315.0
Episode 400, Total Reward: 312.0
Episode 500, Total Reward: 300.0
Episode 600, Total Reward: 350.0
Episode 700, Total Reward: 500.0
Episode 800, Total Reward: 185.0
