### Vanilla policy gradient algorithm from scratch

In [None]:
# Steps:
# 1. Create an env wrapper for continuous and discrete environments. For continuous environments, policy network needs to output parameters for a multivariate gaussian distribution.

import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym
import torch.distributions as dist

class PolicyNetwork(nn.Module):
    def __init__(self, state_space_dim, action_space_dim, is_continuous):
        super().__init__()
        self.is_continuous = is_continuous
        self.state_dim = state_space_dim
        self.action_space_dim = action_space_dim
        self.fc1 = nn.Linear(state_space_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        # For continuous action dim, return mean and only diagonal entries on covariance matrix since action states are probably independent
        self.fc3 = nn.Linear(64, 2 * action_space_dim if is_continuous else action_space_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        if self.is_continuous:
            mean = x[..., :self.action_space_dim]
            covar = torch.exp(x[..., self.action_space_dim:])
            return mean, torch.diag(covar)
        else:
            return torch.softmax(x, dim=-1)


class BaselineVNetwork(nn.Module):
    def __init__(self, state_space_dim, max_timesteps):
        super().__init__()
        self.max_timesteps = max_timesteps
        self.fc1 = nn.Linear(state_space_dim + 1, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, state, timestep):
        timestep = timestep.float() / self.max_timesteps
        x = F.relu(self.fc1(torch.cat([state, timestep])))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


# Implement vpg alg
iterations = 100
max_timesteps = 500
t_per_iter = 35
env = gym.make('CartPole-v0')
state_space_dim = 4
action_space_dim = 2
policy = PolicyNetwork(state_space_dim, action_space_dim, False)
policy_opt = torch.optim.AdamW(policy.parameters(), lr=5e-4)
baseline = BaselineVNetwork(state_space_dim, max_timesteps)
baseline_opt = torch.optim.AdamW(baseline.parameters(), lr=1e-3)
baseline_opt_steps = 10
discount = 0.99

### Training run

In [None]:
for _ in range(iterations):
    # Collect a set of trajectories by executing the current policy
    trajectories = []
    for _ in range(t_per_iter):
        current_state, _ = env.reset()
        traj = torch.tensor([])
        for _ in range(max_timesteps):
            actions_dist = policy(torch.tensor(current_state))
            if policy.is_continuous:
                mean, covar = actions_dist
                mgd = dist.MultivariateNormal(mean, covar)
                action = mgd.sample()
            else:
                d = dist.Categorical(actions_dist)
                action = d.sample()

            next_state, reward, terminated, truncated = env.step(action)
            traj.append((current_state, action, reward))
            current_state = next_state

            if terminated or truncated:
                break
        trajectories.append(traj)

    # At each timestep in each trajectory, compute the return and advantage estimate
    adv_ests = []
    for t in range(t_per_iter):
        traj = trajectories[t]
        traj_adv_ests = torch.tensor([])
        for time in range(len(traj)):
            return_t = sum(discount**(t_prime - time) * traj[t_prime][2] for t_prime in range(t, len(traj)))
            adv_est = return_t - baseline(traj[time][0], time)
            traj_adv_ests.append(adv_est)
        adv_ests.append(traj_adv_ests)

    # Re-fit the baseline
    for _ in range(baseline_opt_steps):
        baseline_loss = (torch.cat(adv_ests) ** 2).sum()
        baseline_opt.zero_grad()
        baseline_loss.backward()
        baseline_opt.step()

    # Update the policy using a policy gradient estimate
    g = torch.tensor([])
    for traj_ind in range(len(trajectories)):
        traj = trajectories[traj_ind]
        for step_ind in range(len(traj)):
            step_state, step_action, step_reward = traj[step_ind]
            if policy.is_continuous:
                mean, covar = policy(torch.tensor(step_state))
                mvg = dist.MultivariateNormal(mean, covar)
                log_pdf_val = mvg.log_prob(step_action)
                g.append(log_pdf_val * adv_ests[traj_ind][step_ind])
            else:
                action_dist = policy(torch.tensor(step_state))
                g.append(action_dist[step_action] * adv_ests[traj_ind][step_ind])
    # Take only one gradient step for policy to avoid overfitting
    policy_loss = g.sum()
    policy_opt.zero_grad()
    policy_loss.backward()
    policy_opt.step()