# 06 - Stein Variational Evolution Strategy [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/06_sv_es.ipynb)

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX on NVIDIA GPU with:

In [None]:
%pip install -U "jax[cuda]"

Then, install `evosax` from PyPi:

In [None]:
%pip install -U "evosax[examples]"

## Import

In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from evosax.algorithms import SV_CMA_ES as ES

In [3]:
seed = 0
key = jax.random.key(seed)

## Problem - Classic Black Box Optimization Benchmark

In [4]:
from evosax.problems import BBOBProblem

fn_name = "rosenbrock"
num_dims = 2
problem = BBOBProblem(
    fn_name=fn_name,
    num_dims=num_dims,
    f_opt=0,
    sample_rotations=False,
    seed=seed,
)

problem_state = problem.init(key)

key, subkey = jax.random.split(key)
solution = problem.sample(subkey)

## Stein Variational CMA-ES

In [5]:
num_generations = 64
population_size = 128
num_populations = 128

es = ES(
    population_size=population_size,
    num_populations=num_populations,
    solution=solution,
)

params = es.default_params
params = params.replace(alpha=2.0)

key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_populations)
solutions = jax.vmap(problem.sample)(keys)

key, subkey = jax.random.split(key)
state = es.init(subkey, solutions, params)

## Run

In [6]:
def step(carry, key):
    state, params, problem_state = carry
    key_ask, key_eval, key_tell = jax.random.split(key, 3)

    population, state = es.ask(key_ask, state, params)

    population = jnp.clip(population, -5, 5)
    fitness, problem_state, _ = problem.eval(key_eval, population, problem_state)

    state, metrics = es.tell(key_tell, population, fitness, state, params)

    return (state, params, problem_state), metrics

In [7]:
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num_generations)
_, metrics = jax.lax.scan(
    step,
    (state, params, problem_state),
    keys,
)

## Visualize

In [11]:
import matplotlib.animation as animation
import mediapy
from matplotlib.animation import PillowWriter

# Create figure and axis
fig, ax = plt.subplots(figsize=(6, 5))

# Visualize the problem landscape
key, subkey = jax.random.split(key)
problem.visualize_2d(subkey, ax=ax)

# Extract mean trajectory from metrics
means = metrics["mean"]  # Shape: (num_generations, num_populations, num_dims)

# Create a scatter plot for all population means
scatter = ax.scatter([], [], color="blue", s=10)

# Add a title with generation counter
title = ax.set_title("Generation: 0")


def init():
    scatter.set_offsets(jnp.empty((0, 2)))
    return scatter, title


def update(frame):
    # Update the positions of all population means
    # means[frame] has shape (num_populations, num_dims)
    scatter.set_offsets(means[frame])

    # Update the title with current generation
    title.set_text(f"Generation: {frame}")

    return scatter, title


# Create the animation
anim = animation.FuncAnimation(
    fig, update, frames=len(means), init_func=init, blit=True
)

plt.close()

# Create a writer
path = "anim.gif"
anim.save(path, writer=PillowWriter())

# Display the GIF in the notebook
mediapy.show_video(mediapy.read_video(path), fps=20, codec="gif")