In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#@title Installs and Imports
!pip install ipympl |tail -n 1
# %matplotlib widget
# from google.colab import output
# output.enable_custom_widget_manager()

import os

from IPython.display import clear_output
import matplotlib.pyplot as plt
import matplotlib as mpl
import functools
import time

import jax
import jax.numpy as jnp

try:
    import brax
except:
    !pip install git+https://github.com/google/brax.git@main |tail -n 1
    import brax

try:
    import qdax
except:
    !pip install --no-deps git+https://github.com/instadeepai/QDax@2-instadeep-new-structure-suggestion |tail -n 1
    import qdax


from qdax.core.map_elites import MAPElites
from qdax.core.containers.repertoire import compute_cvt_centroids
from qdax import environments
from qdax.core.neuroevolution.mdp_utils import scoring_function
from qdax.core.neuroevolution.buffers.buffers import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.utils.plotting import plot_2d_map_elites_grid

from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGEmitter
from qdax.utils.metrics import CSVLogger


if "COLAB_TPU_ADDR" in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

clear_output()

In [None]:
#@title QD Training Definitions Fields
#@markdown ---
env_name = 'walker2d_uni'#@param['ant', 'hopper', 'walker', 'halfcheetah', 'humanoid', 'ant_omni', 'humanoid_omni']
episode_length = 100 #@param {type:"integer"}
num_iterations = 1000 #@param {type:"integer"}
seed = 42 #@param {type:"integer"}
policy_hidden_layer_sizes = (64, 64) #@param {type:"raw"}
num_init_cvt_samples = 50000 #@param {type:"integer"}
num_centroids = 1024 #@param {type:"integer"}
min_bd = 0. #@param {type:"number"}
max_bd = 1.0 #@param {type:"number"}

#@title PGA-ME Emitter Definitions Fields
proportion_mutation_ga = 0.5

# TD3 params
env_batch_size = 100 #@param {type:"number"}
replay_buffer_size = 1000000 #@param {type:"number"}
critic_hidden_layer_size = (256, 256) #@param {type:"raw"}
critic_learning_rate = 3e-4 #@param {type:"number"}
greedy_learning_rate = 3e-4 #@param {type:"number"}
policy_learning_rate = 1e-3 #@param {type:"number"}
noise_clip = 0.5 #@param {type:"number"}
policy_noise = 0.2 #@param {type:"number"}
discount = 0.99 #@param {type:"number"}
reward_scaling = 1.0 #@param {type:"number"}
transitions_batch_size = 256 #@param {type:"number"}
soft_tau_update = 0.005 #@param {type:"number"}
num_critic_training_steps = 300 #@param {type:"number"}
num_pg_training_steps = 100 #@param {type:"number"}
#@markdown ---

## Init environment, policy, population params, init states of the env

In [None]:
# Init environment
env = environments.create(env_name)

# Init a random key
random_key = jax.random.PRNGKey(seed)

# Init policy network
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=env_batch_size)
fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

## Define the way the policy interacts with the env

In [None]:
# Define the fonction to play a step with the policy in the environment
def play_step_fn(
  env_state,
  policy_params,
  random_key,
):
    """
    Play an environment step and return the updated state and the transition.
    """

    actions = policy_network.apply(policy_params, env_state.obs)
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        actions=actions,
        truncations=next_state.info["truncation"],
        state_desc=env_state.info["state_descriptor"],
        next_state_desc=next_state.info["state_descriptor"],
    )

    return next_state, policy_params, random_key, transition

# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
def metrics_function(repertoire):

    # Get metrics
    grid_empty = repertoire.fitnesses == -jnp.inf
    qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty)
    # Add offset for positive qd_score
    qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty)
    coverage = 100 * jnp.mean(1.0 - grid_empty)
    max_fitness = jnp.max(repertoire.fitnesses)

    return {
        "qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage
    }

## Change the emitter for the PG Emitter

In [None]:
# Define the PG-emitter config
pga_emitter_config = PGAMEConfig(
    env_batch_size=env_batch_size,
    batch_size=transitions_batch_size,
    proportion_mutation_ga=proportion_mutation_ga,
    critic_hidden_layer_size=critic_hidden_layer_size,
    critic_learning_rate=critic_learning_rate,
    greedy_learning_rate=greedy_learning_rate,
    policy_learning_rate=policy_learning_rate,
    noise_clip=noise_clip,
    policy_noise=policy_noise,
    discount=discount,
    reward_scaling=reward_scaling,
    replay_buffer_size=replay_buffer_size,
    soft_tau_update=soft_tau_update,
    num_critic_training_steps=num_critic_training_steps,
    num_pg_training_steps=num_pg_training_steps
)

In [None]:
# Get the emitter
variation_fn = functools.partial(
    isoline_variation, iso_sigma=0.05, line_sigma=0.1
)

pg_emitter = PGEmitter(
    config=pga_emitter_config,
    policy_network=policy_network,
    env=env,
    variation_fn=variation_fn,
)

In [None]:
# Create the initial environment states
random_key, subkey = jax.random.split(random_key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)
reset_fn = jax.jit(jax.vmap(env.reset))
init_states = reset_fn(keys)

# Prepare the scoring function
bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]
scoring_fn = functools.partial(
    scoring_function,
    init_states=init_states,
    episode_length=episode_length,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)

In [None]:
# Compute the centroids
centroids = compute_cvt_centroids(
    num_descriptors=env.behavior_descriptor_length,
    num_init_cvt_samples=num_init_cvt_samples,
    num_centroids=num_centroids,
    minval=min_bd,
    maxval=max_bd,
)


# Instantiate MAP Elites
map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=pg_emitter,
    metrics_function=metrics_function,
)

repertoire, emitter_state, random_key = map_elites.init(
    init_variables, centroids, random_key
)

In [None]:
@jax.jit
def update_scan_fn(carry, unused):
    # iterate over grid
    repertoire, emitter_state, metrics, random_key = map_elites.update(*carry)

    return (repertoire, emitter_state, random_key), metrics

In [None]:
log_period = 10
num_loops = int(num_iterations / log_period)

csv_logger = CSVLogger(
    "pgame-logs.csv",
    header=["loop", "iteration", "qd_score", "max_fitness", "coverage", "time"]
)
all_metrics = {}

# main loop
for i in range(num_loops):
    start_time = time.time()
    # main iterations
    (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
        update_scan_fn,
        (repertoire, emitter_state, random_key),
        (),
        length=log_period,
    )
    timelapse = time.time() - start_time

    # log metrics
    logged_metrics = {"time": timelapse, "loop": 1+i, "iteration": 1 + i*log_period}
    for key, value in metrics.items():
        # take last value
        logged_metrics[key] = value[-1]

        # take all values
        if key in all_metrics.keys():
            all_metrics[key] = jnp.concatenate([all_metrics[key], value])
        else:
            all_metrics[key] = value

    csv_logger.log(logged_metrics)

In [None]:
#@title Visualization

# Customize matplotlib params
font_size = 16
params = {
    "axes.labelsize": font_size,
    "axes.titlesize": font_size,
    "legend.fontsize": font_size,
    "xtick.labelsize": font_size,
    "ytick.labelsize": font_size,
    "text.usetex": False,
    "axes.titlepad": 10,
}

mpl.rcParams.update(params)

# Visualize the training evolution and final repertoire
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(40,10))

env_steps = jnp.arange(num_iterations) * episode_length * batch_size

axes[0].plot(env_steps, all_metrics['coverage'])
axes[0].set_xlabel('Environment steps')
axes[0].set_ylabel('Coverage in %')
axes[0].set_title('Coverage evolution during training')
axes[0].set_aspect(0.95/axes[0].get_data_ratio(), adjustable='box')

axes[1].plot(env_steps, all_metrics['max_fitness'])
axes[1].set_xlabel('Environment steps')
axes[1].set_ylabel('Maximum fitness')
axes[1].set_title('Maximum fitness evolution during training')
axes[1].set_aspect(0.95/axes[1].get_data_ratio(), adjustable='box')

axes[2].plot(env_steps, all_metrics['qd_score'])
axes[2].set_xlabel('Environment steps')
axes[2].set_ylabel('QD Score')
axes[2].set_title('QD Score evolution during training')
axes[2].set_aspect(0.95/axes[2].get_data_ratio(), adjustable='box')

plot_2d_map_elites_grid(
    centroids=centroids,
    grid_fitness=repertoire.fitnesses,
    minval=min_bd,
    maxval=max_bd,
    grid_descriptors=repertoire.descriptors,
    ax=axes[3],
)
plt.show()