In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
import gymnax
from gymnax.wrappers.purerl import FlattenObservationWrapper, LogWrapper
import plotly.graph_objects as go
import numpy as np
import pandas as pd

class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray


def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = gymnax.make(config["ENV_NAME"])
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def train(rng):
        # INIT NETWORK
        network = ActorCritic(
            env.action_space(env_params).n, activation=config["ACTIVATION"]
        )
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                """
                Steps the environment across ``num_envs``.
                Returns the updated runner state and observation.
                """
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                actions = pi.sample(seed=_rng)
                log_prob = pi.log_prob(actions)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(
                    env.step, in_axes=(0, 0, 0, None)
                )(rng_step, env_state, actions, env_params)
                transition = Transition(
                    done, actions, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, rng = runner_state
            # get the last value estimate to initialize gae computation
            _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                """
                Compute the generalized advantage estimation of a trajectory batch.

                ``GAE_t = delta_t + (lambda * gamma) * delta_{t+1} + (lambda * gamma)^2 * delta_{t+2}``
                """

                def _get_advantages(gae_and_next_value, transition):
                    """
                    Iteratively computes the GAE starting from the last transition.
                    Uses `lax.scan` to carry the current (`gae`, `next_value`) tuple
                    while iterating through transitions.
                    """
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    # td-error
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    # generalized advantage in recursive form
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae  # (carry_over), collected results

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    # gae is computed backwards as the advantage at time t
                    # depends on the estimated advantages of future timesteps
                    reverse=True,
                    # unrolls the loop body of the scan operation 16 iterations at a time
                    # enables the 128 steps (default value) to be completed in 8 iterations
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK

            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        # TODO: enlever la standardisation et comparer grad / vmap(grad)
                        # gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"],
                                1.0 + config["CLIP_EPS"],
                            )
                            * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    def get_per_sample_norms(grads: dict):
                        """
                        Computes the normalized L2-norm of the per-sample gradient.
                        """

                        def _single_sample_norm(grads, idx):
                            """
                            For a single sample, computes the L2-norm of all the gradient components.
                            """
                            sum_of_squares = jnp.array(
                                jax.tree_flatten(
                                    jax.tree_map(lambda g: jnp.sum(g[idx] ** 2), grads),
                                )[0]
                            ).sum()

                            return jnp.sqrt(sum_of_squares)

                        sample_norms = jax.vmap(_single_sample_norm, in_axes=(None, 0))(
                            grads, jnp.arange(config["NUM_STEPS"])
                        )
                        # TODO: à l'avenir, tester return sample norms, (sample norms)^alpha
                        return sample_norms / sample_norms.sum()

                    def get_weighted_grads(grads, weights):
                        """Divides the per-sample gradients by the norm ratio."""
                        def _single_sample_broadcast(idx):
                            return jax.tree_map(lambda g: g[idx] / weights[idx], grads)

                        per_sample_grads = jax.vmap(_single_sample_broadcast)(
                            jnp.arange(config["NUM_STEPS"])
                        )
                        return jax.tree_map(lambda x: x.sum(axis=0), per_sample_grads)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, per_sample_grads = jax.vmap(
                        grad_fn, in_axes=(None, 0, 0, 0)
                    )(train_state.params, traj_batch, advantages, targets)

                    # per_sample_norms = get_per_sample_norms(per_sample_grads)
                    # weighted_grads = get_weighted_grads(per_sample_grads, per_sample_norms)
                    weighted_grads = get_weighted_grads(per_sample_grads, jnp.ones(config["NUM_STEPS"]))
                    train_state = train_state.apply_gradients(grads=weighted_grads)
                    return train_state, total_loss

                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                # Batching and Shuffling
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                # Mini-batch Updates
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss

            # Updating Training State and Metrics:
            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]

            # Debugging mode
            if config.get("DEBUG"):

                def callback(info):
                    return_values = info["returned_episode_returns"][
                        info["returned_episode"]
                    ]
                    timesteps = (
                        info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
                    )
                    for t in range(len(timesteps)):
                        print(
                            f"global step={timesteps[t]}, episodic return={return_values[t]}"
                        )

                jax.debug.callback(callback, metric)

            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state, "metrics": metric}

    return train

Let's try running one instance of CartPole-v1!

In [None]:
env_name = "CartPole-v1"

config = {
    "LR": 2.5e-4,
    "NUM_ENVS": 4,
    "NUM_STEPS": 128,
    "TOTAL_TIMESTEPS": 5e5,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 4,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.01,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "tanh",
    "ENV_NAME": env_name,
    "ANNEAL_LR": True,
}

rng = jax.random.PRNGKey(42)
train_jit = jax.jit(make_train(config))
# out = train_jit(rng)

In [None]:
import time
import matplotlib.pyplot as plt
rng = jax.random.PRNGKey(42)
# t0 = time.time()
# out = jax.block_until_ready(train_jit(rng))
# print(f"time: {time.time() - t0:.2f} s")
# plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1))
# plt.xlabel("Update Step")
# plt.ylabel("Return")
# plt.show()

That was pretty fast! But now, let's try running many seeds in parallel. Let's try 256. Note that the first time you run it, it will first spend time JIT compiling the code. Each run after that should be faster than the first.

In [None]:
n_agents = 10
rng = jax.random.PRNGKey(42)
rngs = jax.random.split(rng, n_agents)
train_vjit = jax.jit(jax.vmap(make_train(config)))
t0 = time.time()
outs = jax.block_until_ready(train_vjit(rngs))
print(f"time: {time.time() - t0:.2f} s")

for i in range(n_agents):
    plt.plot(outs["metrics"]["returned_episode_returns"][i].mean(-1).reshape(-1))
plt.xlabel("Update Step")
plt.ylabel("Return")
plt.show()

In [None]:
averaged_rewards = pd.DataFrame(outs["metrics"]["returned_episode_returns"].mean(axis=(0,2,3)))
deviations = pd.DataFrame(outs["metrics"]["returned_episode_returns"].std(axis=(0,2,3)))
plt.plot(averaged_rewards)

In [None]:
# reconstruct the results as a 2D array to save as csv
returns = outs["metrics"]["returned_episode_returns"]
n_episodes = returns.shape[1]
returns = outs["metrics"]["returned_episode_returns"].reshape(n_agents, n_episodes, -1)
returns = pd.DataFrame(returns.transpose(1,0,2).reshape(n_episodes,-1))
plt.plot(returns.mean(axis=1))

In [None]:
pd.DataFrame(returns).to_csv(f"../logs/{env_name}_parallel_ger.csv")