# From A2C to PPO

## Helper function

In [83]:
import numpy as np
import gym
import time
import scipy.signal
from gym.spaces import Box, Discrete
import pathlib

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
from torch.optim import Adam

import wandb

In [84]:
def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)


def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)


def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.
    input: 
        vector x, 
        [x0, 
         x1, 
         x2]
    output:
        [x0 + discount * x1 + discount^2 * x2,  
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]


## Model

In [85]:
class Actor(nn.Module):

    def _distribution(self, obs):
        raise NotImplementedError

    def _log_prob_from_distribution(self, pi, act):
        raise NotImplementedError

    def forward(self, obs, act=None):
        # Produce action distributions for given observations, and 
        # optionally compute the log likelihood of given actions under
        # those distributions.
        pi = self._distribution(obs)
        logp_a = None
        if act is not None:
            logp_a = self._log_prob_from_distribution(pi, act)
        return pi, logp_a


class MLPCategoricalActor(Actor):
    
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def _distribution(self, obs):
        logits = self.logits_net(obs)
        return Categorical(logits=logits)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act)


class MLPGaussianActor(Actor):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
        self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
        self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def _distribution(self, obs):
        mu = self.mu_net(obs)
        std = torch.exp(self.log_std)
        return Normal(mu, std)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act).sum(axis=-1)    # Last axis sum needed for Torch Normal distribution  # TODO: why sum?


class MLPCritic(nn.Module):

    def __init__(self, obs_dim, hidden_sizes, activation):
        super().__init__()
        self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs):
        return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.


class MLPActorCritic(nn.Module):
    """
        ===========  ================  ======================================
        Symbol       Shape             Description
        ===========  ================  ======================================
        ``a``        (batch, act_dim)  | Numpy array of actions for each 
                                       | observation.
        ``v``        (batch,)          | Numpy array of value estimates
                                       | for the provided observations.
        ``logp_a``   (batch,)          | Numpy array of log probs for the
                                       | actions in ``a``.
        ===========  ================  ======================================

        The ``pi`` module's forward call should accept a batch of 
        observations and optionally a batch of actions, and return:
        ===========  ================  ======================================
        Symbol       Shape             Description
        ===========  ================  ======================================
        ``pi``       N/A               | Torch Distribution object, containing
                                       | a batch of distributions describing
                                       | the policy for the provided observations.
        ``logp_a``   (batch,)          | Optional (only returned if batch of
                                       | actions is given). Tensor containing 
                                       | the log probability, according to 
                                       | the policy, of the provided actions.
                                       | If actions not given, will contain
                                       | ``None``.
        ===========  ================  ======================================
        
        The ``v`` module's forward call should accept a batch of observations
        and return:
        ===========  ================  ======================================
        Symbol       Shape             Description
        ===========  ================  ======================================
        ``v``        (batch,)          | Tensor containing the value estimates
                                       | for the provided observations. (Critical: 
                                       | make sure to flatten this!)
        ===========  ================  ======================================    
    
    """

    def __init__(self, observation_space, action_space, 
                 hidden_sizes=(64,64), activation=nn.Tanh):
        super().__init__()

        obs_dim = observation_space.shape[0]

        # policy builder depends on action space
        if isinstance(action_space, Box):
            self.pi = MLPGaussianActor(obs_dim, action_space.shape[0], hidden_sizes, activation)
        elif isinstance(action_space, Discrete):
            self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation)

        # build value function
        self.v = MLPCritic(obs_dim, hidden_sizes, activation)

    def step(self, obs):
        with torch.no_grad():
            pi = self.pi._distribution(obs)
            a = pi.sample()
            logp_a = self.pi._log_prob_from_distribution(pi, a)
            v = self.v(obs)
        return a.numpy(), v.numpy(), logp_a.numpy()

    def act(self, obs): # only return action
        return self.step(obs)[0]
    

## Agent

$L^{CLIP}(\theta) = E_t[min(r_t(\theta)\hat A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat A_t)]$

In [87]:
class PPOAgent():
    """ update model and take action """
    def __init__(self, observation_space, action_space, train_pi_iters, train_v_iters, clip_ratio, pi_lr, vf_lr,
                 target_kl, hidden_sizes=(256, 256), activation=nn.ReLU):
        
        self.train_pi_iters = train_pi_iters
        self.train_v_iters = train_v_iters
        self.clip_ratio = clip_ratio
        self.target_kl = target_kl
        
        self.ac = MLPActorCritic(observation_space, action_space, hidden_sizes=hidden_sizes, activation=activation)
    
        # Set up optimizers for policy and value function
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=pi_lr)
        self.vf_optimizer = Adam(self.ac.v.parameters(), lr=vf_lr)

        # Count variables
        var_counts = tuple(count_vars(module) for module in [self.ac.pi, self.ac.v])
        print('\nNumber of parameters: \t pi: %d, \t v: %d\n'%var_counts)

    # Set up function for computing PPO policy loss
    def _compute_loss_pi(self, data):
        obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data['logp']

        # Policy loss
        pi, logp = self.ac.pi(obs, act) # logp: pi_new(a_t|s_t); logp_old: pi_old(a_t|s_t), calculated by policy with params when collecting data. 
        ratio = torch.exp(logp - logp_old) 
        clip_adv = torch.clamp(ratio, 1-self.clip_ratio, 1+self.clip_ratio) * adv
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()  # expectation under pi_old, since act is sample from pi_old

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()  # TODO: Why use this?
        ent = pi.entropy().mean().item()
        clipped = ratio.gt(1+self.clip_ratio) | ratio.lt(1-self.clip_ratio) 
        clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent, cf=clipfrac)

        return loss_pi, pi_info

    # Set up function for computing value loss
    def _compute_loss_v(self, data):
        obs, ret = data['obs'], data['ret']  # TODO: important: how to calculate the ret
        return ((self.ac.v(obs) - ret)**2).mean()

    def update(self, data):

        # compute loss of pi and v before updating, used to calculate DelatLossPi and DeltaLossV
        pi_l_old, pi_info_old = self._compute_loss_pi(data)
        pi_l_old = pi_l_old.item()
        v_l_old = self._compute_loss_v(data).item()

        # Train policy with multiple steps of gradient descent
        for i in range(self.train_pi_iters): # in vanilla PG, the policy is trained with a single step
            self.pi_optimizer.zero_grad()
            loss_pi, pi_info = self._compute_loss_pi(data)
            kl = pi_info['kl']
            if kl > 1.5 * self.target_kl:
                print('Early stopping at step %d due to reaching max kl.'%i)
                break
                
            loss_pi.backward()
            self.pi_optimizer.step()

        # Value function learning
        for i in range(self.train_v_iters):
            self.vf_optimizer.zero_grad()
            loss_v = self._compute_loss_v(data)
            loss_v.backward()
            self.vf_optimizer.step()

        # Log changes from update
        kl, ent, cf = pi_info['kl'], pi_info_old['ent'], pi_info['cf']
        
        wandb.log({"LossPi":pi_l_old, "LossV":v_l_old, "KL":kl, "Entropy":ent, "ClipFrac":cf,
                  "DeltaLossPi":(loss_pi.item() - pi_l_old), 
                  "DeltaLossV":(loss_v.item() - v_l_old)})
        
    def get_action(obs):
        return self.ac.act(obs)

## Buffer

$\hat A_t = \delta_t + (\lambda \gamma)\delta_{t+1}+ \cdots + (\lambda \gamma)^{T-t+1}\delta_{T-1}$, 
where $\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$

$V(s_t) = \sum_{t'=t}^{t+n} \gamma^{t'-t} r(s_{t'}, a_{t'}) + V(s_{t+n})$

In [88]:
class PPOBuffer:
    """
    A buffer for storing trajectories experienced by a PPO agent interacting
    with the environment, and using Generalized Advantage Estimation (GAE-Lambda)
    for calculating the advantages of state-action pairs.
    """

    def __init__(self, obs_dim, act_dim, size, gamma=0.99, lam=0.95):
        self.obs_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(combined_shape(size, act_dim), dtype=np.float32)
        self.adv_buf = np.zeros(size, dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.ret_buf = np.zeros(size, dtype=np.float32)
        self.val_buf = np.zeros(size, dtype=np.float32)
        self.logp_buf = np.zeros(size, dtype=np.float32)
        self.gamma, self.lam = gamma, lam
        self.ptr, self.path_start_idx, self.max_size = 0, 0, size

    def store(self, obs, act, rew, val, logp):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
        assert self.ptr < self.max_size     # buffer has to have room so you can store
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.val_buf[self.ptr] = val
        self.logp_buf[self.ptr] = logp
        self.ptr += 1

    def finish_path(self, last_val=0):
        """
        Calculate the return and advantage:
        
        Call this at the end of a trajectory, or when one gets cut off
        by an epoch ending. This looks back in the buffer to where the
        trajectory started, and uses rewards and value estimates from
        the whole trajectory to compute advantage estimates with GAE-Lambda,
        as well as compute the rewards-to-go for each state, to use as
        the targets for the value function.
        The "last_val" argument should be 0 if the trajectory ended
        because the agent reached a terminal state (died), and otherwise
        should be V(s_T), the value function estimated for the last state.
        This allows us to bootstrap the reward-to-go calculation to account
        for timesteps beyond the arbitrary episode horizon (or epoch cutoff).
        """

        path_slice = slice(self.path_start_idx, self.ptr)
        rews = np.append(self.rew_buf[path_slice], last_val)
        vals = np.append(self.val_buf[path_slice], last_val)
        
        # the next two lines implement GAE-Lambda advantage calculation
        deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1] 
        self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam)
        
        # the next line computes rewards-to-go, to be targets for the value function
        self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)[:-1]
        
        self.path_start_idx = self.ptr

    def get(self):
        """
        Call this at the end of an epoch to get all of the data from
        the buffer, with advantages appropriately normalized (shifted to have
        mean zero and std one). Also, resets some pointers in the buffer.
        """
        assert self.ptr == self.max_size    # buffer has to be full before you can get
        self.ptr, self.path_start_idx = 0, 0
#         # the next two lines implement the advantage normalization trick
#         adv_mean, adv_std = mpi_statistics_scalar(self.adv_buf)
#         self.adv_buf = (self.adv_buf - adv_mean) / adv_std

        data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf,
                    adv=self.adv_buf, logp=self.logp_buf)
        return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in data.items()}


## Training and testing

In [91]:
if __name__ == "__main__":
    
    # setup hyperparameter
    wandb.init(project="ppo")
    config = wandb.config
    config.logdir = pathlib.Path(".")
    config.env = "HalfCheetah-v2"
    config.seed = 0
    config.steps_per_epoch = 4000
    config.epochs = 50
    config.gamma = 0.99
    config.clip_ratio = 0.2
    config.pi_lr = 3e-4
    config.vf_lr = 1e-3
    config.train_pi_iters = 80
    config.train_v_iters = 80
    config.lam = 0.97  # Lambda for GAE-Lambda. (Always between 0 and 1, close to 1.)
    config.max_ep_len = 1000
    config.target_kl = 0.01
    config.save_freq=10
        
    # setup random seed and num_threads
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    torch.set_num_threads(torch.get_num_threads())

    # setup env
    env, test_env = gym.make(config.env), gym.make(config.env)
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape
    print(obs_dim, act_dim)
    
    # replay buffer
    buffer = PPOBuffer(obs_dim, act_dim, config.steps_per_epoch, config.gamma, config.lam)
    
    # Prepare for interaction with environment
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0  # obs, ep_totoal_reward, ep_length
    
    # setup agent
    agent = PPOAgent(env.observation_space, 
                     env.action_space, 
                     config.train_pi_iters, 
                     config.train_v_iters, 
                     config.clip_ratio, 
                     config.pi_lr, 
                     config.vf_lr, 
                     config.target_kl)
    
    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(config.epochs):
        for t in range(config.steps_per_epoch):
            a, v, logp = agent.ac.step(torch.as_tensor(o, dtype=torch.float32))

            next_o, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1

            # save and log
            buffer.store(o, a, r, v, logp)
            wandb.log({"VVlas":v})
            
            # Update obs (critical!)
            o = next_o

            timeout = ep_len == config.max_ep_len
            terminal = d or timeout
            epoch_ended = t==config.steps_per_epoch-1

            if terminal or epoch_ended:
                if epoch_ended and not(terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.'%ep_len, flush=True)
                    
                # if trajectory didn't reach terminal state, bootstrap value target
                if timeout or epoch_ended:
                    _, v, _ = agent.ac.step(torch.as_tensor(o, dtype=torch.float32))
                else:
                    v = 0
                
                # calculate the advantage and ret once an eposide is finished.
                buffer.finish_path(v)
                
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    wandb.log({"EpRet":ep_ret, "EpLen":ep_len})                   

                o, ep_ret, ep_len = env.reset(), 0, 0


        # Save model
        if (epoch % config.save_freq == 0) or (epoch == config.epochs-1):
            torch.save(agent.ac.state_dict(), "model.h5")
            wandb.save("model.h5")

        # Perform PPO update!
        data = buffer.get()
        agent.update(data)

#         # Log info about epoch
        
#         logger.log_tabular('Epoch', epoch)
#         logger.log_tabular('EpRet', with_min_and_max=True)
#         logger.log_tabular('EpLen', average_only=True)
#         logger.log_tabular('VVals', with_min_and_max=True)
#         logger.log_tabular('TotalEnvInteracts', (epoch+1)*steps_per_epoch)
#         logger.log_tabular('LossPi', average_only=True)
#         logger.log_tabular('LossV', average_only=True)
#         logger.log_tabular('DeltaLossPi', average_only=True)
#         logger.log_tabular('DeltaLossV', average_only=True)
#         logger.log_tabular('Entropy', average_only=True)
#         logger.log_tabular('KL', average_only=True)
#         logger.log_tabular('ClipFrac', average_only=True)
#         logger.log_tabular('StopIter', average_only=True)
#         logger.log_tabular('Time', time.time()-start_time)
#         logger.dump_tabular()




VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
VVlas,-28.6373
_step,8008.0
_runtime,16.0
_timestamp,1602789867.0
EpRet,-499.36128
EpLen,1000.0
LossPi,8.74086
LossV,1200.65735
KL,0.07369
Entropy,0.91894


0,1
VVlas,████████████████████▄▄▅▃▄▄▃▃▄▅▁▂▅▃▅▃▃▂▃▃
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▃▃▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇███
_timestamp,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▃▃▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇███
EpRet,█▇▅▆▆▇▇▁
EpLen,▁▁▁▁▁▁▁▁
LossPi,▁
LossV,▁
KL,▁
Entropy,▁


[34m[1mwandb[0m: wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


(17,) (6,)

Number of parameters: 	 pi: 71948, 	 v: 70657

Early stopping at step 1 due to reaching max kl.
Early stopping at step 2 due to reaching max kl.
Early stopping at step 49 due to reaching max kl.
Early stopping at step 6 due to reaching max kl.
Early stopping at step 6 due to reaching max kl.
Early stopping at step 3 due to reaching max kl.
Early stopping at step 2 due to reaching max kl.
Early stopping at step 4 due to reaching max kl.


KeyboardInterrupt: 