In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gymnasium as gym

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.actor = nn.Linear(hidden_dim, action_dim)
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        logits = self.actor(x)
        value = self.critic(x)
        return logits, value

env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

model = ActorCritic(state_dim, action_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
gamma = 0.99

for episode in range(1000):
    state, _ = env.reset()
    log_probs = []
    values = []
    rewards = []

    done = False
    while not done:
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        logits, value = model(state_tensor)
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()

        next_state, reward, done, _, _ = env.step(action.item())

        log_probs.append(dist.log_prob(action))
        values.append(value)
        rewards.append(reward)
        state = next_state

    # Compute returns
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.tensor(returns)

    log_probs = torch.stack(log_probs)
    values = torch.cat(values).squeeze(-1)
    advantage = returns - values.detach()

    # Losses
    actor_loss = -(log_probs * advantage).mean()
    critic_loss = F.mse_loss(values, returns)
    loss = actor_loss + critic_loss

    # Update
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if episode % 50 == 0:
        print(f"Episode {episode}, Total Reward: {sum(rewards)}")


Episode 0, Total Reward: 32.0
Episode 50, Total Reward: 10.0
Episode 100, Total Reward: 18.0
Episode 150, Total Reward: 18.0
Episode 200, Total Reward: 32.0
Episode 250, Total Reward: 24.0
Episode 300, Total Reward: 99.0
Episode 350, Total Reward: 42.0
Episode 400, Total Reward: 40.0
Episode 450, Total Reward: 36.0
Episode 500, Total Reward: 27.0
Episode 550, Total Reward: 80.0
Episode 600, Total Reward: 42.0
Episode 650, Total Reward: 58.0
Episode 700, Total Reward: 63.0
Episode 750, Total Reward: 126.0
Episode 800, Total Reward: 106.0
Episode 850, Total Reward: 123.0
Episode 900, Total Reward: 169.0
Episode 950, Total Reward: 127.0
