-
Notifications
You must be signed in to change notification settings - Fork 41
/
mome_repertoire.py
415 lines (340 loc) · 15 KB
/
mome_repertoire.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
"""This file contains the class to define the repertoire used to
store individuals in the Multi-Objective MAP-Elites algorithm as
well as several variants."""
from __future__ import annotations
from functools import partial
from typing import Any, Tuple
import jax
import jax.numpy as jnp
from qdax.core.containers.repertoire import MapElitesRepertoire, get_cells_indices
from qdax.types import Centroid, Descriptor, Fitness, Genotype, RNGKey
from qdax.utils.pareto_front import compute_masked_pareto_front
class MOMERepertoire(MapElitesRepertoire):
"""Class for the repertoire in Multi Objective Map Elites
This class inherits from MAPElitesRepertoire. The stored data
is the same: genotypes, fitnesses, descriptors, centroids.
The shape of genotypes is (in the case where it's an array):
(num_centroids, pareto_front_length, genotype_dim).
When the genotypes is a PyTree, the two first dimensions are the same
but the third will depend on the leafs.
The shape of fitnesses is: (num_centroids, pareto_front_length, num_criteria)
The shape of descriptors and centroids are:
(num_centroids, num_descriptors, pareto_front_length).
Inherited functions: save and load.
"""
@property
def repertoire_capacity(self) -> int:
"""Returns the maximum number of solutions the repertoire can
contain which corresponds to the number of cells times the
maximum pareto front length.
Returns:
The repertoire capacity.
"""
first_leaf = jax.tree_leaves(self.genotypes)[0]
return int(first_leaf.shape[0] * first_leaf.shape[1])
@partial(jax.jit, static_argnames=("num_samples",))
def _sample_in_masked_pareto_front(
self,
pareto_front_genotypes: Genotype,
mask: jnp.ndarray,
random_key: RNGKey,
) -> Genotype:
"""Sample num_samples elements in masked pareto front.
Note: do not retrieve a random key because this function
is to be vmapped. The public method that uses this function
will return a random key
Args:
pareto_front_genotypes: the genotypes of a pareto front
mask: a mask associated to the front
random_key: a random key to handle stochastic operations
Returns:
A single genotype among the pareto front.
"""
p = (1.0 - mask) / jnp.sum(1.0 - mask)
genotype_sample = jax.tree_map(
lambda x: jax.random.choice(random_key, x, shape=(1,), p=p),
pareto_front_genotypes,
)
return genotype_sample
@partial(jax.jit, static_argnames=("num_samples",))
def sample(
self, random_key: RNGKey, num_samples: int
) -> Tuple[jnp.ndarray, RNGKey]:
"""Sample elements in the repertoire.
This method sample a non-empty pareto front, and then sample
genotypes from this pareto front.
Args:
random_key: a random key to handle stochasticity.
num_samples: number of samples to retrieve from the repertoire.
Returns:
A sample of genotypes.
"""
# create sampling probability for the cells
repertoire_empty = jnp.any(self.fitnesses == -jnp.inf, axis=-1)
occupied_cells = jnp.any(~repertoire_empty, axis=-1)
p = occupied_cells / jnp.sum(occupied_cells)
# possible indices - num cells
indices = jnp.arange(start=0, stop=repertoire_empty.shape[0])
# choose idx - among indices of cells that are not empty
random_key, subkey = jax.random.split(random_key)
cells_idx = jax.random.choice(subkey, indices, shape=(num_samples,), p=p)
# get genotypes (front) from the chosen indices
pareto_front_genotypes = jax.tree_map(lambda x: x[cells_idx], self.genotypes)
# prepare second sampling function
sample_in_fronts = jax.vmap(self._sample_in_masked_pareto_front)
# sample genotypes from the pareto front
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey, num=num_samples)
sampled_genotypes = sample_in_fronts( # type: ignore
pareto_front_genotypes=pareto_front_genotypes,
mask=repertoire_empty[cells_idx],
random_key=subkeys,
)
# remove unnecessary dimension
sampled_genotypes = jax.tree_map(lambda x: x.squeeze(), sampled_genotypes)
return sampled_genotypes, random_key
@jax.jit
def _update_masked_pareto_front(
self,
pareto_front_fitnesses: Fitness,
pareto_front_genotypes: Genotype,
pareto_front_descriptors: Descriptor,
mask: jnp.ndarray,
new_batch_of_fitnesses: Fitness,
new_batch_of_genotypes: Genotype,
new_batch_of_descriptors: Descriptor,
new_mask: jnp.ndarray,
) -> Tuple[Fitness, Genotype, Descriptor, jnp.ndarray]:
"""Takes a fixed size pareto front, its mask and new points to add.
Returns updated front and mask.
Args:
pareto_front_fitnesses: fitness of the pareto front
pareto_front_genotypes: corresponding genotypes
pareto_front_descriptors: corresponding descriptors
mask: mask of the front, to hide void parts
new_batch_of_fitnesses: new batch of fitness that is considered
to be added to the pareto front
new_batch_of_genotypes: corresponding genotypes
new_batch_of_descriptors: corresponding descriptors
new_mask: corresponding mask (no one is masked)
Returns:
The updated pareto front.
"""
# get dimensions
batch_size = new_batch_of_fitnesses.shape[0]
num_criteria = new_batch_of_fitnesses.shape[1]
pareto_front_len = pareto_front_fitnesses.shape[0]
first_leaf = jax.tree_leaves(new_batch_of_genotypes)[0]
genotypes_dim = first_leaf.shape[1]
descriptors_dim = new_batch_of_descriptors.shape[1]
# gather all data
cat_mask = jnp.concatenate([mask, new_mask], axis=-1)
cat_fitnesses = jnp.concatenate(
[pareto_front_fitnesses, new_batch_of_fitnesses], axis=0
)
cat_genotypes = jax.tree_map(
lambda x, y: jnp.concatenate([x, y], axis=0),
pareto_front_genotypes,
new_batch_of_genotypes,
)
cat_descriptors = jnp.concatenate(
[pareto_front_descriptors, new_batch_of_descriptors], axis=0
)
# get new front
cat_bool_front = compute_masked_pareto_front(
batch_of_criteria=cat_fitnesses, mask=cat_mask
)
# get corresponding indices
indices = (
jnp.arange(start=0, stop=pareto_front_len + batch_size) * cat_bool_front
)
indices = indices + ~cat_bool_front * (batch_size + pareto_front_len - 1)
indices = jnp.sort(indices)
# get new fitness, genotypes and descriptors
new_front_fitness = jnp.take(cat_fitnesses, indices, axis=0)
new_front_genotypes = jax.tree_map(
lambda x: jnp.take(x, indices, axis=0), cat_genotypes
)
new_front_descriptors = jnp.take(cat_descriptors, indices, axis=0)
# compute new mask
num_front_elements = jnp.sum(cat_bool_front)
new_mask_indices = jnp.arange(start=0, stop=batch_size + pareto_front_len)
new_mask_indices = (num_front_elements - new_mask_indices) > 0
new_mask = jnp.where(
new_mask_indices,
jnp.ones(shape=batch_size + pareto_front_len, dtype=bool),
jnp.zeros(shape=batch_size + pareto_front_len, dtype=bool),
)
fitness_mask = jnp.repeat(
jnp.expand_dims(new_mask, axis=-1), num_criteria, axis=-1
)
new_front_fitness = new_front_fitness * fitness_mask
new_front_fitness = new_front_fitness[: len(pareto_front_fitnesses), :]
genotypes_mask = jnp.repeat(
jnp.expand_dims(new_mask, axis=-1), genotypes_dim, axis=-1
)
new_front_genotypes = jax.tree_map(
lambda x: x * genotypes_mask, new_front_genotypes
)
new_front_genotypes = jax.tree_map(
lambda x: x[: len(pareto_front_fitnesses), :], new_front_genotypes
)
descriptors_mask = jnp.repeat(
jnp.expand_dims(new_mask, axis=-1), descriptors_dim, axis=-1
)
new_front_descriptors = new_front_descriptors * descriptors_mask
new_front_descriptors = new_front_descriptors[: len(pareto_front_fitnesses), :]
new_mask = ~new_mask[: len(pareto_front_fitnesses)]
return new_front_fitness, new_front_genotypes, new_front_descriptors, new_mask
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
) -> MOMERepertoire:
"""Insert a batch of elements in the repertoire.
Shape of the batch_of_genotypes (if an array):
(batch_size, genotypes_dim)
Shape of the batch_of_descriptors: (batch_size, num_descriptors)
Shape of the batch_of_fitnesses: (batch_size, num_criteria)
Args:
batch_of_genotypes: a batch of genotypes that we are trying to
insert into the repertoire.
batch_of_descriptors: the descriptors of the genotypes we are
trying to add to the repertoire.
batch_of_fitnesses: the fitnesses of the genotypes we are trying
to add to the repertoire.
Returns:
The updated repertoire with potential new individuals.
"""
# get the indices that corresponds to the descriptors in the repertoire
batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
def _add_one(
carry: MOMERepertoire,
data: Tuple[Genotype, Descriptor, Fitness, jnp.ndarray],
) -> Tuple[MOMERepertoire, Any]:
# unwrap data
genotype, descriptors, fitness, index = data
index = index.astype(jnp.int32)
# get cell data
cell_genotype = jax.tree_map(lambda x: x[index], carry.genotypes)
cell_fitness = carry.fitnesses[index]
cell_descriptor = carry.descriptors[index]
cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1)
# update pareto front
(
cell_fitness,
cell_genotype,
cell_descriptor,
cell_mask,
) = self._update_masked_pareto_front(
pareto_front_fitnesses=cell_fitness.squeeze(),
pareto_front_genotypes=cell_genotype.squeeze(),
pareto_front_descriptors=cell_descriptor.squeeze(),
mask=cell_mask.squeeze(),
new_batch_of_fitnesses=jnp.expand_dims(fitness, axis=0),
new_batch_of_genotypes=jnp.expand_dims(genotype, axis=0),
new_batch_of_descriptors=jnp.expand_dims(descriptors, axis=0),
new_mask=jnp.zeros(shape=(1,), dtype=bool),
)
# update cell fitness
cell_fitness = cell_fitness - jnp.inf * jnp.expand_dims(cell_mask, axis=-1)
# update grid
new_genotypes = jax.tree_map(
lambda x, y: x.at[index].set(y), carry.genotypes, cell_genotype
)
new_fitnesses = carry.fitnesses.at[index].set(cell_fitness)
new_descriptors = carry.descriptors.at[index].set(cell_descriptor)
carry = carry.replace( # type: ignore
genotypes=new_genotypes,
descriptors=new_descriptors,
fitnesses=new_fitnesses,
)
# return new grid
return carry, ()
# scan the addition operation for all the data
self, _ = jax.lax.scan(
_add_one,
self,
(
batch_of_genotypes,
batch_of_descriptors,
batch_of_fitnesses,
batch_of_indices,
),
)
return self
@classmethod
def init( # type: ignore
cls,
genotypes: jnp.ndarray,
fitnesses: Fitness,
descriptors: Descriptor,
centroids: Centroid,
pareto_front_max_length: int,
) -> MOMERepertoire:
"""
Initialize a Multi Objective Map-Elites repertoire with an initial population
of genotypes. Requires the definition of centroids that can be computed with
any method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so it can
be called easily called from other modules.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
fitnesses: fitness of the initial genotypes of shape:
(batch_size, num_criteria)
descriptors: descriptors of the initial genotypes
of shape (batch_size, num_descriptors)
centroids: tesselation centroids of shape (batch_size, num_descriptors)
pareto_front_max_length: maximum size of the pareto fronts
Returns:
An initialized MAP-Elite repertoire
"""
# get dimensions
num_criteria = fitnesses.shape[1]
num_descriptors = descriptors.shape[1]
num_centroids = centroids.shape[0]
# create default values
default_fitnesses = -jnp.inf * jnp.ones(
shape=(num_centroids, pareto_front_max_length, num_criteria)
)
default_genotypes = jax.tree_map(
lambda x: jnp.zeros(
shape=(
num_centroids,
pareto_front_max_length,
)
+ x.shape[1:]
),
genotypes,
)
default_descriptors = jnp.zeros(
shape=(num_centroids, pareto_front_max_length, num_descriptors)
)
# create repertoire with default values
repertoire = MOMERepertoire( # type: ignore
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
centroids=centroids,
)
# add first batch of individuals in the repertoire
new_repertoire = repertoire.add(genotypes, descriptors, fitnesses)
return new_repertoire # type: ignore
@jax.jit
def compute_global_pareto_front(
self,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Merge all the pareto fronts of the MOME repertoire into a single one
called global pareto front.
Returns:
The pareto front and its mask.
"""
fitnesses = jnp.concatenate(self.fitnesses, axis=0)
mask = jnp.any(fitnesses == -jnp.inf, axis=-1)
pareto_bool = compute_masked_pareto_front(fitnesses, mask)
pareto_front = fitnesses - jnp.inf * (~jnp.array([pareto_bool, pareto_bool]).T)
return pareto_front, pareto_bool