# ***Parallel PPO 2***

In [1]:
from typing import Sequence

import distrax
import flax.linen as nn
import gymnax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from gymnax.wrappers.purerl import FlattenObservationWrapper, LogWrapper
from dataclasses import dataclass
import time
import pandas as pd
import plotly.express as px

from utils import Transition

In [2]:
class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)


In [3]:
@dataclass
class Args:
    seed: int = 0
    save_model: bool = False
    log_results: bool = False

    wandb_project_name: str = "improved-gradient-steps"
    wandb_entity: str = "rpegoud"
    logging_dir: str = "."

    # Algorithm specific arguments
    trainer: str = "base_ppo"
    env_name: str = "CartPole-v1"
    total_timesteps: int = 5e4
    learning_rate: float = 2.5e-4
    n_agents: int = 16
    num_envs: int = 4
    num_steps: int = 128
    update_epochs: int = 4
    num_minibatches: int = 4
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    alpha: float = 0.2
    activation: str = "tanh"
    anneal_lr: bool = True
    debug: bool = False

In [6]:
args = Args()
NUM_UPDATES = args.total_timesteps // args.num_steps // args.num_envs
MINIBATCH_SIZE = args.num_envs * args.num_steps // args.num_minibatches
env, env_params = gymnax.make(args.env_name)
env = FlattenObservationWrapper(env)
env = LogWrapper(env)
rng = jax.random.PRNGKey(0)

def linear_schedule(count):
    frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / NUM_UPDATES
    return args.learning_rate * frac


# INIT NETWORK
network = ActorCritic(env.action_space(env_params).n, activation=args.activation)
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_space(env_params).shape)
network_params = network.init(_rng, init_x)
if args.anneal_lr:
    tx = optax.chain(
        optax.clip_by_global_norm(args.max_grad_norm),
        optax.adam(learning_rate=linear_schedule, eps=1e-5),
    )
else:
    tx = optax.chain(
        optax.clip_by_global_norm(args.max_grad_norm),
        optax.adam(args.learning_rate, eps=1e-5),
    )
train_state = TrainState.create(
    apply_fn=network.apply,
    params=network_params,
    tx=tx,
)

# INIT ENV
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, args.num_envs)
obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
rng, _rng = jax.random.split(rng)
runner_state = (train_state, env_state, obsv, _rng)
# TRAIN LOOP


def _env_step(runner_state, unused):
    """
    Steps the environment across ``num_envs``.
    Returns the updated runner state and observation.
    """
    train_state, env_state, last_obs, rng = runner_state

    # SELECT ACTION
    rng, _rng = jax.random.split(rng)
    pi, value = network.apply(train_state.params, last_obs)
    actions = pi.sample(seed=_rng)
    log_prob = pi.log_prob(actions)

    # STEP ENV
    rng, _rng = jax.random.split(rng)
    rng_step = jax.random.split(_rng, args.num_envs)
    obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
        rng_step, env_state, actions, env_params
    )
    transition = Transition(done, actions, value, reward, log_prob, last_obs, info)
    runner_state = (train_state, env_state, obsv, rng)
    return runner_state, transition


runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, args.num_steps)

# CALCULATE ADVANTAGE
train_state, env_state, last_obs, rng = runner_state
# get the last value estimate to initialize gae computation
_, last_val = network.apply(train_state.params, last_obs)


def _calculate_gae(traj_batch, last_val):
    """
    Compute the generalized advantage estimation of a trajectory batch.
    """

    def _get_advantages(gae_and_next_value, transition):
        """
        Iteratively computes the GAE starting from the last transition.
        Uses `lax.scan` to carry the current (`gae`, `next_value`) tuple
        while iterating through transitions.
        """
        gae, next_value = gae_and_next_value
        done, value, reward = (
            transition.done,
            transition.value,
            transition.reward,
        )
        # td-error
        delta = reward + args.gamma * next_value * (1 - done) - value
        # generalized advantage in recursive form
        gae = delta + args.gamma * args.gae_lambda * (1 - done) * gae
        return (gae, value), gae  # (carry_over), collected results

    _, advantages = jax.lax.scan(
        _get_advantages,
        (jnp.zeros_like(last_val), last_val),
        traj_batch,
        # gae is computed backwards as the advantage at time t
        # depends on the estimated advantages of future timesteps
        reverse=True,
        # unrolls the loop body of the scan operation 16 iterations at a time
        # enables the 128 steps (default value) to be completed in 8 iterations
        unroll=16,
    )
    return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)

In [11]:
def _loss_fn(params, traj_batch, gae, targets):
    # RERUN NETWORK
    pi, value = network.apply(params, traj_batch.obs)
    log_prob = pi.log_prob(traj_batch.action)

    # CALCULATE VALUE LOSS
    value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
        -args.clip_eps, args.clip_eps
    )
    value_losses = jnp.square(value - targets)
    value_losses_clipped = jnp.square(value_pred_clipped - targets)
    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

    # CALCULATE ACTOR LOSS
    ratio = jnp.exp(log_prob - traj_batch.log_prob)
    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
    loss_actor1 = ratio * gae
    loss_actor2 = (
        jnp.clip(
            ratio,
            1.0 - args.clip_eps,
            1.0 + args.clip_eps,
        )
        * gae
    )
    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
    loss_actor = loss_actor.mean()
    entropy = pi.entropy().mean()

    total_loss = loss_actor + args.vf_coef * value_loss - args.ent_coef * entropy
    return total_loss, (value_loss, loss_actor, entropy)


grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(train_state.params, traj_batch, advantages, targets)
train_state = train_state.apply_gradients(grads=grads)
jax.tree_map(lambda x: x.shape, grads)

{'params': {'Dense_0': {'bias': (64,), 'kernel': (4, 64)},
  'Dense_1': {'bias': (64,), 'kernel': (64, 64)},
  'Dense_2': {'bias': (2,), 'kernel': (64, 2)},
  'Dense_3': {'bias': (64,), 'kernel': (4, 64)},
  'Dense_4': {'bias': (64,), 'kernel': (64, 64)},
  'Dense_5': {'bias': (1,), 'kernel': (64, 1)}}}

In [12]:
import jax
import jax.numpy as jnp


def _loss_fn(params, traj_batch, gae, targets):
    pi, value = network.apply(params, traj_batch.obs)
    log_prob = pi.log_prob(traj_batch.action)

    value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
        -args.clip_eps, args.clip_eps
    )
    value_losses = jnp.square(value - targets)
    value_losses_clipped = jnp.square(value_pred_clipped - targets)
    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped)

    ratio = jnp.exp(log_prob - traj_batch.log_prob)
    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
    loss_actor1 = ratio * gae
    loss_actor2 = (
        jnp.clip(
            ratio,
            1.0 - args.clip_eps,
            1.0 + args.clip_eps,
        )
        * gae
    )
    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)

    entropy = pi.entropy()

    return value_loss, loss_actor, entropy


# Vectorize the loss function to compute per-sample gradients
_loss_fn_vmap = jax.vmap(_loss_fn, in_axes=(None, 0, 0, 0))


# Wrapper function to compute total loss and individual losses
def compute_losses_and_grads(params, traj_batch, gae, targets):
    value_losses, actor_losses, entropies = _loss_fn_vmap(
        params, traj_batch, gae, targets
    )

    total_losses = (
        actor_losses + args.vf_coef * value_losses - args.ent_coef * entropies
    )

    value_grads = jax.grad(
        lambda p: jnp.sum(_loss_fn_vmap(p, traj_batch, gae, targets)[0])
    )(params)
    actor_grads = jax.grad(
        lambda p: jnp.sum(_loss_fn_vmap(p, traj_batch, gae, targets)[1])
    )(params)
    entropy_grads = jax.grad(
        lambda p: jnp.sum(_loss_fn_vmap(p, traj_batch, gae, targets)[2])
    )(params)

    return (
        total_losses,
        value_losses,
        actor_losses,
        entropies,
        value_grads,
        actor_grads,
        entropy_grads,
    )


(
    total_losses,
    value_losses,
    actor_losses,
    entropies,
    value_grads,
    actor_grads,
    entropy_grads,
) = compute_losses_and_grads(train_state.params, traj_batch, advantages, targets)

In [13]:
import jax
import jax.numpy as jnp


def _loss_fn(params, traj_batch, gae, targets):
    # RERUN NETWORK
    pi, value = network.apply(params, traj_batch.obs)
    log_prob = pi.log_prob(traj_batch.action)

    # CALCULATE VALUE LOSS
    value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
        -args.clip_eps, args.clip_eps
    )
    value_losses = jnp.square(value - targets)
    value_losses_clipped = jnp.square(value_pred_clipped - targets)
    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped)

    # CALCULATE ACTOR LOSS
    ratio = jnp.exp(log_prob - traj_batch.log_prob)
    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
    loss_actor1 = ratio * gae
    loss_actor2 = (
        jnp.clip(
            ratio,
            1.0 - args.clip_eps,
            1.0 + args.clip_eps,
        )
        * gae
    )
    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
    entropy = pi.entropy()

    return value_loss, loss_actor, entropy


def compute_per_sample_grads(params, traj_batch, gae, targets):
    value_losses, actor_losses, entropies = _loss_fn_vmap(
        params, traj_batch, gae, targets
    )

    def total_loss_fn(p):
        # per sample losses
        value_losses, actor_losses, entropies = jax.vmap(
            _loss_fn, in_axes=(None, 0, 0, 0)
        )(p, traj_batch, gae, targets)
        
        total_losses = (
            actor_losses + args.vf_coef * value_losses - args.ent_coef * entropies
        )
        return total_losses.sum(), (value_losses, actor_losses, entropies)

    grad_fn = jax.value_and_grad(total_loss_fn, has_aux=True)
    (_, (value_losses, actor_losses, entropies)), total_grads = grad_fn(params)

    value_grads = jax.grad(lambda p: jnp.sum(value_losses))(params)
    actor_grads = jax.grad(lambda p: jnp.sum(actor_losses))(params)
    entropy_grads = jax.grad(lambda p: jnp.sum(entropies))(params)

    return (
        total_grads,
        value_losses,
        actor_losses,
        entropies,
        value_grads,
        actor_grads,
        entropy_grads,
    )


(
    total_grads,
    value_losses,
    actor_losses,
    entropies,
    value_grads,
    actor_grads,
    entropy_grads,
) = compute_per_sample_grads(train_state.params, traj_batch, advantages, targets)

In [19]:
import chex

In [21]:
chex.assert_trees_all_equal_shapes_and_dtypes(value_grads, actor_grads, entropy_grads)