In [2]:
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt

class Actor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Actor, self).__init__()
        self.fc = nn.Linear(input_dim, 128)
        self.actor = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc(x))
        return torch.softmax(self.actor(x), dim=-1)

class Critic(nn.Module):
    def __init__(self, input_dim):
        super(Critic, self).__init__()
        self.fc = nn.Linear(input_dim, 128)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc(x))
        return self.critic(x)

def a2c(env, num_episodes=1000, gamma=0.99, learning_rate=0.001):
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    actor = Actor(input_dim, output_dim)
    critic = Critic(input_dim)

    optimizer = optim.Adam(list(actor.parameters()) + list(critic.parameters()), lr=learning_rate)

    rewards_per_episode = []

    for episode in range(num_episodes):
        state, _ = env.reset()
        done = False

        rewards = []
        log_probs = []
        values = []

        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action_probs = actor(state_tensor)
            value = critic(state_tensor)

            dist = Categorical(action_probs)
            action = dist.sample()

            log_prob = dist.log_prob(action)
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated

            rewards.append(reward)
            log_probs.append(log_prob)
            values.append(value)

            state = next_state

        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)

        returns = torch.FloatTensor(returns)
        log_probs = torch.cat(log_probs)
        values = torch.cat(values).squeeze()
        advantages = returns - values.detach()

        actor_loss = -(log_probs * advantages).mean()
        critic_loss = (returns - values).pow(2).mean()
        loss = actor_loss + critic_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        episode_reward = sum(rewards)
        rewards_per_episode.append(episode_reward)

        if episode % 100 == 0:
            sample_state, _ = env.reset()
            sample_tensor = torch.FloatTensor(sample_state).unsqueeze(0)
            sample_probs = actor(sample_tensor).detach().numpy()
            print(f"[Episode {episode}] Reward: {episode_reward:.1f}")
            print(f"Sample Action Probabilities: {sample_probs}")

    return actor, critic, rewards_per_episode

if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    actor, critic, rewards = a2c(env)
    env.close()

    plt.plot(rewards)
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('A2C on CartPole')
    plt.show()


[Episode 0] Reward: 32.0
Sample Action Probabilities: [[0.52109104 0.4789089 ]]
[Episode 100] Reward: 31.0
Sample Action Probabilities: [[0.5990624  0.40093765]]


KeyboardInterrupt: 