<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Policy_Gradient_Methods.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym

# Define the policy network
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.fc(x)

# Initialize environment and policy network
env = gym.make('CartPole-v1', new_step_api=True)
policy_net = PolicyNetwork(state_dim=4, action_dim=2)
optimizer = optim.Adam(policy_net.parameters(), lr=0.01)

# Function to select action based on policy network's output probabilities
def select_action(state):
    state = torch.from_numpy(state).float()
    probabilities = policy_net(state)
    action = torch.multinomial(probabilities, 1).item()
    return action

# Training function for policy gradient method
def train(env, policy_net, optimizer, episodes):
    for episode in range(episodes):
        state = env.reset()
        log_probs = []
        rewards = []
        for t in range(1000):
            action = select_action(state)
            result = env.step(action)
            next_state, reward, done, info = result[:4]
            if len(result) == 5:
                truncated = result[4]
            else:
                truncated = False
            log_prob = torch.log(policy_net(torch.from_numpy(state).float())[action])
            log_probs.append(log_prob)
            rewards.append(reward)
            state = next_state
            if done or truncated:
                break

        # Compute the discounted returns
        R = 0
        policy_loss = []
        returns = []
        for r in rewards[::-1]:
            R = r + 0.99 * R
            returns.insert(0, R)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)

        # Compute the policy loss
        for log_prob, R in zip(log_probs, returns):
            policy_loss.append(-log_prob * R)
        optimizer.zero_grad()
        policy_loss = torch.stack(policy_loss).sum()
        policy_loss.backward()
        optimizer.step()

        print(f"Episode {episode + 1}, Loss: {policy_loss.item():.4f}")

# Train the policy network
train(env, policy_net, optimizer, episodes=100)