In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse

from qdax.core.cmaes import CMAES

In [3]:
# jax.config.update('jax_disable_jit', True)

### Define fitness function

In [4]:
fitness_func = lambda x: -jnp.sum((x - 5.0) * (x - 5.0), axis=-1)

In [5]:
num_iterations = 70000 # 70000 #70000 #10000
num_dimensions = 100 #100 #1000 #@param {type:"integer"} # try 20 and 100
grid_shape = (500, 500) # (500, 500) 
batch_size = 36 #@param {type:"integer"}
sigma_g = 0.5 # 0.5 #@param {type:"number"}
minval = -5.12
num_best = 36

def rastrigin_scoring(x: jnp.ndarray):
    first_term = 10 * x.shape[-1]
    second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))
    return -(first_term + second_term)
#     return (first_term + second_term)

fitness_func = jax.vmap(rastrigin_scoring)

### Instantiate CMA-ES object

In [6]:
cmaes = CMAES(
    population_size=batch_size,
    num_best=num_best,
    search_dim=num_dimensions, #2, 
    fitness_function=fitness_func,
    mean_init=jnp.zeros((num_dimensions,)), #jnp.asarray([-1.0, -1.0]),
    init_sigma=sigma_g, #1.0,
    init_step_size=0.5 #0.5,
)

### Initialization

In [7]:
state = cmaes.init()
key = jax.random.PRNGKey(0)

In [8]:
cmaes._weights

DeviceArray([0.10648119, 0.08596389, 0.07396204, 0.06544659, 0.0588415 ,
             0.05344474, 0.04888185, 0.04492929, 0.04144289, 0.0383242 ,
             0.03550299, 0.03292744, 0.03055816, 0.02836455, 0.02632234,
             0.02441199, 0.02261749, 0.02092559, 0.01932519, 0.0178069 ,
             0.0163627 , 0.01498569, 0.01366991, 0.01241014, 0.0112018 ,
             0.01004086, 0.00892373, 0.00784725, 0.00680854, 0.00580504,
             0.00483445, 0.00389469, 0.00298384, 0.00210018, 0.00124215,
             0.00040829], dtype=float32)

### Iterations

In [9]:
%%time

means = [state.mean]
covs = [state.cov_matrix]

for _ in range(num_iterations):
    
    samples, key = cmaes.sample(state, key)
#     state = cmaes.update(state, samples)

    fitnesses = -fitness_func(samples)
    idx_sorted = jnp.argsort(fitnesses)
    sorted_candidates = samples[idx_sorted[: num_best]]

    state = cmaes.update_state(state, sorted_candidates)
    
    means.append(state.mean)
    covs.append(state.cov_matrix)

Cov:  Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=0/2)>
u before diag:  Traced<ShapedArray(float32[100])>with<DynamicJaxprTrace(level=0/2)>
u after diag:  Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=0/2)>
tmp_1:  Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/2)>
Cov:  Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=0/2)>
u before diag:  Traced<ShapedArray(float32[100])>with<DynamicJaxprTrace(level=0/2)>
u after diag:  Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=0/2)>
tmp_1:  Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/2)>


KeyboardInterrupt: 

In [10]:
fitnesses

DeviceArray([106.46234, 106.46234, 106.46234, 106.46234, 106.46234,
             106.46234, 106.46234, 106.46234, 106.46234, 106.46234,
             106.46234, 106.46234, 106.46234, 106.46234, 106.46234,
             106.46234, 106.46234, 106.46234, 106.46234, 106.46234,
             106.46234, 106.46234, 106.46234, 106.46234, 106.46234,
             106.46234, 106.46234, 106.46234, 106.46234, 106.46234,
             106.46234, 106.46234, 106.46234, 106.46234, 106.46234,
             106.46234], dtype=float32)

In [11]:
means[-1]

DeviceArray([1.0532064 , 0.05809125, 2.048349  , 0.05819275, 1.0531707 ,
             2.048232  , 0.05806085, 1.0532742 , 2.048295  , 3.0434341 ,
             2.048349  , 2.0482361 , 1.053242  , 3.0433483 , 1.0530953 ,
             2.0482621 , 2.0483422 , 2.0483575 , 1.05316   , 3.0432768 ,
             1.0532229 , 3.0434422 , 3.043326  , 3.043363  , 1.0532238 ,
             1.0531511 , 3.0434332 , 3.0433965 , 4.038295  , 3.0433767 ,
             1.053139  , 0.05811936, 0.05804817, 2.0482464 , 1.0531576 ,
             3.0433474 , 1.0531776 , 1.0531347 , 3.0433965 , 2.0482073 ,
             1.053119  , 3.0434198 , 3.043429  , 2.04819   , 2.048191  ,
             1.0531707 , 0.05798267, 0.05809572, 1.0531707 , 3.0433483 ,
             2.0482898 , 2.0483832 , 2.0482864 , 1.0532571 , 1.053123  ,
             3.0435386 , 2.0482316 , 3.0434432 , 1.0532467 , 3.043419  ,
             2.0481358 , 1.0531228 , 1.0531871 , 1.0532064 , 2.0483832 ,
             2.0483809 , 4.0383434 , 4.0383253 , 2.

### Visualization

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

# sample points to show fitness landscape
x = jax.random.uniform(key, minval=-20, maxval=20, shape=(100000, 2))
f_x = fitness_func(x)

# plot fitness landscape
points = ax.scatter(x[:, 0], x[:, 1], c=f_x)
fig.colorbar(points)

# plot cma-es trajectory
for mean, cov in zip(means, covs):
#     print("Mean: ", mean)
#     print("Covariance: ", cov)
#     print("Covariance: ", cov[0])
#     print("Covariance: ", cov[1])
    ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')
    ax.add_patch(ellipse)
    ax.plot(mean[0], mean[1], color='k', marker='x')

In [17]:
N = 4

pc = jnp.zeros((N,1))
ps = jnp.zeros((N,1))  # evolution paths for C and sigma
B = jnp.eye(N,N)    # B defines the coordinate system
D = jnp.ones((N,1))      # diagonal D defines the scaling
C = B * jnp.diag(D**2) * B.T     # covariance matrix C
invsqrtC = B * jnp.diag(D**(-1)) * B.T

In [18]:
pc

DeviceArray([[0.],
             [0.],
             [0.],
             [0.]], dtype=float32)

In [19]:
ps

DeviceArray([[0.],
             [0.],
             [0.],
             [0.]], dtype=float32)

In [20]:
B

DeviceArray([[1., 0., 0., 0.],
             [0., 1., 0., 0.],
             [0., 0., 1., 0.],
             [0., 0., 0., 1.]], dtype=float32)

In [21]:
D

DeviceArray([[1.],
             [1.],
             [1.],
             [1.]], dtype=float32)

In [20]:
jnp.diag(D)

DeviceArray([1.], dtype=float32)

In [33]:
D**2

DeviceArray([[1.],
             [1.],
             [1.],
             [1.]], dtype=float32)

In [29]:
B * jnp.diag(D**2) * B.T

DeviceArray([[1., 0., 0., 0.],
             [0., 1., 0., 0.],
             [0., 0., 1., 0.],
             [0., 0., 0., 1.]], dtype=float32)

In [30]:
B * jnp.diag(D**(-1)) * B.T

DeviceArray([[1., 0., 0., 0.],
             [0., 1., 0., 0.],
             [0., 0., 1., 0.],
             [0., 0., 0., 1.]], dtype=float32)

In [34]:
C

DeviceArray([[1., 0., 0., 0.],
             [0., 1., 0., 0.],
             [0., 0., 1., 0.],
             [0., 0., 0., 1.]], dtype=float32)

In [35]:
jnp.triu(C)

DeviceArray([[1., 0., 0., 0.],
             [0., 1., 0., 0.],
             [0., 0., 1., 0.],
             [0., 0., 0., 1.]], dtype=float32)

In [39]:
A = jnp.array(
[[1, 2, 3, 4],
 [5, 6, 7, 8],
 [9, 10, 11, 12],
 [13, 14, 15, 16]])

In [40]:
A

DeviceArray([[ 1,  2,  3,  4],
             [ 5,  6,  7,  8],
             [ 9, 10, 11, 12],
             [13, 14, 15, 16]], dtype=int32)

In [41]:
jnp.triu(A)

DeviceArray([[ 1,  2,  3,  4],
             [ 0,  6,  7,  8],
             [ 0,  0, 11, 12],
             [ 0,  0,  0, 16]], dtype=int32)

In [21]:
weights = jnp.array([1, 2, 3, 4, 5])
jnp.diag(weights)

DeviceArray([[1, 0, 0, 0, 0],
             [0, 2, 0, 0, 0],
             [0, 0, 3, 0, 0],
             [0, 0, 0, 4, 0],
             [0, 0, 0, 0, 5]], dtype=int32)