In [6]:
# Example for Vanilla-ES directly inspired from the MAP-Elites example

import functools
import time
from typing import Dict, Tuple

import jax
import jax.numpy as jnp

from qdax import environments
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.emitters.vanilla_es_emitter import VanillaESConfig, VanillaESEmitter
from qdax.core.map_elites import MAPElites
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.tasks.brax_envs import (
    make_policy_network_play_step_fn_brax,
    reset_based_scoring_function_brax_envs,
)
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.core.rl_es_parts.es_utils import ES, default_es_metrics

from qdax.utils.plotting import plot_map_elites_results
%matplotlib inline

In [7]:
##############
# Parameters #

# General parameters
env_name = "ant_uni"  # @param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
episode_length = 100  # Number of steps per episode
num_iterations = 100  # Generations
seed = 42  # Random seed
policy_hidden_layer_sizes = (64, 64)  # Policy network hidden layer sizes

# MAP-Elites Parameters
num_init_cvt_samples = 50000 # Number of samples to use for CVT initialization
num_centroids = 1024  # Number of centroids
min_bd = 0.0  # Minimum value for the behavior descriptor
max_bd = 1.0  # Maximum value for the behavior descriptor

# ES Parameters
sample_number = 512  # Population size
sample_sigma = 0.01  # Standard deviation of the Gaussian distribution
sample_mirror = True  # Mirror sampling in ES
sample_rank_norm = True  # Rank normalization in ES
adam_optimizer = True  # Use Adam optimizer instead of SGD
learning_rate = 0.01  # Learning rate for Adam optimizer
l2_coefficient = 0.02  # L2 coefficient for Adam optimizer

# NSES Parameters
# WARNING: BD-based NSES 
nses_emitter = False  # Use NSES instead of ES
novelty_nearest_neighbors = 10  # Number of nearest neighbors to use for novelty computation

In [8]:
# TD3 config
episode_length: int = 1000
batch_size: int = 256
policy_delay: int = 2
grad_updates_per_step: float = 1
soft_tau_update: float = 0.005
critic_hidden_layer_size: Tuple[int, ...] = (256, 256)
policy_hidden_layer_size: Tuple[int, ...] = (256, 256)

num_loops = 10
print_freq = 1

In [9]:
# Init environment
env = environments.create(env_name, episode_length=episode_length)

# Init a random key
random_key = jax.random.PRNGKey(seed)

# Init policy network
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

In [10]:

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=1)
fake_batch = jnp.zeros(shape=(1, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

# Play reset fn
# WARNING: use "env.reset" for stochastic environment,
# use "lambda random_key: init_state" for deterministic environment
play_reset_fn = env.reset

# Prepare the scoring function
bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]
scoring_fn = functools.partial(
    reset_based_scoring_function_brax_envs,
    episode_length=episode_length,
    play_reset_fn=play_reset_fn,
    play_step_fn=make_policy_network_play_step_fn_brax(env, policy_network),
    behavior_descriptor_extractor=bd_extraction_fn,
)

# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
metrics_function = functools.partial(
    default_es_metrics,
    qd_offset=reward_offset * episode_length,
)

In [11]:
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten, tree_map


# @jax.jit
def flatten(network):
    """Flatten a network into a single vector of floats"""
    flat_variables, tree_def = tree_flatten(network)
    # print("Flatten", flat_variables)
    shapes = [x.shape for x in flat_variables]
    print("Shapes", shapes)
    
    vect = jnp.concatenate([jnp.ravel(x) for x in flat_variables])
    sizes = [x.size for x in flat_variables]
    sizes = jnp.array(sizes)
    return vect, tree_def, sizes, shapes

# @jax.jit
def unflaten(genome, tree_def, sizes, shapes):
    """Unflatten a vector of floats into a network"""
    indices = jnp.cumsum(sizes)[:-1]
    indices = indices.tolist()
    print(indices)
    split_genome = jnp.split(genome, indices)
    # Reshape to the original shape
    split_genome = [jnp.reshape(x, s) for x, s in zip(split_genome, shapes)]

    # Unflatten the tree
    new_net = tree_unflatten(tree_def, split_genome)
    return new_net
    

genome, tree_def, sizes, shapes = flatten(init_variables)
net_size = len(genome)
print("Network size:", net_size)

random = jax.random.PRNGKey(42)
genome = jax.random.normal(random, (net_size,))

random_net = unflaten(genome, tree_def, sizes, shapes)

after_genome, after_tree_def, after_sizes, after_shapes = flatten(random_net)

assert jnp.all(genome == after_genome)

Shapes [(1, 64), (1, 87, 64), (1, 64), (1, 64, 64), (1, 8), (1, 64, 8)]
Network size: 10312
[64, 5632, 5696, 9792, 9800]
Shapes [(1, 64), (1, 87, 64), (1, 64), (1, 64, 64), (1, 8), (1, 64, 8)]


In [50]:
init_variables["params"]["Dense_0"]["kernel"].shape

(1, 87, 64)

In [53]:
genome, tree_def, sizes, shapes = flatten(init_variables)
new_net = unflaten(genome, tree_def, sizes, shapes)
new_net["params"]["Dense_0"]["kernel"].shape

Shapes [(1, 64), (1, 87, 64), (1, 64), (1, 64, 64), (1, 8), (1, 64, 8)]
[64, 5632, 5696, 9792, 9800]


(1, 87, 64)

In [1]:
import jax
import jax.numpy as jnp

# Define the function to be mapped
def my_func(x, y):
    return x + y

# Define the input array of arrays
arr = jnp.array([[1, 2, 3], [4, 5, 6]])

# Apply vmap to the function to loop over the first axis of the input array
mapped_func = jax.vmap(my_func, in_axes=(0, None))

# Apply the mapped function to the input array
result = mapped_func(arr, 1)

# Print the result
print("Result:", result)

Result: [[2 3 4]
 [5 6 7]]


In [27]:
tree_def

PyTreeDef(CustomNode(FrozenDict[()], [{'params': {'Dense_0': {'bias': *, 'kernel': *}, 'Dense_1': {'bias': *, 'kernel': *}, 'Dense_2': {'bias': *, 'kernel': *}}}]))

In [12]:
from qdax.core.es_parts.open_es import OpenESEmitter, OpenESConfig

es_config = OpenESConfig(
    nses_emitter=nses_emitter,
    sample_number=sample_number,
    sample_sigma=sample_sigma,
    sample_mirror=sample_mirror,
    sample_rank_norm=sample_rank_norm,
    adam_optimizer=adam_optimizer,
    learning_rate=learning_rate,
    l2_coefficient=l2_coefficient,
    novelty_nearest_neighbors=novelty_nearest_neighbors,
)

es_emitter = OpenESEmitter(
    config=es_config,
    scoring_fn=scoring_fn,
    total_generations=num_iterations,
    num_descriptors=env.behavior_descriptor_length,
)

In [13]:
import numpy as np

In [14]:
def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)


In [15]:
def f(genome, x):
    return genome + 1, genome + 1

init = 0
xs = np.arange(10)
carry, ys = scan(f, init, xs)
print(carry, ys)

10 [ 1  2  3  4  5  6  7  8  9 10]


In [7]:
import jax
import jax.numpy as jnp

def f(x, y, z):
    return x + y + z

def g(x, y, z):
    return -(x + y + z)

def c(x, y, z):
    key = jax.random.PRNGKey(0)
    cond = jax.random.choice(key, jnp.array([True, False]))

    return jax.lax.cond(cond, 
                 f,
                 g, 
                 x, 
                 y,
                 z)
    
c(1, 2, 3)

DeviceArray(6, dtype=int32, weak_type=True)

In [17]:
import numpy
# Random normal vector 
n = 25000
sigma = 0.1
x = numpy.random.normal(0, 1, n) * sigma
# get norm
norm = numpy.linalg.norm(x)
norm

15.753333543850154

In [24]:
import flax


class ESMetrics(flax.struct.PyTreeNode):
    logs: dict

    def __init__(self, **kwargs):
        print(self.__dict__)
        if kwargs:
            self.replace(dict=kwargs)
    
    def __repr__(self):
        return f"ESMetrics({self.logs})"
    
    def __str__(self):
        return self.__repr__()

    def replace(self, **kwargs):
        print("Replace", kwargs)
        self.logs.update(kwargs)
        return self
    
    def __getattr__(self, key):
        print("Get attr", key)
        if key == "logs":
            return self.__dict__.logs
        if key in self.logs:
            print("Found", key)
            return self.logs[key]
        else:
            raise AttributeError(f"Attribute {key} not found in ESMetrics")
        
    @property
    def __dataclass_fields__(self):
        return self.logs

In [26]:
metrics = ESMetrics(logs={})

print(metrics)

# metrics = ESMetrics()


{}
Get attr logs


AttributeError: 'dict' object has no attribute 'logs'

In [30]:
import jax
import jax.numpy as jnp

# Define a function that takes a pytree as input
def my_function(pytree, a):
  # Do some computation on the pytree
  print(jax.tree_map(lambda x: x.shape, pytree))
  return 4


In [31]:
jax.tree_map(lambda x: x.shape, init_variables)

FrozenDict({
    params: {
        Dense_0: {
            bias: (1, 64),
            kernel: (1, 87, 64),
        },
        Dense_1: {
            bias: (1, 64),
            kernel: (1, 64, 64),
        },
        Dense_2: {
            bias: (1, 8),
            kernel: (1, 64, 8),
        },
    },
})

In [32]:
# Use vmap to map over the first dimension of the resulting array
jax.vmap(my_function, in_axes=[0, None])(init_variables, jnp.array([1, 2, 3]))


FrozenDict({
    params: {
        Dense_0: {
            bias: (64,),
            kernel: (87, 64),
        },
        Dense_1: {
            bias: (64,),
            kernel: (64, 64),
        },
        Dense_2: {
            bias: (8,),
            kernel: (64, 8),
        },
    },
})


DeviceArray([4], dtype=int32, weak_type=True)

In [1]:
import jax
import jax.numpy as jnp

# @jax.jit
def spearman(x, y):
    """Computes the Spearman correlation coefficient and p-value between two arrays.

    Args:
    x: A NumPy array of values.
    y: A NumPy array of values.

    Returns:
    A tuple of the Spearman correlation coefficient and p-value between x and y.
    """

    # Compute the ranks of x and y.
    x_ranks = jnp.argsort(x)
    y_ranks = jnp.argsort(y)

    # Compute the covariance of the ranks.
    covariance = jnp.cov(x_ranks, y_ranks)[0, 1]

    # Compute the standard deviation of the ranks.
    standard_deviation = jnp.std(x_ranks) * jnp.std(y_ranks)

    # Compute the Spearman correlation coefficient.
    r = covariance / standard_deviation

    # Compute the degrees of freedom.
    df = x.shape[0] - 2

    # Compute the critical value.
    critical_value = jnp.sqrt((1 - r**2) / (df * (1 - r**2)))

    # Return the Spearman correlation coefficient and p-value.
    return r, jnp.less(r, critical_value).astype(jnp.float32)

In [3]:
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (100,))

spearman(x, x)

(DeviceArray(1.0101011, dtype=float32), DeviceArray(0., dtype=float32))

In [7]:
def spearmanr_with_pvalue(x, y):
    # Compute the length of the arrays
    n = len(x)
    
    # Compute the ranks of the elements in x and y
    rank_x = jnp.argsort(jnp.argsort(x))
    rank_y = jnp.argsort(jnp.argsort(y))
    
    # Compute the squared differences between the ranks
    d = jnp.square(rank_x - rank_y)
    
    # Compute the t-statistic and p-value for testing non-correlation between two variables
    t = 1 - (6 * jnp.sum(d)) / (n * (n**2 - 1))
    p = 2 * (1 - jax.scipy.stats.norm.cdf(jnp.abs(t)))
    
    # Return the Spearman's rank correlation coefficient, p-value for the Spearman's rank correlation coefficient, and p-value for testing non-correlation between two variables
    return t, p


key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (100,))

spearmanr_with_pvalue(x, x)

(DeviceArray(1., dtype=float32), DeviceArray(0.31731057, dtype=float32))