In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse

from qdax.core.cmaes import CMAES

In [None]:
# jax.config.update('jax_disable_jit', True)

### Define fitness function

In [None]:
fitness_func = lambda x: -jnp.sum((x - 5.0) * (x - 5.0), axis=-1)

In [None]:
num_iterations = 7000 # 70000 #70000 #10000
num_dimensions = 100 #100 #1000 #@param {type:"integer"} # try 20 and 100
grid_shape = (500, 500) # (500, 500) 
batch_size = 500 #@param {type:"integer"}
sigma_g = 0.5 # 0.5 #@param {type:"number"}
minval = -5.12
num_best = 250 #36

def rastrigin_scoring(x: jnp.ndarray):
    first_term = 10 * x.shape[-1]
    second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))
    return -(first_term + second_term)
#     return (first_term + second_term)

fitness_func = jax.vmap(rastrigin_scoring)

### Instantiate CMA-ES object

In [None]:
cmaes = CMAES(
    population_size=batch_size,
    num_best=num_best,
    search_dim=num_dimensions, #2, 
    fitness_function=fitness_func,
    mean_init=jnp.zeros((num_dimensions,)), #jnp.asarray([-1.0, -1.0]),
    init_sigma=sigma_g, #1.0,
    init_step_size=0.5, #0.5,
    delay_eigen_decomposition=True, #False,
)

### Initialization

In [None]:
state = cmaes.init()
key = jax.random.PRNGKey(0)

In [None]:
cmaes._weights

### Iterations

In [None]:
%%time

means = [state.mean]
covs = [state.cov_matrix]

for _ in range(num_iterations):
    
    samples, key = cmaes.sample(state, key)
#     state = cmaes.update(state, samples)

    fitnesses = -fitness_func(samples)
    idx_sorted = jnp.argsort(fitnesses)
    sorted_candidates = samples[idx_sorted[: num_best]]

    state = cmaes.update_state(state, sorted_candidates)
    
    stop_condition = cmaes.stop_condition(state)
    print("Step size: ", state.step_size)
    print("Stop condition: ", stop_condition)
    if stop_condition:
        break
    
    means.append(state.mean)
    covs.append(state.cov_matrix)

In [None]:
fitnesses

In [None]:
worst_objective = rastrigin_scoring(-jnp.ones(num_dimensions) * 5.12)
# worst_objective = rastrigin_scoring(jnp.zeros(num_dimensions))
best_objective = rastrigin_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4)

(fitnesses - worst_objective) * 100 / (best_objective - worst_objective)

In [None]:
means[-1]

### Visualization

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

# sample points to show fitness landscape
x = jax.random.uniform(key, minval=-6, maxval=6, shape=(100000, 2))
f_x = fitness_func(x)

# plot fitness landscape
points = ax.scatter(x[:, 0], x[:, 1], c=f_x, s=0.1)
fig.colorbar(points)

# plot cma-es trajectory
for mean, cov in zip(means, covs):
#     print("Mean: ", mean)
#     print("Covariance: ", cov)
#     print("Covariance: ", cov[0])
#     print("Covariance: ", cov[1])
    ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')
    ax.add_patch(ellipse)
    ax.plot(mean[0], mean[1], color='k', marker='x')

In [None]:
N = 4

pc = jnp.zeros((N,1))
ps = jnp.zeros((N,1))  # evolution paths for C and sigma
B = jnp.eye(N,N)    # B defines the coordinate system
D = jnp.ones((N,1))      # diagonal D defines the scaling
C = B * jnp.diag(D**2) * B.T     # covariance matrix C
invsqrtC = B * jnp.diag(D**(-1)) * B.T

In [None]:
pc

In [None]:
ps

In [None]:
B

In [None]:
D

In [None]:
jnp.diag(D)

In [None]:
D**2

In [None]:
B * jnp.diag(D**2) * B.T

In [None]:
B * jnp.diag(D**(-1)) * B.T

In [None]:
C

In [None]:
jnp.triu(C)

In [None]:
A = jnp.array(
[[1, 2, 3, 4],
 [5, 6, 7, 8],
 [9, 10, 11, 12],
 [13, 14, 15, 16]])

In [None]:
A

In [None]:
jnp.triu(A)

In [None]:
weights = jnp.array([1, 2, 3, 4, 5])
jnp.diag(weights)