Skip to content

Commit

Permalink
feat(algo): add DCG-MAP-Elites (#167)
Browse files Browse the repository at this point in the history
* a new method for MAP-Elites repertoire, that enables to samples individuals with their corresponding descriptors.
* a new output extra_info for Emitter.emit methods that is similar to the extra_scores of the scoring function, and that enables to pass information from the emit step to the state_update (necessary for DCG-MAP-Elites).
* a new DCGTransition that add desc and desc_prime to the QDTransition.
* descriptor-conditioned TD3 loss, descriptor-conditioned scoring functions, descriptor-conditioned MLP
* two new reward wrappers to clip and offset the reward (necessary for DCG-MAP-Elites).
  • Loading branch information
maxencefaldor committed Jan 9, 2024
1 parent 82c0437 commit b4125c3
Show file tree
Hide file tree
Showing 40 changed files with 1,798 additions and 192 deletions.
2 changes: 1 addition & 1 deletion examples/distributed_mapelites.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@
"repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn(\n",
" centroids=centroids,\n",
" devices=devices,\n",
")(init_genotypes=init_variables, random_key=random_key)"
")(genotypes=init_variables, random_key=random_key)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/me_sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@
"# initialize map-elites\n",
"repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n",
" devices=devices, centroids=centroids\n",
")(init_genotypes=training_states, random_key=keys)"
")(genotypes=training_states, random_key=keys)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/me_td3_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@
"# initialize map-elites\n",
"repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n",
" devices=devices, centroids=centroids\n",
")(init_genotypes=training_states, random_key=keys)"
")(genotypes=training_states, random_key=keys)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions examples/mome.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@
"# initial population\n",
"random_key = jax.random.PRNGKey(42)\n",
"random_key, subkey = jax.random.split(random_key)\n",
"init_genotypes = jax.random.uniform(\n",
"genotypes = jax.random.uniform(\n",
" random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n",
")\n",
"\n",
Expand Down Expand Up @@ -303,7 +303,7 @@
"outputs": [],
"source": [
"repertoire, emitter_state, random_key = mome.init(\n",
" init_genotypes,\n",
" genotypes,\n",
" centroids,\n",
" pareto_front_max_length,\n",
" random_key\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/nsga2_spea2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
"# Initial population\n",
"random_key = jax.random.PRNGKey(0)\n",
"random_key, subkey = jax.random.split(random_key)\n",
"init_genotypes = jax.random.uniform(\n",
"genotypes = jax.random.uniform(\n",
" subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n",
")\n",
"\n",
Expand Down Expand Up @@ -238,7 +238,7 @@
"\n",
"# init nsga2\n",
"repertoire, emitter_state, random_key = nsga2.init(\n",
" init_genotypes,\n",
" genotypes,\n",
" population_size,\n",
" random_key\n",
")"
Expand Down Expand Up @@ -303,7 +303,7 @@
"\n",
"# init spea2\n",
"repertoire, emitter_state, random_key = spea2.init(\n",
" init_genotypes,\n",
" genotypes,\n",
" population_size,\n",
" num_neighbours,\n",
" random_key\n",
Expand Down
21 changes: 8 additions & 13 deletions qdax/baselines/genetic_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def __init__(

@partial(jax.jit, static_argnames=("self", "population_size"))
def init(
self, init_genotypes: Genotype, population_size: int, random_key: RNGKey
self, genotypes: Genotype, population_size: int, random_key: RNGKey
) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]:
"""Initialize a GARepertoire with an initial population of genotypes.
Args:
init_genotypes: the initial population of genotypes
genotypes: the initial population of genotypes
population_size: the maximal size of the repertoire
random_key: a random key to handle stochastic operations
Expand All @@ -54,26 +54,21 @@ def init(

# score initial genotypes
fitnesses, extra_scores, random_key = self._scoring_function(
init_genotypes, random_key
genotypes, random_key
)

# init the repertoire
repertoire = GARepertoire.init(
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
population_size=population_size,
)

# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
)

# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
random_key=random_key,
repertoire=repertoire,
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=None,
extra_scores=extra_scores,
Expand Down Expand Up @@ -108,7 +103,7 @@ def update(
"""

# generate offsprings
genotypes, random_key = self._emitter.emit(
genotypes, extra_info, random_key = self._emitter.emit(
repertoire, emitter_state, random_key
)

Expand All @@ -127,7 +122,7 @@ def update(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=None,
extra_scores=extra_scores,
extra_scores={**extra_scores, **extra_info},
)

# update the metrics
Expand Down
15 changes: 10 additions & 5 deletions qdax/baselines/nsga2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,36 @@ class NSGA2(GeneticAlgorithm):

@partial(jax.jit, static_argnames=("self", "population_size"))
def init(
self, init_genotypes: Genotype, population_size: int, random_key: RNGKey
self, genotypes: Genotype, population_size: int, random_key: RNGKey
) -> Tuple[NSGA2Repertoire, Optional[EmitterState], RNGKey]:

# score initial genotypes
fitnesses, extra_scores, random_key = self._scoring_function(
init_genotypes, random_key
genotypes, random_key
)

# init the repertoire
repertoire = NSGA2Repertoire.init(
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
population_size=population_size,
)

# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
random_key=random_key,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=None,
extra_scores=extra_scores,
)

# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
extra_scores=extra_scores,
)
Expand Down
15 changes: 10 additions & 5 deletions qdax/baselines/spea2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,40 @@ class SPEA2(GeneticAlgorithm):
)
def init(
self,
init_genotypes: Genotype,
genotypes: Genotype,
population_size: int,
num_neighbours: int,
random_key: RNGKey,
) -> Tuple[SPEA2Repertoire, Optional[EmitterState], RNGKey]:

# score initial genotypes
fitnesses, extra_scores, random_key = self._scoring_function(
init_genotypes, random_key
genotypes, random_key
)

# init the repertoire
repertoire = SPEA2Repertoire.init(
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
population_size=population_size,
num_neighbours=num_neighbours,
)

# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
random_key=random_key,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=None,
extra_scores=extra_scores,
)

# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
extra_scores=extra_scores,
)
Expand Down
24 changes: 11 additions & 13 deletions qdax/core/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def container_size_control(

def init(
self,
init_genotypes: Genotype,
genotypes: Genotype,
aurora_extra_info: AuroraExtraInfo,
l_value: jnp.ndarray,
max_size: int,
Expand All @@ -128,7 +128,7 @@ def init(
genotypes. Also performs the first training of the AURORA encoder.
Args:
init_genotypes: initial genotypes, pytree in which leaves
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
aurora_extra_info: information to perform AURORA encodings,
such as the encoder parameters
Expand All @@ -141,7 +141,7 @@ def init(
the emitter, and the updated information to perform AURORA encodings
"""
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
init_genotypes,
genotypes,
random_key,
)

Expand All @@ -150,7 +150,7 @@ def init(
descriptors = self._encoder_fn(observations, aurora_extra_info)

repertoire = UnstructuredRepertoire.init(
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
observations=observations,
Expand All @@ -160,13 +160,9 @@ def init(

# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
)

# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
genotypes=init_genotypes,
random_key=random_key,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
Expand Down Expand Up @@ -208,9 +204,10 @@ def update(
a new key
"""
# generate offsprings with the emitter
genotypes, random_key = self._emitter.emit(
genotypes, extra_info, random_key = self._emitter.emit(
repertoire, emitter_state, random_key
)

# scores the offsprings
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
genotypes,
Expand All @@ -232,10 +229,11 @@ def update(
# update emitter state after scoring is made
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
extra_scores=extra_scores | extra_info,
)

# update the metrics
Expand Down
32 changes: 32 additions & 0 deletions qdax/core/containers/mapelites_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,38 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey

return samples, random_key

@partial(jax.jit, static_argnames=("num_samples",))
def sample_with_descs(
self,
random_key: RNGKey,
num_samples: int,
) -> Tuple[Genotype, Descriptor, RNGKey]:
"""Sample elements in the repertoire.
Args:
random_key: a jax PRNG random key
num_samples: the number of elements to be sampled
Returns:
samples: a batch of genotypes sampled in the repertoire
random_key: an updated jax PRNG random key
"""

repertoire_empty = self.fitnesses == -jnp.inf
p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)

random_key, subkey = jax.random.split(random_key)
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.genotypes,
)
descs = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.descriptors,
)

return samples, descs, random_key

@jax.jit
def add(
self,
Expand Down
Loading

0 comments on commit b4125c3

Please sign in to comment.