<a href="https://colab.research.google.com/github/EffiSciencesResearch/ML4G-2.0/blob/master/workshops/vanilla_policy_gradient/vanilla_policy_gradient.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Vanilla Policy Optimisation

We will be looking at an implementation of the vanilla policy gradient algorithm, to train a policy to play CartPole: the goal is to balance a stick on a sliding cart. The agent can move the cart left or right. The episode ends when the stick falls over too much, or the cart moves too far away from the center.

![CartPole](https://pytorch.org/tutorials/_images/cartpole.gif)



Read all the code, then:
- Complete the ... in the compute_loss function.
- Use https://github.com/patrick-kidger/torchtyping to type the functions get_policy, get_action. You can draw inspiration from the compute_loss function.
- Answer the questions

Questions:
- Run the script with the defaults parameters on the terminal
- Explain from torch.distributions.categorical import Categorical
- Is vanilla policy gradient (VPG) model based or model free?
- Is VPG on-policy or off-policy?
- google gym python, why is it useful?

Don't begin working on this algorithms if you don't understand the blog: https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html

This exercise is short, but you should aim to understand everything in this code. Simply completing the types is not sufficient. The important thing here is to have a good understanding of each line of code, as well as the policy gradient theorem that we are using.

In [None]:
!pip install jaxtyping typeguard==2.13.3 gym==0.25.2

In [None]:
import torch
import torch.nn as nn
from torch import Tensor
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Discrete, Box

from jaxtyping import Float, Int, jaxtyped
from typeguard import typechecked


def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
    # Build a feedforward neural network.
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]

    # What does * mean here? Search for unpacking in python
    return nn.Sequential(*layers)


def train(
    env_name="CartPole-v0", hidden_sizes=[32], lr=1e-2, epochs=50, batch_size=5000, render=False
):

    # make environment, check spaces, get obs / act dims
    env = gym.make(env_name)
    assert isinstance(
        env.observation_space, Box
    ), "This example only works for envs with continuous state spaces."
    assert isinstance(
        env.action_space, Discrete
    ), "This example only works for envs with discrete action spaces."

    obs_dim = env.observation_space.shape[0]
    n_acts = env.action_space.n

    # Core of policy network
    # What should be the sizes of the layers of the policy network?
    logits_net = mlp(sizes=[obs_dim] + hidden_sizes + [n_acts])

    # make function to compute action distribution
    @jaxtyped  # Checks that the sizes are consistent between tensors
    @typechecked  # TODO: What is the shape of obs?
    def get_policy(obs: Float[Tensor, "batch obs_dim"]) -> Categorical:
        logits = logits_net(obs)
        # Tip: Categorical is a convenient pytorch object which enable register logits (or a batch of logits)
        # and then being able to sample from this pseudo-probability distribution with the ".sample()" method.
        return Categorical(logits=logits)

    # make action selection function (outputs int actions, sampled from policy)
    # What is the shape of obs?
    @jaxtyped
    @typechecked  # TODO: What is the shape of obs?
    def get_action(obs: Float[Tensor, "obs_dim"]) -> int:
        return get_policy(obs.unsqueeze(0)).sample().item()

    # make loss function whose gradient, for the right data, is policy gradient
    @jaxtyped
    @typechecked
    def compute_loss(
        obs: Float[Tensor, "batch obs_dim"],
        acts: Int[Tensor, "batch"],
        rewards: Float[Tensor, "batch"],
    ) -> Float[Tensor, ""]:
        # TODO:
        # rewards: a piecewise constant vector containing the total reward of each episode.

        # Use the get_policy function to get the categorical object, then sample from it with the 'log_prob' method.‹
        log_probs = get_policy(obs).log_prob(acts)
        return -(log_probs * rewards).mean()

    # make optimizer
    optimizer = Adam(logits_net.parameters(), lr=lr)

    # for training policy
    def train_one_epoch():
        # make some empty lists for logging.
        batch_obs = []  # for observations
        batch_acts = []  # for actions
        batch_weights = []  # for R(tau) weighting in policy gradient
        batch_rets = []  # for measuring episode returns # What is the return?
        batch_lens = []  # for measuring episode lengths

        # reset episode-specific variables
        obs, _ = env.reset()  # first obs comes from starting distribution
        ep_rews = []  # list for rewards accrued throughout ep

        # render first episode of each epoch
        finished_rendering_this_epoch = False

        # collect experience by acting in the environment with current policy
        while True:

            # rendering
            if (not finished_rendering_this_epoch) and render:
                env.render()

            # save obs
            batch_obs.append(obs.copy())

            # act in the environment
            act = get_action(torch.as_tensor(obs, dtype=torch.float32))
            obs, rew, terminated, truncated, _ = env.step(act)

            # save action, reward
            batch_acts.append(act)
            ep_rews.append(rew)

            if terminated or truncated:
                # if episode is over, record info about episode
                # Is the reward discounted?
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)

                # the weight for each logprob(a|s) is R(tau)
                # Why do we use a constant vector here?
                batch_weights += [ep_ret] * ep_len

                # reset episode-specific variables
                obs, _ = env.reset()
                ep_rews = []

                # won't render again this epoch
                finished_rendering_this_epoch = True

                # end experience loop if we have enough of it
                if len(batch_obs) > batch_size:
                    break

        # take a single policy gradient update step
        optimizer.zero_grad()

        batch_loss = compute_loss(
            obs=torch.as_tensor(batch_obs, dtype=torch.float32),
            acts=torch.as_tensor(batch_acts, dtype=torch.int32),
            rewards=torch.as_tensor(batch_weights, dtype=torch.float32),
        )
        batch_loss.backward()
        optimizer.step()
        return batch_loss, batch_rets, batch_lens

    # training loop
    for i in range(epochs):
        batch_loss, batch_rets, batch_lens = train_one_epoch()
        print(
            "epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f"
            % (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens))
        )


train(env_name="CartPole-v0", hidden_sizes=[32], lr=1e-2, epochs=50, batch_size=50, render=False)

Original algo here: https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/vpg/vpg.py