# 📸 At tour of PureRL: Vectorizable RL Algorithms in Jax

`pureRL` is a library of reinforcement learning algorithms which you can `jax.jit` and `jax.vmap`. In this notebook, I want to show you the its key features:

- 🏃‍♀️ [Set up](#setup), [train](#training) and [evaluate](#evaluation) RL agents
- 🔙 Use [custom callbacks](#callbacks) to define how training curves are generated, log data to wandb and more
- 💫 [Vmap](#vmapping) the train function to train multiple agents in parallel
- 🎨 [Customize](#customizing) existing algorithms by overwriting their methods
<br><br>

<a href="https://colab.research.google.com/github/kerajli/purerl/blob/master/examples/purerl_tour.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a> &nbsp; and make sure you set the runtime to GPU!

In [1]:
!pip install -q "jax[cuda12]==0.4.24" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q "flax>=0.7"
!pip install -q git+https://github.com/keraJLi/pureRL

  Preparing metadata (setup.py) ... [?25l[?25hdone


## 🔨 Setting up the training configuration <a name="setup"></a>

1. Each algorithm comes with a config class that extends `flax.PyTreeNode`.
2. `get_algo` is a convinient function to access algorithms easily.
3. `config_cls.create` creates an instance of the config class, and populates it with default values.

In [2]:
from purerl import get_algo

env_str = "CartPole-v1"  # @param ["CartPole-v1", "Acrobot-v1", "Pendulum-v1"]
algo_str = "ppo"  # @param ["ppo", "dqn", "sac", "td3"]

# We make some changes to the default hyperparameters to get better plots.
# All algorithms should work reasonably well on CartPole-v1 and Acrobot-1.
CONFIGS = {
    "sac": {"target_entropy_ratio": 0.7, "num_envs": 10, "gradient_steps": 5},
    "ppo": {"total_timesteps": 3e5},
    "dqn": {},
    "td3": {"total_timesteps": 5e4, "normalize_observations": True},
}

algo, config_cls = get_algo(algo_str)
config = config_cls.create(env=env_str, learning_rate=0.001, **CONFIGS.get(algo_str, {}))

Let's look at the config we have created:

In [3]:
[fn for fn in dir(algo) if not fn.startswith("_")]

['calculate_gae',
 'collect_trajectories',
 'initialize_train_state',
 'make_act',
 'make_minibatches',
 'train',
 'train_iteration',
 'update',
 'update_actor',
 'update_critic']

In [4]:
config.__dict__

{'env': <gymnax.environments.classic_control.cartpole.CartPole at 0x78068ef8c250>,
 'env_params': EnvParams(max_steps_in_episode=500, gravity=9.8, masscart=1.0, masspole=0.1, total_mass=1.1, length=0.5, polemass_length=0.05, force_mag=10.0, tau=0.02, theta_threshold_radians=0.20943951023931953, x_threshold=2.4),
 'actor': DiscretePolicy(
     # attributes
     action_dim = 2
     hidden_layer_sizes = (64, 64)
     activation = silu
 ),
 'critic': VNetwork(
     # attributes
     hidden_layer_sizes = (64, 64)
     activation = silu
 ),
 'eval_callback': <function purerl.evaluate.make_evaluate.<locals>._evaluate(config, ts, rng)>,
 'learning_rate': 0.001,
 'gamma': 0.99,
 'gae_lambda': 0.95,
 'clip_eps': 0.2,
 'vf_coef': 0.5,
 'ent_coef': 0.01,
 'max_grad_norm': inf,
 'total_timesteps': 300000.0,
 'eval_freq': 10000,
 'num_envs': 100,
 'num_steps': 50,
 'num_epochs': 5,
 'num_minibatches': 10,
 'normalize_observations': False,
 'skip_initial_evaluation': False}

As you can see we have a lot of hyperparemeters and variables we can tune. You are free to modify these after the creation:

In [5]:
config = config.replace(gamma=0.995)
print(f"New gamma: {config.gamma}")

New gamma: 0.995


**A few words about configs**

1. In pureRL, configs extend `flax.struct.PyTreeNode`. This allows to jit and vmap over individual parameters while keeping others fixed. For example, you are free to vmap over `learning_rate`, but not over `total_timesteps`.

2. You are free to replace the config you pass to the training algorithm by any object that has the same (or necessary) attributes.

## 🏃‍♀️ Training the agent <a name="training"></a>

In [6]:
import jax
import time

# Set training seed and jit train function
rng = jax.random.PRNGKey(0)
train_fn = jax.jit(algo.train)

print("Starting to train")

# Train!
start = time.time()
train_state, evaluation = train_fn(config, rng)
time_elapsed = time.time() - start

sps = config.total_timesteps / time_elapsed
print(f"Finished training in {time_elapsed:g} seconds ({sps:g} steps/second).")

RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

Whoa, that was pretty quick! Let's break down what just happened.
1. We jit the `train_fn` to allow for fast execution 🚀
2. `train_fn` created an initial `train_state`, which holds information about the current state of the algorithm such as the current environment step, replay buffer contents, network parameters and more.
3. The `train_state` was passed to the algorithm, which transforms it over the course of training. The final value of the `train_state` is returned, including the final network parameters.
4. Additionally, the returned `evaluation` value is a tuple of episode lengths and episodic returns. We take a look at how to customize this value later.

Let's look at how well training worked by plotting the learning curve!

In [None]:
from matplotlib import pyplot as plt

episode_lengths, episode_returns = evaluation
mean_return = episode_returns.mean(axis=1)

plt.plot(jax.numpy.linspace(0, config.total_timesteps, len(mean_return)), mean_return)
plt.xlabel("Environment step")
plt.ylabel("Episodic return")
plt.title(f"Training agent for {env} using {algo_str}")
plt.show()

We should also take a look at the `train_state`, to see what the algorithm produced.

In [None]:
jax.tree_map(lambda x: x.shape, train_state).__dict__

In actor-critic algorithms, the `train_state` has an `actor_ts` as well as `critic_ts` which are instances of a `flax.training.train_state.TrainState`. These provide the network parameters of our final agent. Note that while also being a `flax.struct.PyTreeNode`, our `train_state` is not. Algorithms which use buffers also have a `replay_buffer` state.

## 🤖 Making and evaluating policies <a name="evaluation"></a>

As discussed above, algorithms return an agent's policy parameters in its train state. We can extract a policy of the type `Callable[[chex.Array, chex.PRNGKey], chex.Array]` which maps `(obs, rng) -> action` like this:

In [None]:
# Get policy and jit it
policy = algo.make_act(config, train_state)
policy = jax.jit(policy)

Let's evaluate the policy! For demonstration purposes here is a full rollout like you would do with gym:

In [None]:
# For demonstration purposes, we do a manual rollout of the policy
import gymnax

rng = jax.random.PRNGKey(0)
env, params = gymnax.make(env_str)
step = jax.jit(env.step)

obs, state = env.reset(rng, params)
episode_return = 0
done = False

while not done:
    rng, rng_action, rng_step = jax.random.split(rng, 3)
    action = policy(obs, rng_action)
    obs, state, reward, done, info = step(rng_step, state, action, params)
    episode_return += reward

print(f"Return achieved in one episode of {env_str}: {episode_return}")

Alternatively, `pureRL` offers fast parallel evaluation of policies under `purerl.evaluate`:

In [None]:
from purerl.evaluate import evaluate

num_seeds = 200  # @param {type:"slider", min:1, max:500, step:1}
max_steps = params.max_steps_in_episode

rng = jax.random.PRNGKey(0)

start = time.time()

# Evaluation time!
episode_lengths, episode_returns = evaluate(policy, rng, env, params, num_seeds, max_steps)

time_elapsed = time.time() - start

print(
    f"Evaluated {num_seeds} episodes "
    f"with a total of {jax.numpy.sum(episode_lengths)} environment steps "
    f"in {time_elapsed:g} seconds."
)

In [None]:
fig, axes = plt.subplots(figsize=(8, 4), ncols=2, sharey="row")

axes[0].hist(episode_lengths)
axes[0].set(title="Episode length", ylabel="Count")
axes[1].hist(episode_returns)
axes[1].set(title="Episode return")

# 😲 Advanced features

## 💫 Vmapping the train function <a name="vmapping"></a>

I want to fit an SAC agent for a discrete environment. However, the best `target_entropy_ratio` parameter varies extremely between environments. We can do the following
1. `jax.vmap` across the config to run multiple values for the target entropy ratio in parallel
2. `jax.vmap` across the training seed to make sure our best run isn't a fluke

This can take a minute!

In [None]:
from purerl import SAC, SACConfig

num_seeds = 5  # @param {"type": "slider", "min": 1, "max": 10, "step": 1}
env_str = "Acrobot-v1"  # @param ["CartPole-v1", "Acrobot-v1", "MountainCar-v0"]

def make_config(ter: float):
    return SACConfig.create(
        env=env_str,
        target_entropy_ratio=ter,
        num_envs=10,
        gradient_steps=5,
        )

ters = jax.numpy.arange(0.1, 1, 0.1)
sac_config = jax.vmap(make_config)(ters)

rng = jax.random.PRNGKey(0)
rngs = jax.random.split(rng, num_seeds)

sac_train_fn = jax.jit(SAC.train)
sac_train_fn = jax.vmap(sac_train_fn, in_axes=(None, 0))  # Vmap across seeds
sac_train_fn = jax.vmap(sac_train_fn, in_axes=(0, None))  # Vmap across configs

start = time.time()
train_state, evaluation = sac_train_fn(config, rngs)
time_elapsed = time.time() - start

sps = num_seeds * len(ters) * config.total_timesteps / time_elapsed
print(
    f"Trained {num_seeds * len(ters)} agents "
    f"for {sac_config.total_timesteps} environment steps "
    f"in {time_elapsed // 60:.0f} minutes and {time_elapsed % 60:g} seconds."
)
print(f"That's {sps:g} steps / second!")

The resulting evaluation now has shape `(num_values, num_train_seeds, num_evals, num_eval_seeds)`:

In [None]:
_, episode_returns = evaluation
print(f"shape of evaluation: {episode_returns.shape}")

In [None]:
t = jax.numpy.linspace(0, sac_config.total_timesteps, episode_returns.shape[2])

# Plot all with alpha
for i, returns_for_val in enumerate(episode_returns):
    # take mean across evaluation seeds and transpose for plotting
    returns_for_val = returns_for_val.mean(axis=2).T
    plt.plot(t, returns_for_val, c=f"C{i}", alpha=0.3)

# Plot mean across training seeds
for i, returns_for_val in enumerate(episode_returns):
    returns_for_val = returns_for_val.mean(axis=(0, 2))
    plt.plot(t, returns_for_val, c=f"C{i}", label=f"{ters[i]:.1f}")

plt.legend(title="target_entropy_ratio", bbox_to_anchor=(1.05, 0.5), loc="center left")
plt.xlabel("Environment step")
plt.ylabel("Episodic return")
plt.title(f"Different values for SAC target entropy ratio on {env_str}")

## 🔙 Using custom callbacks <a name="callbacks"></a>

If you were super alert, you might have noticed the `eval_callback` attribute of the training config. This can be any function of your choice! As long as it maps `(config, train_state, rng) -> evaluation: chex.ArrayTree`. Two additional attributes control the evaluation behavior:
- `eval_freq`: the evaluating of `eval_callback` is called every `eval_freq` environment steps
- `skip_initial_evaluation`: if true, don't evaluate the initialized policy (but start after `eval_freq` steps instead).

The return value of `eval_callback` should be a PyTree of Jax arrays, so it can be aggregated over the course of training. To show how exactly this works out in practice, we can try to implement our custom callback which returns a dict with some information.

In [None]:
def dict_callback(config, train_state, rng):
    policy = algo.make_act(config, train_state)
    l, r = evaluate(policy, rng, env, params, num_seeds, max_steps)
    return {
        "episode_lengths": l,
        "episode_returns": r,
        "global_step": train_state.global_step,
    }

In [None]:
rng = jax.random.PRNGKey(0)

# Replace eval_callback by new one!
config = config.replace(eval_callback=dict_callback)
_, evaluations = train_fn(config, rng)

In [None]:
print(f"total_timesteps: {config.total_timesteps}")
print(f"eval_freq: {config.eval_freq}")
print(f"Evaluation: {jax.tree_map(lambda x: x.shape, evaluations)}")

As you can see, the entries in the evaluation dict have shape `(total_timesteps / eval_freq + 1, num_seeds)`. The `+1` comes from the fact that we evaluate the inital policy as well per default.

Since `train_fn` is the jitted `algo.train`, the `eval_callback` is jitted as well, so it must be pure. However, we can escape this via one of Jax's callbacks. For example, we can print the current performance of the agent over the course of training!

In [None]:
def print_callback(config, train_state, rng):
    policy = algo.make_act(config, train_state)
    _, r = evaluate(policy, rng, env, params, num_seeds, max_steps)

    # Jax callback! This can be impure.
    jax.debug.print("step: {}, mean return: {}", train_state.global_step, r.mean())

    # Since we now print the result, we don't return it
    return ()

train_fn(config.replace(eval_callback=print_callback), rng)
None  # Supress output

Feel free to use `jax.experimental.io_callback`, `jax.debug.callback` or `jax.pure_callback` for more finegrained control over callbacks. For example, I've included an example of how to log to wandb in the repositories `examples` folder!

## 🎨 Customize and extend existing algorithms to your liking <a name="customizing"></a>

The algorithms in `pureRL` are stateless classes, meaning all of their methods are class methods. You are free to subclass and overwrite these methods! This can be very helpful when trying to modify or extend existing algorithms.

Let us replace the actor loss of PPO to create an instance of [Discovered Policy Optimization](https://arxiv.org/abs/2210.05639) [1].
In a nutshell, DPO replaces the surrogate objective of PPO with

$$
L_\text{actor} = \sum_{s, a} r A - f(r, A)
$$

where $A = A^{\pi_k}(s, a)$,  $r = \pi(a|s) / \pi_k(a|s)$ and

$$
f(r, A) =
\begin{cases}
\text{ReLU}((r - 1)A - \alpha \tanh((r - 1)A/\alpha)) & \text{if } A \geq 0, \\
\text{ReLU}(\log(r)A - \beta \tanh(\log(r)A/\beta)) & \text{if } A < 0.
\end{cases}
$$

For more detail, see: <br>Chris Lu, Jakub Grudzien Kuba, Alistair Letcher, Luke Metz, Christian Schroeder de Witt, & Jakob Foerster. (2022). Discovered Policy Optimisation.

We can easily replace `PPO.update_actor` by this new loss function, while keeping the rest of the algorithm as is.

In [None]:
from purerl import PPO, PPOConfig

DPOConfig = PPOConfig

class DPO(PPO):
    @classmethod
    def update_actor(cls, config, ts, batch):
        def actor_loss_fn(params):
            # config.actor is a flax model on which we call the apply function
            log_prob, entropy = config.actor.apply(
                params,
                batch.trajectories.obs,
                batch.trajectories.action,
                method="log_prob_entropy",
            )
            entropy = entropy.mean()

            # We now calculate actor loss as in DPO
            alpha, beta = 2, 0.6
            ratio = jax.numpy.exp(log_prob - batch.trajectories.log_prob)
            advantages = (batch.advantages - batch.advantages.mean()) / (
                batch.advantages.std() + 1e-8
            )  # advantage normalization (optional, but improves performance)

            # The two drifts represent the cases for A < 0 (A >= 0) in the
            # formula for f
            drift1 = jax.nn.relu(
                (ratio - 1) * advantages
                - alpha * jax.numpy.tanh((ratio - 1) * advantages / alpha),
            )
            drift2 = jax.nn.relu(
                jax.numpy.log(ratio) * advantages
                - beta * jax.numpy.tanh(jax.numpy.log(ratio) * advantages / beta),
            )
            drift = jax.numpy.where(advantages >= 0, drift1, drift2)

            # Finally, we calculate the actor loss, including an entropy bonus
            pi_loss = -(ratio * advantages - drift).mean()
            pi_loss = pi_loss - config.ent_coef * entropy
            return pi_loss

        grads = jax.grad(actor_loss_fn)(ts.actor_ts.params)
        return ts.replace(actor_ts=ts.actor_ts.apply_gradients(grads=grads))

Let's see if DPO can train an agent!

In [None]:
env_str = "Pendulum-v1"  # @param ["CartPole-v1", "Acrobot-v1", "Pendulum-v1"]

rng = jax.random.PRNGKey(0)
dpo_config = DPOConfig.create(
    env=env_str,
    total_timesteps=3e5,
    learning_rate=0.001,
    num_epochs=10,
)
dpo_train_fn = jax.jit(DPO.train)
_, (_, returns) = dpo_train_fn(dpo_config, rng)

In [None]:
plt.plot(
    jax.numpy.linspace(0, dpo_config.total_timesteps, len(returns)),
    returns.mean(axis=1),
)
plt.xlabel("Environment step")
plt.ylabel("Episodic return")
plt.title(f"Discovered Policy Optimization on {env_str}")

Seems like it!

🤗 I hope this tour was instructive, let me know if I can help you using pureRL in any way!