In [None]:
!pip install einops
import jax
from jax import vmap, jit
from jax.lax import scan
from jax.experimental import optimizers
import jax.random as random
import jax.numpy as np
from einops import repeat

In [2]:
def f_opt(w):
    """A toy function to minimize
    """
    solution = np.array([23, 0.2, -19.91])
    return -np.sum((w - solution)**2)

In [3]:
def es(key, min_fn, num_params, num_workers=128, num_gens=1000, sigma=0.1, learning_rate=1e-1):
    """ Runs Evolution Strategies to optimize min_fn
        PDF: https://arxiv.org/abs/1703.03864

        Args:
            key             : PRNG key for deterministic calculation
            min_fn          : Function handler to be minimized
            num_params      : Number of parameters to pass into min_fn
            num_workers     : Number of workers (population size) 
            num_gens        : Number of generations to evaluate
            sigma           : Parameter jitter amount 
            learning_rate   : Adam optimizer step size
        
        Returns:
            An array of evolved parameters after number of generations
    """
    batched_func = jit(vmap(min_fn))
    opt_init, opt_update, get_params = optimizers.adam(step_size=learning_rate)
    theta = random.normal(key, (num_params,)) 
    opt_state = opt_init(theta)

    @jit
    def es_step(carry, xs):
        '''Antithetic ES with Adam Optimizer
        '''
        key, key_eps = random.split(carry[0])
        eps = random.normal(key_eps, (num_workers, num_params))

        theta_eps = np.concatenate([repeat(carry[1], 't -> p t', p=num_workers) + sigma * eps,
                                    repeat(carry[1], 't -> p t', p=num_workers) - sigma * eps])
        
        evals = batched_func(theta_eps)
        evals = (evals - np.mean(evals)) / np.std(evals)
        
        pop_grad = - 1 /(2*num_workers*sigma) * np.dot(np.concatenate([eps, -eps]).T, evals)
        opt_state = opt_update(0, pop_grad, carry[2])

        return [key_eps, get_params(opt_state), opt_state], get_params(opt_state)

    # Evolve for some number of generations
    carry, _ = scan(es_step, [key , theta, opt_state], (), length=num_gens)
    return carry[1]

In [4]:
key = random.PRNGKey(42)
theta = es(key, f_opt, num_params=3, num_workers=5096, num_gens=250, sigma=0.5, learning_rate=1e-1)
print(f'Final parameters: {theta}\nFinal Loss: {f_opt(theta):.3f}')

Final parameters: [ 23.011965     0.19921508 -19.910608  ]
Final Loss: -0.000
