#### TODO
- Write out training video to visualize models

In [1]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import gym
from gym.spaces import Discrete, Box

## Policy Gradient Methods
Starting with the basics, we have some policy $\pi_\theta$, and some return $R$. The goal of our reinforcement learning algorithm is to maximize the expected return of the policy, which for convience sake is denoted as $J(\pi_\theta)$
$$E[R(\pi_\theta)] = J(\pi_\theta)$$
One way to optimize this policy would be to use gradient descent, for example
$$\theta_{k+1} = \theta_k + \alpha \cdot \nabla J(\pi_\theta)$$

Note that the big thing here is finding $\nabla J(\pi_\theta)$, and the algorithms that do this are called *policy gradient algorithms*

**Derivation** breaks up into 2 parts
- First we figure out what the gradient is supposed to be (mathematically)
- Second we go about building a model to estimate that gradient as efficiently as possible

### Part 1- Finding the Gradient
$$ \nabla J(\pi_\theta) = \nabla E[R(\pi_\theta)] = \nabla \int_\tau P(\tau ~|~ \theta) R(\tau) = \int_\tau \nabla P(\tau ~|~ \theta) R(\tau)$$
$$\int_\tau \nabla P(\tau ~|~ \theta) R(\tau) = \int_\tau P(\tau ~|~ \theta) \nabla \log P(\tau ~|~ \theta) R(\tau) = E[\nabla \log P(\tau ~|~ \theta) R(\tau)] = E\left[\sum \log \pi(a_t ~|~ s_t) R(\tau)\right]$$
[more details](https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#deriving-the-simplest-policy-gradient)

### Part 2 - Building the Algorithm
Now that we know
$$\nabla J(\pi_\theta) =  E\left[\sum \log \pi(a_t ~|~ s_t) R(\tau)\right]$$
This expression is now written in such a way we can approximate the expectation or mean, by sampling large samples
$$E\left[\sum \log \pi(a_t ~|~ s_t) R(\tau)\right] \approx \frac{1}{|D|} \sum_{\tau \in D} \sum_{t=0}^T \nabla \log \pi_\theta (a_t~|~s_t)R(\tau) $$

In [2]:
def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
    # Build a feedforward neural network.
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

In [3]:
def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2, 
          epochs=50, batch_size=5000, render=False):

    # make environment, check spaces, get obs / act dims
    env = gym.make(env_name)
    assert isinstance(env.observation_space, Box), \
        "This example only works for envs with continuous state spaces."
    assert isinstance(env.action_space, Discrete), \
        "This example only works for envs with discrete action spaces."

    obs_dim = env.observation_space.shape[0]
    n_acts = env.action_space.n

    # make core of policy network
    logits_net = mlp(sizes=[obs_dim]+hidden_sizes+[n_acts])

    # make function to compute action distribution
    def get_policy(obs):
        logits = logits_net(obs)
        return Categorical(logits=logits)

    # make action selection function (outputs int actions, sampled from policy)
    def get_action(obs):
        return get_policy(obs).sample().item()

    # make loss function whose gradient, for the right data, is policy gradient
    def compute_loss(obs, act, weights):
        logp = get_policy(obs).log_prob(act)
        return -(logp * weights).mean()

    # make optimizer
    optimizer = Adam(logits_net.parameters(), lr=lr)

    # for training policy
    def train_one_epoch():
        # make some empty lists for logging.
        batch_obs = []          # for observations
        batch_acts = []         # for actions
        batch_weights = []      # for R(tau) weighting in policy gradient
        batch_rets = []         # for measuring episode returns
        batch_lens = []         # for measuring episode lengths

        # reset episode-specific variables
        obs = env.reset()       # first obs comes from starting distribution
        done = False            # signal from environment that episode is over
        ep_rews = []            # list for rewards accrued throughout ep

        # render first episode of each epoch
        finished_rendering_this_epoch = False

        # collect experience by acting in the environment with current policy
        while True:

            # rendering
            if (not finished_rendering_this_epoch) and render:
                env.render()

            # save obs
            batch_obs.append(obs.copy())

            # act in the environment
            act = get_action(torch.as_tensor(obs, dtype=torch.float32))
            obs, rew, done, _ = env.step(act)

            # save action, reward
            batch_acts.append(act)
            ep_rews.append(rew)

            if done:
                # if episode is over, record info about episode
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)

                # the weight for each logprob(a|s) is R(tau)
                batch_weights += [ep_ret] * ep_len

                # reset episode-specific variables
                obs, done, ep_rews = env.reset(), False, []

                # won't render again this epoch
                finished_rendering_this_epoch = True

                # end experience loop if we have enough of it
                if len(batch_obs) > batch_size:
                    break

        # take a single policy gradient update step
        optimizer.zero_grad()
        batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
                                  act=torch.as_tensor(batch_acts, dtype=torch.int32),
                                  weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                                  )
        batch_loss.backward()
        optimizer.step()
        return batch_loss, batch_rets, batch_lens

    # training loop
    for i in range(epochs):
        batch_loss, batch_rets, batch_lens = train_one_epoch()
        print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))

In [4]:
train(env_name='CartPole-v0', render=False, lr=1e-2)



epoch:   0 	 loss: 17.483 	 return: 20.279 	 ep_len: 20.279
epoch:   1 	 loss: 22.605 	 return: 23.933 	 ep_len: 23.933
epoch:   2 	 loss: 22.555 	 return: 25.677 	 ep_len: 25.677
epoch:   3 	 loss: 28.089 	 return: 29.857 	 ep_len: 29.857
epoch:   4 	 loss: 31.012 	 return: 34.986 	 ep_len: 34.986
epoch:   5 	 loss: 32.365 	 return: 35.610 	 ep_len: 35.610
epoch:   6 	 loss: 34.119 	 return: 38.462 	 ep_len: 38.462
epoch:   7 	 loss: 40.085 	 return: 44.416 	 ep_len: 44.416
epoch:   8 	 loss: 37.368 	 return: 47.667 	 ep_len: 47.667
epoch:   9 	 loss: 38.159 	 return: 47.657 	 ep_len: 47.657
epoch:  10 	 loss: 46.638 	 return: 59.631 	 ep_len: 59.631
epoch:  11 	 loss: 44.030 	 return: 57.364 	 ep_len: 57.364
epoch:  12 	 loss: 45.956 	 return: 60.452 	 ep_len: 60.452
epoch:  13 	 loss: 42.201 	 return: 61.108 	 ep_len: 61.108
epoch:  14 	 loss: 49.790 	 return: 66.447 	 ep_len: 66.447
epoch:  15 	 loss: 54.297 	 return: 69.932 	 ep_len: 69.932
epoch:  16 	 loss: 50.655 	 return: 71.0

### Expected Grad Log Prob Lemma
$$
E[\nabla \log P_\theta(x) ] = 0
$$
[Proof](https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#expected-grad-log-prob-lemma)

### Reward to go
Note that in our original formualtion, the gradient is adjusted in accordance with all rewards which are obtained. Reality might dictate that instead we focus only on rewards taken an action not before, as the action was of no consequence for those awards. Call this the Reward to Go policy, the expression might be

$$\nabla J(\pi_\theta) = E\left[ \sum_{t=0}^T \log \pi(a_t ~|~ s_t) \sum_{t'=t}^T R(s_{t'}, a_{t'}, s_{t'+1}) \right]$$

(note that most of this - including the feedforward architecture - is the same as our above simple policy gradient method)

In [5]:
def reward_to_go(rews):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0)
    return rtgs

def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2, 
          epochs=50, batch_size=5000, render=False):

    # make environment, check spaces, get obs / act dims
    env = gym.make(env_name)
    assert isinstance(env.observation_space, Box), \
        "This example only works for envs with continuous state spaces."
    assert isinstance(env.action_space, Discrete), \
        "This example only works for envs with discrete action spaces."

    obs_dim = env.observation_space.shape[0]
    n_acts = env.action_space.n

    # make core of policy network
    logits_net = mlp(sizes=[obs_dim]+hidden_sizes+[n_acts])

    # make function to compute action distribution
    def get_policy(obs):
        logits = logits_net(obs)
        return Categorical(logits=logits)

    # make action selection function (outputs int actions, sampled from policy)
    def get_action(obs):
        return get_policy(obs).sample().item()

    # make loss function whose gradient, for the right data, is policy gradient
    def compute_loss(obs, act, weights):
        logp = get_policy(obs).log_prob(act)
        return -(logp * weights).mean()

    # make optimizer
    optimizer = Adam(logits_net.parameters(), lr=lr)

    # for training policy
    def train_one_epoch():
        # make some empty lists for logging.
        batch_obs = []          # for observations
        batch_acts = []         # for actions
        batch_weights = []      # for reward-to-go weighting in policy gradient
        batch_rets = []         # for measuring episode returns
        batch_lens = []         # for measuring episode lengths

        # reset episode-specific variables
        obs = env.reset()       # first obs comes from starting distribution
        done = False            # signal from environment that episode is over
        ep_rews = []            # list for rewards accrued throughout ep

        # render first episode of each epoch
        finished_rendering_this_epoch = False

        # collect experience by acting in the environment with current policy
        while True:

            # rendering
            if (not finished_rendering_this_epoch) and render:
                env.render()

            # save obs
            batch_obs.append(obs.copy())

            # act in the environment
            act = get_action(torch.as_tensor(obs, dtype=torch.float32))
            obs, rew, done, _ = env.step(act)

            # save action, reward
            batch_acts.append(act)
            ep_rews.append(rew)

            if done:
                # if episode is over, record info about episode
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)

                # the weight for each logprob(a_t|s_t) is reward-to-go from t
                batch_weights += list(reward_to_go(ep_rews))

                # reset episode-specific variables
                obs, done, ep_rews = env.reset(), False, []

                # won't render again this epoch
                finished_rendering_this_epoch = True

                # end experience loop if we have enough of it
                if len(batch_obs) > batch_size:
                    break

        # take a single policy gradient update step
        optimizer.zero_grad()
        batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
                                  act=torch.as_tensor(batch_acts, dtype=torch.int32),
                                  weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                                  )
        batch_loss.backward()
        optimizer.step()
        return batch_loss, batch_rets, batch_lens

    # training loop
    for i in range(epochs):
        batch_loss, batch_rets, batch_lens = train_one_epoch()
        print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))



In [6]:
train(env_name='CartPole-v0', render=False, lr=1e-2)

epoch:   0 	 loss: 10.911 	 return: 23.758 	 ep_len: 23.758
epoch:   1 	 loss: 10.999 	 return: 25.293 	 ep_len: 25.293
epoch:   2 	 loss: 12.845 	 return: 27.762 	 ep_len: 27.762
epoch:   3 	 loss: 14.065 	 return: 31.880 	 ep_len: 31.880
epoch:   4 	 loss: 16.821 	 return: 34.979 	 ep_len: 34.979
epoch:   5 	 loss: 19.236 	 return: 40.589 	 ep_len: 40.589
epoch:   6 	 loss: 16.043 	 return: 36.022 	 ep_len: 36.022
epoch:   7 	 loss: 21.132 	 return: 49.069 	 ep_len: 49.069
epoch:   8 	 loss: 20.956 	 return: 50.616 	 ep_len: 50.616
epoch:   9 	 loss: 22.223 	 return: 52.811 	 ep_len: 52.811
epoch:  10 	 loss: 24.308 	 return: 58.918 	 ep_len: 58.918
epoch:  11 	 loss: 22.412 	 return: 61.427 	 ep_len: 61.427
epoch:  12 	 loss: 23.503 	 return: 63.430 	 ep_len: 63.430
epoch:  13 	 loss: 29.038 	 return: 73.145 	 ep_len: 73.145
epoch:  14 	 loss: 27.452 	 return: 71.648 	 ep_len: 71.648
epoch:  15 	 loss: 33.064 	 return: 90.018 	 ep_len: 90.018
epoch:  16 	 loss: 34.274 	 return: 91.9