In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt
from collections import deque

# Define the Policy Network
class Policy(nn.Module):
    def __init__(self, state_size, action_size):
        super(Policy, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_size, 128),
            nn.ReLU(),
            nn.Linear(128, action_size),
            nn.Softmax(dim=-1)
        )

    def forward(self, state):
        return self.network(state)

# REINFORCE Agent
class ReinforceAgent:
    def __init__(self, state_size, action_size, learning_rate=0.01, gamma=0.99):
        self.policy = Policy(state_size, action_size)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.saved_log_probs = []
        self.rewards = []

    def select_action(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.policy(state)
        m = Categorical(probs)
        action = m.sample()
        self.saved_log_probs.append(m.log_prob(action))
        return action.item()

    def finish_episode(self):
        R = 0
        policy_loss = []
        returns = deque()
        for r in self.rewards[::-1]:
            R = r + self.gamma * R
            returns.appendleft(R)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-6) # Normalize returns

        for log_prob, R in zip(self.saved_log_probs, returns):
            policy_loss.append(-log_prob * R)

        self.optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward()
        self.optimizer.step()

        del self.rewards[:]
        del self.saved_log_probs[:]
        
        return policy_loss.item()

def main():
    # Environment setup
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n

    # Agent setup
    agent = ReinforceAgent(state_size, action_size)

    # Training parameters
    n_episodes = 1500
    print_every = 100

    # Logging
    scores_deque = deque(maxlen=100)
    scores = []
    losses = []

    for i_episode in range(1, n_episodes + 1):
        state, _ = env.reset()
        episode_reward = 0
        for t in range(1000):
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            agent.rewards.append(reward)
            state = next_state
            episode_reward += reward
            if terminated or truncated:
                break
        
        scores_deque.append(episode_reward)
        scores.append(episode_reward)
        loss = agent.finish_episode()
        losses.append(loss)

        if i_episode % print_every == 0:
            print(f'Episode {i_episode}\tAverage Score: {np.mean(scores_deque):.2f}')
        
        if np.mean(scores_deque) >= 475.0:
            print(f'\nEnvironment solved in {i_episode-100} episodes!\tAverage Score: {np.mean(scores_deque):.2f}')
            break

    # Plotting
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)

    # Plot Training Scores and Rolling Average
    ax1.plot(np.arange(1, len(scores) + 1), scores, label='Training Scores')
    rolling_avg = [np.mean(scores[k:k+100]) for k in range(len(scores)-100)]
    ax1.plot(np.arange(100, len(scores)), rolling_avg, color='r', label='Rolling Average Score (100 episodes)')
    ax1.set_ylabel('Score')
    ax1.set_title('Training Scores and Rolling Average')
    ax1.legend()
    ax1.grid(True)

    # Plot Average Loss per Episode
    ax2.plot(np.arange(1, len(losses) + 1), losses, color='orange', label='Average Loss per Episode')
    ax2.set_xlabel('Episode #')
    ax2.set_ylabel('Loss')
    ax2.set_title('Average Loss per Episode')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

if __name__ == '__main__':
    main()