In [1]:
total_timesteps = 500000 # total timesteps of the experiment
learning_rate = 3e-4 # the learning rate of the optimizer
num_envs = 1 # the number of parallel environments
num_steps = 128 # the number of steps to run in each environment per policy rollout
gamma = 0.99 # the discount factor gamma
gae_lambda = 0.95 # the lambda for the general advantage estimation
num_minibatches = 4 # the number of mini batches
update_epochs = 4 # the K epochs to update the policy
clip_coef = 0.2 # the surrogate clipping coefficient
ent_coef = 0.01 # coefficient of the entropy
vf_coef = 0.5 # coefficient of the value function
max_grad_norm = 0.5 # the maximum norm for the gradient clipping
seed = 1 # seed for reproducible benchmarks
exp_name = 'PPO' # unique experiment name
env_id= "CartPole-v1" # id of the environment
capture_video = False # whether to save video of agent gameplay

batch_size = num_envs * num_steps # size of the batch after one rollout
minibatch_size = batch_size // num_minibatches # size of the mini batch
num_updates = total_timesteps // batch_size # the number of learning cycle

In [2]:
import string
import gymnasium as gym
import numpy as np

def make_env(env_id: string, idx: int, capture_video: bool, run_name: string):
    def thunk():
        if capture_video:
            env = gym.make(env_id, render_mode='rgb_array')
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{env_id}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        return env

    return thunk

In [3]:
envs = gym.vector.SyncVectorEnv(
    [make_env(env_id, i, capture_video, exp_name) for i in range(num_envs)]
) # AsyncVectorEnv is faster, but we cannot extract single environment from it
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
obs, _ = envs.reset()

In [4]:
import flax.linen as nn

# Helper function to quickly declare linear layer with weight and bias initializers
def linear_layer_init(features, std=np.sqrt(2), bias_const=0.0):
    layer = nn.Dense(features=features, kernel_init=nn.initializers.orthogonal(std), bias_init=nn.initializers.constant(bias_const))
    return layer

In [5]:
from jax import Array
import jax.numpy as jnp

class Actor(nn.Module):
    action_n: int

    @nn.compact
    def __call__(self, x: Array):
        return nn.Sequential([
            linear_layer_init(64),
            nn.tanh,
            linear_layer_init(64),
            nn.tanh,
            linear_layer_init(self.action_n, std=0.01),
        ])(x)

class Critic(nn.Module):
    @nn.compact
    def __call__(self, x: Array):
        return nn.Sequential([
            linear_layer_init(64),
            nn.tanh,
            linear_layer_init(64),
            nn.tanh,
            linear_layer_init(1, std=1.0),
        ])(x)

actor = Actor(action_n=envs.single_action_space.n) # For jit we need to declare prod outside of class
critic = Critic()

In [6]:
import jax.random as random

# Setting seed of the environment for reproduction
key = random.PRNGKey(seed)
np.random.seed(seed)

key, actor_key, critic_key, action_key, permutation_key = random.split(key, num=5)

# Initializing agent parameters
actor_params = actor.init(actor_key, obs)
critic_params = critic.init(critic_key, obs)

2023-06-21 14:41:07.317529: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-06-21 14:41:07.317574: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 27983872 bytes free, 11712004096 bytes total.
2023-06-21 14:41:07.317608: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:453] Possibly insufficient driver version: 530.41.3


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [None]:
import optax

# Anneal learning rate over time
def linear_schedule(count):
    frac = 1.0 - (count // (num_minibatches * update_epochs)) / num_updates
    return learning_rate * frac

tx = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.inject_hyperparams(optax.adamw)(
        learning_rate=linear_schedule,
        eps=1e-5
    )
)

In [None]:
from flax.core import FrozenDict
from flax.struct import dataclass

@dataclass
class AgentParams:
    actor_params: FrozenDict
    critic_params: FrozenDict

In [9]:
from jax import jit
from typing import Callable
from flax.training.train_state import TrainState
from flax import struct

# Probably jitting isn't needed as this functions should be jitted already
actor.apply = jit(actor.apply)
critic.apply = jit(critic.apply)

class AgentState(TrainState):
    # Setting default values for agent functions to make TrainState work in jitted function
    actor_fn: Callable = struct.field(pytree_node=False)
    critic_fn: Callable = struct.field(pytree_node=False)

agent_state = AgentState.create(
    params=AgentParams(
        actor_params=actor_params,
        critic_params=critic_params
    ),
    tx=tx,
    # As we have separated actor and critic we don't use apply_fn
    apply_fn=None,
    actor_fn=actor.apply,
    critic_fn=critic.apply
)

In [10]:
tx = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.inject_hyperparams(optax.adamw)(
        learning_rate=linear_schedule,
        eps=1e-5
    )
)

agent_state = AgentState.create(
    params=AgentParams(
        actor_params=agent_state.params.actor_params,
        critic_params=agent_state.params.critic_params
    ),
    tx=tx,
    apply_fn=None,
    actor_fn=actor.apply,
    critic_fn=critic.apply
)

In [11]:
@dataclass
class Storage:
    obs: jnp.array
    actions: jnp.array
    logprobs: jnp.array
    dones: jnp.array
    values: jnp.array
    advantages: jnp.array
    returns: jnp.array
    rewards: jnp.array

In [12]:
from numpy import ndarray
#import tensorflow_probability.substrates.jax.distributions as tfp
import distrax

@jit
def get_action_and_value(agent_state: AgentState, next_obs: ndarray, next_done: ndarray, storage: Storage, step: int, key: random.PRNGKeyArray):
    action_logits = agent_state.actor_fn(agent_state.params.actor_params, next_obs)
    value = agent_state.critic_fn(agent_state.params.critic_params, next_obs)

    # Sample discrete actions from Normal distribution
    #probs = tfp.Categorical(action_logits)
    probs = distrax.Categorical(action_logits)
    key, subkey = random.split(key)
    action = probs.sample(seed=subkey)
    logprob = probs.log_prob(action)
    storage = storage.replace(
        obs=storage.obs.at[step].set(next_obs),
        dones=storage.dones.at[step].set(next_done),
        actions=storage.actions.at[step].set(action),
        logprobs=storage.logprobs.at[step].set(logprob),
        values=storage.values.at[step].set(value.squeeze()),
    )
    return storage, action, key

In [14]:
@jit
def get_action_and_value2(agent_state: AgentState, params: AgentParams, obs: ndarray, action: ndarray):
    action_logits = agent_state.actor_fn(params.actor_params, obs)
    value = agent_state.critic_fn(params.critic_params, obs)

    #probs = tfp.Categorical(action_logits)
    probs = distrax.Categorical(action_logits)
    return probs.log_prob(action), probs.entropy(), value.squeeze()

In [17]:
#from flax.metrics.tensorboard import SummaryWriter
from jax import device_get

def rollout(
        agent_state: AgentState,
        next_obs: ndarray,
        next_done: ndarray,
        storage: Storage,
        key: random.PRNGKeyArray,
        global_step: int,
):
    for step in range(0, num_steps):
        global_step += 1 * num_envs
        storage, action, key = get_action_and_value(agent_state, next_obs, next_done, storage, step, key)
        next_obs, reward, terminated, truncated, infos = envs.step(device_get(action))
        next_done = terminated | truncated
        storage = storage.replace(rewards=storage.rewards.at[step].set(reward))

        # Only print when at least 1 env is done
        if "final_info" not in infos:
            continue

        for info in infos["final_info"]:
            # Skip the envs that are not done
            if info is None:
                continue

    return next_obs, next_done, storage, key, global_step

In [18]:
@jit
def compute_gae(
        agent_state: AgentState,
        next_obs: ndarray,
        next_done: ndarray,
        storage: Storage
):
    # Reset advantages values
    storage = storage.replace(advantages=storage.advantages.at[:].set(0.0))
    next_value = agent_state.critic_fn(agent_state.params.critic_params, next_obs).squeeze()
    # Compute advantage using generalized advantage estimate
    lastgaelam = 0
    for t in reversed(range(num_steps)):
        if t == num_steps - 1:
            nextnonterminal = 1.0 - next_done
            nextvalues = next_value
        else:
            nextnonterminal = 1.0 - storage.dones[t + 1]
            nextvalues = storage.values[t + 1]
        delta = storage.rewards[t] + gamma * nextvalues * nextnonterminal - storage.values[t]
        lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
        storage = storage.replace(advantages=storage.advantages.at[t].set(lastgaelam))
    # Save returns as advantages + values
    storage = storage.replace(returns=storage.advantages + storage.values)
    return storage

In [19]:
from jax.lax import stop_gradient

@jit
def ppo_loss(
        agent_state: AgentState,
        params: AgentParams,
        obs: ndarray,
        act: ndarray,
        logp: ndarray,
        adv: ndarray,
        ret: ndarray,
        val: ndarray,
):
    newlogprob, entropy, newvalue = get_action_and_value2(agent_state, params, obs, act)
    logratio = newlogprob - logp
    ratio = jnp.exp(logratio)

    # Calculate how much policy is changing
    approx_kl = ((ratio - 1) - logratio).mean()

    # Advantage normalization
    adv = (adv - adv.mean()) / (adv.std() + 1e-8)

    # Policy loss
    pg_loss1 = -adv * ratio
    pg_loss2 = -adv * jnp.clip(ratio, 1 - clip_coef, 1 + clip_coef)
    pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()

    # Value loss
    v_loss_unclipped = (newvalue - ret) ** 2
    v_clipped = val + jnp.clip(
        newvalue - val,
        -clip_coef,
        clip_coef,
    )
    v_loss_clipped = (v_clipped - ret) ** 2
    v_loss_max = jnp.maximum(v_loss_unclipped, v_loss_clipped)
    v_loss = 0.5 * v_loss_max.mean()

    # Entropy loss
    entropy_loss = entropy.mean()

    # main loss as sum of each part loss
    loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef
    return loss, (pg_loss, v_loss, entropy_loss, stop_gradient(approx_kl))

In [20]:
from jax import value_and_grad


def update_ppo(
        agent_state: AgentState,
        storage: Storage,
        key: random.PRNGKeyArray
):
    # Flatten collected experiences
    b_obs = storage.obs.reshape((-1,) + envs.single_observation_space.shape)
    b_logprobs = storage.logprobs.reshape(-1)
    b_actions = storage.actions.reshape((-1,) + envs.single_action_space.shape)
    b_advantages = storage.advantages.reshape(-1)
    b_returns = storage.returns.reshape(-1)
    b_values = storage.values.reshape(-1)

    # Create function that will return gradient of the specified function
    ppo_loss_grad_fn = jit(value_and_grad(ppo_loss, argnums=1, has_aux=True))

    for epoch in range(update_epochs):
        key, subkey = random.split(key)
        b_inds = random.permutation(subkey, batch_size, independent=True)
        for start in range(0, batch_size, minibatch_size):
            end = start + minibatch_size
            mb_inds = b_inds[start:end]
            (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
                agent_state,
                agent_state.params,
                b_obs[mb_inds],
                b_actions[mb_inds],
                b_logprobs[mb_inds],
                b_advantages[mb_inds],
                b_returns[mb_inds],
                b_values[mb_inds],
            )
            # Update an agent
            agent_state = agent_state.apply_gradients(grads=grads)

    # Calculate how good an approximation of the return is the value function
    y_pred, y_true = b_values, b_returns
    var_y = jnp.var(y_true)
    explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
    return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, explained_var, key

In [22]:
import time
import signal
from tqdm.notebook import tqdm

# Make kernel interrupt be handled as normal python error
signal.signal(signal.SIGINT, signal.default_int_handler)

run_name = f"{exp_name}_{seed}_{time.asctime(time.localtime(time.time())).replace('  ', ' ').replace(' ', '_')}"

# Initialize the storage
storage = Storage(
    obs=jnp.zeros((num_steps, num_envs) + envs.single_observation_space.shape),
    actions=jnp.zeros((num_steps, num_envs) + envs.single_action_space.shape),
    logprobs=jnp.zeros((num_steps, num_envs)),
    dones=jnp.zeros((num_steps, num_envs)),
    values=jnp.zeros((num_steps, num_envs)),
    advantages=jnp.zeros((num_steps, num_envs)),
    returns=jnp.zeros((num_steps, num_envs)),
    rewards=jnp.zeros((num_steps, num_envs)),
)
global_step = 0
start_time = time.time()
next_obs, _ = envs.reset(seed=seed)
next_done = jnp.zeros(num_envs)

try:
    for update in tqdm(range(1, num_updates + 1)):
        next_obs, next_done, storage, action_key, global_step = rollout(agent_state, next_obs, next_done, storage, action_key, global_step)
        storage = compute_gae(agent_state, next_obs, next_done, storage)
        agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, explained_var, permutation_key = update_ppo(agent_state, storage, permutation_key)

        if update%10==0:
            avg_episodic_return = np.mean(device_get(storage.returns))
            print(f"update={update}, avg_episodic_return={avg_episodic_return}")

    print('Training complete!')
except KeyboardInterrupt:
    print('Training interrupted!')
finally:
    envs.close()

  0%|          | 0/3906 [00:00<?, ?it/s]

Training interrupted!
