<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/reinforcementlearning_actor_critic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

In [3]:
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Actor-Critic Networks
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, state):
        return self.net(state)

class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state):
        return self.net(state)

In [5]:
actor = Actor(state_dim, action_dim)
critic = Critic(state_dim)
actor_optim = optim.Adam(actor.parameters(), lr=1e-3)
critic_optim = optim.Adam(critic.parameters(), lr=1e-3)
gamma = 0.99
num_episodes = 1000

In [None]:
for episode in range(num_episodes):
    state, _ = env.reset()
    done = False
    total_reward = 0

    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32)
        action_probs = actor(state_tensor)
        dist = Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        next_state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        total_reward += reward

        next_state_tensor = torch.tensor(next_state, dtype=torch.float32)
        target = reward + gamma * critic(next_state_tensor).item() * (1 - done)
        value = critic(state_tensor)
        advantage = target - value.item()

        # Critic update (value loss)
        critic_loss = (value - target) ** 2
        critic_optim.zero_grad()
        critic_loss.backward()
        critic_optim.step()

        # Actor update (policy gradient)
        actor_loss = -log_prob * advantage
        actor_optim.zero_grad()
        actor_loss.backward()
        actor_optim.step()

        state = next_state

    if episode % 100 == 0:
        print(f"Episode {episode}, Total Reward: {total_reward:.2f}")

Episode 0, Total Reward: 20.00
Episode 100, Total Reward: 12.00
