### Vanilla policy gradient algorithm from scratch (with a2c advantage estimation)

In [118]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym
import torch.distributions as dist
from enum import Enum

class PolicyNetwork(nn.Module):
    def __init__(self, state_space_dim, action_space_dim, is_continuous):
        super().__init__()
        self.is_continuous = is_continuous
        self.state_dim = state_space_dim
        self.action_space_dim = action_space_dim
        self.fc1 = nn.Linear(state_space_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        # For continuous action dim, return mean and only diagonal entries on covariance matrix since action states are probably independent
        self.fc3 = nn.Linear(64, 2 * action_space_dim if is_continuous else action_space_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        if self.is_continuous:
            mean = x[..., :self.action_space_dim]
            covar = torch.exp(x[..., self.action_space_dim:])
            return mean, torch.diag(covar)
        else:
            return torch.softmax(x, dim=-1)


class BaselineVNetwork(nn.Module):
    def __init__(self, state_space_dim, max_timesteps):
        super().__init__()
        self.max_timesteps = max_timesteps
        self.fc1 = nn.Linear(state_space_dim + 1, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, state, timestep):
        timestep = timestep.float() / self.max_timesteps
        x = F.relu(self.fc1(torch.cat([state, timestep.unsqueeze(-1)])))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class AdvantageEstimation(Enum):
    VPG = 0
    # Use k=5 step lookahead
    A2C = 1


# Implement vpg alg
iterations = 100
max_timesteps = 500
t_per_iter = 50
env = gym.make('CartPole-v1', render_mode='human', max_episode_steps=max_timesteps)
state_space_dim = 4
action_space_dim = 2
policy = PolicyNetwork(state_space_dim, action_space_dim, False)
policy_opt = torch.optim.AdamW(policy.parameters(), lr=5e-4)
baseline = BaselineVNetwork(state_space_dim, max_timesteps)
baseline_opt = torch.optim.AdamW(baseline.parameters(), lr=1e-2)
baseline_opt_steps = 10
discount = 0.99
adv_est_scheme = AdvantageEstimation.A2C

### Training run

In [119]:
for iter in range(iterations):
    # Collect a set of trajectories by executing the current policy
    trajectories = []
    for t_ind in range(t_per_iter):
        current_state, _ = env.reset()
        current_state = torch.tensor(current_state, dtype=torch.float32)
        traj = []
        for _ in range(max_timesteps):
            actions_dist = policy(current_state)
            if policy.is_continuous:
                mean, covar = actions_dist
                mgd = dist.MultivariateNormal(mean, covar)
                action = mgd.sample()
            else:
                d = dist.Categorical(actions_dist)
                action = d.sample()

            next_state, reward, terminated, truncated, _ = env.step(action.item())
            new_step = (current_state, action, reward)
            traj.append(new_step)
            current_state = torch.tensor(next_state, dtype=torch.float32)

            if terminated or truncated:
                break

        trajectories.append(traj)


    # At each timestep in each trajectory, compute the return and advantage estimate
    def calc_returns():
        if adv_est_scheme == AdvantageEstimation.VPG:
            returns = [[sum(discount**(t_prime - time) * traj[t_prime][2] for t_prime in range(time, len(traj))) for time in range(len(traj))] for traj in trajectories]
        else:
            returns = []
            for traj in trajectories:
                traj_returns = []
                for time in range(len(traj)):
                    t_return = 0
                    is_long_en = len(traj) > time + 5
                    for t_prime in range(time, min(len(traj), time + 5)):
                        t_return += discount**(t_prime - time) * traj[t_prime][2]

                    if is_long_en:
                        t_return += discount**5 * baseline(traj[time + 5][0], torch.tensor(time + 5, dtype=torch.float32))

                    traj_returns.append(t_return)

                returns.append(traj_returns)

        return returns

    def calc_baselines():
       return [[baseline(traj[time][0], torch.tensor(time, dtype=torch.float32)) for time in range(len(traj))] for traj in trajectories]

    def calc_adv_ests(returns):
        baselines = calc_baselines()
        return [[(returns[traj_ind][step_ind] - baselines[traj_ind][step_ind]) for step_ind in range(len(trajectories[traj_ind]))] for traj_ind in range(len(trajectories))]

    def normalize_adv_ests(adv_ests):
        flattened = torch.stack([a for t in adv_ests for a in t])
        return [[(ae - flattened.mean()) / (flattened.std() + 1e-8) for ae in t_ae] for t_ae in adv_ests]

    def calc_baseline_loss():
        returns = calc_returns()
        adv_ests = calc_adv_ests(returns)
        adv_ests = [a for traj_a in adv_ests for a in traj_a]
        return (torch.stack(adv_ests) ** 2).sum() / len(trajectories)

    # Re-fit the baseline
    last_baseline_loss = 0
    for i in range(baseline_opt_steps):
        baseline_loss = calc_baseline_loss()
        last_baseline_loss = baseline_loss
        # print("On baseline value opt step:", i, " with loss:", baseline_loss)
        baseline_opt.zero_grad()
        baseline_loss.backward()
        baseline_opt.step()

    # Update the policy using a policy gradient estimate
    returns = calc_returns()
    adv_ests = normalize_adv_ests(calc_adv_ests(returns))
    g = []
    for traj_ind in range(len(trajectories)):
        traj = trajectories[traj_ind]
        for step_ind in range(len(traj)):
            step_state, step_action, step_reward = traj[step_ind]
            if policy.is_continuous:
                mean, covar = policy(torch.tensor(step_state))
                mvg = dist.MultivariateNormal(mean, covar)
                log_pdf_val = mvg.log_prob(step_action)
                g.append(log_pdf_val * adv_ests[traj_ind][step_ind])
            else:
                action_dist = policy(torch.tensor(step_state))
                g.append(torch.log(action_dist[step_action]) * adv_ests[traj_ind][step_ind])
    # Take only one gradient step for policy to avoid overfitting
    policy_loss = -torch.stack(g).sum()
    print("--------------------")
    print("Iteration:", iter)
    print("Policy loss:", policy_loss)
    print("Last baseline loss:", last_baseline_loss)
    print("Avg return:", sum(r[0] for r in returns) / len(returns))
    print("Avg steps:", sum(len(t) for t in trajectories) / len(trajectories))
    print("--------------------")
    policy_opt.zero_grad()
    policy_loss.backward()
    policy_opt.step()

KeyboardInterrupt: 

In [52]:
# Test implementation
current_state, _ = env.reset()
current_state = torch.tensor(current_state, dtype=torch.float32)
for _ in range(max_timesteps):
    actions_dist = policy(current_state)
    if policy.is_continuous:
        mean, covar = actions_dist
        mgd = dist.MultivariateNormal(mean, covar)
        action = mgd.sample()
    else:
        d = dist.Categorical(actions_dist)
        action = d.sample()

    next_state, reward, terminated, truncated, _ = env.step(action.item())
    current_state = torch.tensor(next_state, dtype=torch.float32)

    if terminated or truncated:
        break

### Reflections

**Normalizing advantage estimates**
While training on cart pole, one problem I started encountering was that the policy was learning very very slowly and at one point seemed to be stagnating. After some investigation, I realized that when the policy would take trajectories that had a huge advantage estimate (positive or negative), that value would dominate the gradient step and the policy would be focused on fitting itself to that trajectory only, often overfitting due to very large gradients. To solve this, I normalized the advantage estimates and even though this mathematically isn't exactly the same anymore as the originall loss's gradient, it worked very well in practice.
