In [None]:
# Example for Vanilla-ES directly inspired from the MAP-Elites example

import functools
import time
from typing import Dict

import jax
import jax.numpy as jnp

from qdax import environments
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.emitters.vanilla_es_emitter import VanillaESConfig, VanillaESEmitter
from qdax.core.map_elites import MAPElites
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.tasks.brax_envs import (
    make_policy_network_play_step_fn_brax,
    reset_based_scoring_function_brax_envs,
)
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results
%matplotlib inline

In [None]:
##############
# Parameters #

# General parameters
env_name = "ant_uni"  # @param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
episode_length = 100  # Number of steps per episode
num_iterations = 100  # Generations
seed = 42  # Random seed
policy_hidden_layer_sizes = (128, 128)  # Policy network hidden layer sizes

# MAP-Elites Parameters
num_init_cvt_samples = 50000 # Number of samples to use for CVT initialization
num_centroids = 1024  # Number of centroids
min_bd = 0.0  # Minimum value for the behavior descriptor
max_bd = 1.0  # Maximum value for the behavior descriptor

# ES Parameters
sample_number = 512  # Population size
sample_sigma = 0.01  # Standard deviation of the Gaussian distribution
sample_mirror = True  # Mirror sampling in ES
sample_rank_norm = True  # Rank normalization in ES
adam_optimizer = True  # Use Adam optimizer instead of SGD
learning_rate = 0.01  # Learning rate for Adam optimizer
l2_coefficient = 0.02  # L2 coefficient for Adam optimizer

# NSES Parameters
# WARNING: BD-based NSES 
nses_emitter = False  # Use NSES instead of ES
novelty_nearest_neighbors = 10  # Number of nearest neighbors to use for novelty computation

In [1]:
import jax.numpy as jnp
import jax

In [2]:
mu = 3
lam = 6

scores = jnp.array([30, 40, 20, 10, 50, 60])
ranking_indices = jnp.argsort(scores, axis=0)
ranking_indices

DeviceArray([3, 2, 0, 1, 4, 5], dtype=int32)

In [3]:
ranks = jnp.argsort(ranking_indices, axis=0) 
ranks = lam - ranks
ranks

DeviceArray([4, 3, 5, 6, 2, 1], dtype=int32)

In [5]:
# w = jnp.log(mu+0.5) - jnp.log(lam - ranks)
# print(w)
w = jnp.where(ranks <= mu, jnp.log(mu+0.5) - jnp.log(ranks), 0) 
print(w)
w /= jnp.sum(w)
w

[0.         0.15415072 0.         0.         0.55961585 1.252763  ]


DeviceArray([0.        , 0.07838719, 0.        , 0.        , 0.28457028,
             0.6370426 ], dtype=float32)