-
Notifications
You must be signed in to change notification settings - Fork 38
/
repertoire.py
346 lines (282 loc) · 11.9 KB
/
repertoire.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
from __future__ import annotations
import math
from functools import partial
from typing import Callable, List, Tuple, Union
import flax
import jax
import jax.numpy as jnp
import numpy as np
from jax.flatten_util import ravel_pytree
from sklearn.cluster import KMeans
from qdax.types import Centroid, Descriptor, Fitness, Genotype, RNGKey
def compute_cvt_centroids(
num_descriptors: int,
num_init_cvt_samples: int,
num_centroids: int,
minval: Union[float, List[float]],
maxval: Union[float, List[float]],
) -> jnp.ndarray:
"""
Compute centroids for CVT tesselation.
Args:
num_descriptors: number od scalar descriptors
num_init_cvt_samples: number of sampled point to be sued for clustering to
determine the centroids. The larger the number of centroids and the
number of descriptors, the higher this value must be (e.g. 100000 for
1024 centroids and 100 descriptors).
num_centroids: number of centroids
minval: minimum descriptors value
maxval: maximum descriptors value
Returns:
the centroids with shape (num_centroids, num_descriptors)
"""
minval = jnp.array(minval)
maxval = jnp.array(maxval)
# assume here all values are in [0, 1] and rescale later
x = np.random.rand(num_init_cvt_samples, num_descriptors)
k_means = KMeans(
init="k-means++",
n_clusters=num_centroids,
n_init=1,
)
k_means.fit(x)
centroids = k_means.cluster_centers_
# rescale now
return jnp.asarray(centroids) * (maxval - minval) + minval
def compute_euclidean_centroids(
num_descriptors: int,
num_centroids: int,
minval: Union[float, List[float]],
maxval: Union[float, List[float]],
) -> jnp.ndarray:
"""
Compute centroids for square Euclidean tesselation.
Args:
num_descriptors: number od scalar descriptors
num_centroids: number of centroids
minval: minimum descriptors value
maxval: maximum descriptors value
Returns:
the centroids with shape (num_centroids, num_descriptors)
"""
if num_descriptors != 2:
raise NotImplementedError("This function supports 2 descriptors only for now.")
sqrt_centroids = math.sqrt(num_centroids)
if math.floor(sqrt_centroids) != sqrt_centroids:
raise ValueError("Num centroids should be a squared number.")
offset = 1 / (2 * int(sqrt_centroids))
linspace = jnp.linspace(offset, 1.0 - offset, int(sqrt_centroids))
meshes = jnp.meshgrid(linspace, linspace, sparse=False)
centroids = jnp.stack([jnp.ravel(meshes[0]), jnp.ravel(meshes[1])], axis=-1)
minval = jnp.array(minval)
maxval = jnp.array(maxval)
return jnp.asarray(centroids) * (maxval - minval) + minval
def get_cells_indices(
batch_of_descriptors: jnp.ndarray, centroids: jnp.ndarray
) -> jnp.ndarray:
"""
Returns the array of cells indices for a batch of descriptors
given the centroids of the grid.
Args:
batch_of_descriptors: a batch of descriptors
of shape (batch_size, num_descriptors)
centroids: centroids array of shape (num_centroids, num_descriptors)
Returns:
the indices of the centroids corresponding to each vector of descriptors
in the batch with shape (batch_size,)
"""
def _get_cells_indices(
descriptors: jnp.ndarray, centroids: jnp.ndarray
) -> jnp.ndarray:
"""
set_of_descriptors of shape (1, num_descriptors)
centroids of shape (num_centroids, num_descriptors)
"""
return jnp.argmin(
jnp.sum(jnp.square(jnp.subtract(descriptors, centroids)), axis=-1)
)
func = jax.vmap(lambda x: _get_cells_indices(x, centroids))
return func(batch_of_descriptors)
class MapElitesRepertoire(flax.struct.PyTreeNode):
"""
Class for the repertoire in Map Elites.
Args:
genotypes: a PyTree containing all the genotypes in the repertoire ordered
by the centroids. Each leaf has a shape (num_centroids, num_features). The
PyTree can be a simple Jax array or a more complex nested structure such
as to represent parameters of neural network in Flax.
fitnesses: an array that contains the fitness of solutions in each cell of the
repertoire, ordered by centroids. The array shape is (num_centroids,).
descriptors: an array that contains the descriptors of solutions in each cell
of the repertoire, ordered by centroids. The array shape
is (num_centroids, num_descriptors).
centroids: an array the contains the centroids of the tesselation. The array
shape is (num_centroids, num_descriptors).
"""
genotypes: Genotype
fitnesses: Fitness
descriptors: Descriptor
centroids: Centroid
def save(self, path: str = "./") -> None:
"""Saves the grid on disk in the form of .npy files.
Flattens the genotypes to store it with .npy format. Supposes that
a user will have access to the reconstruction function when loading
the genotypes.
Args:
path: Path where the data will be saved. Defaults to "./".
"""
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
flatten_genotype, _ = ravel_pytree(genotype)
return flatten_genotype
# flatten all the genotypes
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
# save data
jnp.save(path + "genotypes.npy", flat_genotypes)
jnp.save(path + "fitnesses.npy", self.fitnesses)
jnp.save(path + "descriptors.npy", self.descriptors)
jnp.save(path + "centroids.npy", self.centroids)
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> MapElitesRepertoire:
"""Loads a MAP Elites Grid.
Args:
reconstruction_fn: Function to reconstruct a PyTree
from a flat array.
path: Path where the data is saved. Defaults to "./".
Returns:
A MAP Elites Repertoire.
"""
flat_genotypes = jnp.load(path + "genotypes.npy")
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
fitnesses = jnp.load(path + "fitnesses.npy")
descriptors = jnp.load(path + "descriptors.npy")
centroids = jnp.load(path + "centroids.npy")
return MapElitesRepertoire(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
centroids=centroids,
)
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""
Sample elements in the grid.
Args:
random_key: a jax PRNG random key
num_samples: the number of elements to be sampled
Returns:
samples: a batch of genotypes sampled in the repertoire
random_key: an updated jax PRNG random key
"""
random_key, sub_key = jax.random.split(random_key)
grid_empty = self.fitnesses == -jnp.inf
p = (1.0 - grid_empty) / jnp.sum(grid_empty)
samples = jax.tree_map(
lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p),
self.genotypes,
)
return samples, random_key
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
) -> MapElitesRepertoire:
"""
Add a batch of elements to the repertoire.
Args:
batch_of_genotypes: a batch of genotypes to be added to the repertoire.
Similarly to the self.genotypes argument, this is a PyTree in which
the leaves have a shape (batch_size, num_features)
batch_of_descriptors: an array that contains the descriptors of the
aforementioned genotypes. Its shape is (batch_size, num_descriptors)
batch_of_fitnesses: an array that contains the fitnesses of the
aforementioned genotypes. Its shape is (batch_size,)
Returns:
The updated MAP-Elites repertoire.
"""
batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)
num_centroids = self.centroids.shape[0]
# get fitness segment max
best_fitnesses = jax.ops.segment_max(
batch_of_fitnesses,
batch_of_indices.astype(jnp.int32).squeeze(),
num_segments=num_centroids,
)
cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)
# put dominated fitness to -jnp.inf
batch_of_fitnesses = jnp.where(
batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf
)
# get addition condition
grid_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
current_fitnesses = jnp.take_along_axis(grid_fitnesses, batch_of_indices, 0)
addition_condition = batch_of_fitnesses > current_fitnesses
# assign fake position when relevant : num_centroids is out of bound
batch_of_indices = jnp.where(
addition_condition, x=batch_of_indices, y=num_centroids
)
# create new grid
new_grid_genotypes = jax.tree_multimap(
lambda grid_genotypes, new_genotypes: grid_genotypes.at[
batch_of_indices.squeeze()
].set(new_genotypes),
self.genotypes,
batch_of_genotypes,
)
# compute new fitness and descriptors
new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set(
batch_of_fitnesses.squeeze()
)
new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set(
batch_of_descriptors.squeeze()
)
return MapElitesRepertoire(
genotypes=new_grid_genotypes,
fitnesses=new_fitnesses.squeeze(),
descriptors=new_descriptors.squeeze(),
centroids=self.centroids,
)
@classmethod
def init(
cls,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
centroids: Centroid,
) -> MapElitesRepertoire:
"""
Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so it can
be called easily called from other modules.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
fitnesses: fitness of the initial genotypes of shape (batch_size,)
descriptors: descriptors of the initial genotypes
of shape (batch_size, num_descriptors)
centroids: tesselation centroids of shape (batch_size, num_descriptors)
Returns:
an initialized MAP-Elite repertoire
"""
# Initialize grid with default values
num_centroids = centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
default_genotypes = jax.tree_map(
lambda x: jnp.zeros(shape=(num_centroids,) + x.shape[1:]),
genotypes,
)
default_descriptors = jnp.zeros(shape=(num_centroids, centroids.shape[-1]))
repertoire = MapElitesRepertoire(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
centroids=centroids,
)
# Add initial values to the grid
new_repertoire = repertoire.add(genotypes, descriptors, fitnesses)
return new_repertoire # type: ignore