<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/reinforcementlearning_policy_based.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

# Setup environment
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Define the policy network
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, state):
        return self.net(state)

# Initialize
policy_net = PolicyNetwork(state_dim, action_dim)
optimizer = optim.Adam(policy_net.parameters(), lr=1e-2)
gamma = 0.99

# Training REINFORCE
def compute_returns(rewards, gamma):
    R = 0
    returns = []
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    return torch.tensor(returns, dtype=torch.float32)

num_episodes = 1000
for episode in range(num_episodes):
    state, _ = env.reset()
    log_probs = []
    rewards = []
    done = False

    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32)
        action_probs = policy_net(state_tensor)
        dist = Categorical(action_probs)
        action = dist.sample()

        log_probs.append(dist.log_prob(action))

        state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        rewards.append(reward)

    # Compute returns
    returns = compute_returns(rewards, gamma)
    returns = (returns - returns.mean()) / (returns.std() + 1e-9)  # normalize

    # Policy gradient update
    loss = -torch.sum(torch.stack(log_probs) * returns)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if episode % 100 == 0:
        print(f"Episode {episode}, total reward: {sum(rewards)}")

env.close()


Episode 0, total reward: 11.0
Episode 100, total reward: 106.0
Episode 200, total reward: 125.0
Episode 300, total reward: 29.0
Episode 400, total reward: 59.0
