# <h1><center>Evolution Strategy Tutorial</center></h1>

This notebook provides an introductory tutorial to Evolution Strategies (ES). Specifically, we will apply the [OpenAI-ES](https://arxiv.org/abs/1703.03864) algorithm to optimise the gait of a brittle star robot. The [brittle star robot and its environment](https://github.com/Co-Evolve/brt/tree/main/biorobot/brittle_star) is part of the [**the Bio-inspired Robotics Testbed (BRT)**](https://github.com/Co-Evolve/brt). Instead of directly evolving joint-level actions, we will evolve modulation parameters for a Central Pattern Generator, that in turn outputs the joint-level actions.

# Evolution Strategies

Start by reading this [blogpost](https://blog.otoro.net/2017/10/29/visual-evolution-strategies/) for a good visual introduction to evolution strategies.

Instead of implementing our own Evolution Strategy, we will be using the [EvoSax](https://github.com/RobertTLange/evosax) python library. Take a look at the README to get an idea of how its interface works.
Before running the next cell, make sure you have actived the appropriate Jupyter kernel!

In [None]:
import sys
!{sys.executable} -m pip install --user evosax
import evosax

# Case study: Evolving CPG modulation parameters for brittle star locomotion

Evolution strategies are a type of black-box optimizer that are particularly effective for optimizing continuous parameters. To apply these strategies, two essential components are required: first, a clear definition of a candidate solution—essentially a list of parameters representing a potential answer to the optimization problem—and second, an evaluation function that can assess the performance of each candidate solution.

In this case study, we utilize the OpenAI-ES algorithm to optimize the modulation parameters of a Central Pattern Generator (CPG), with the objective of enabling the brittle star robot to move as far from its starting position as possible. The CPG system employed here is the same as the one described in the CPG tutorial, which consists of a network of $N$ coupled oscillators controlled by both amplitude and offset parameters. Each arm of the brittle star is actuated by two oscillators—one managing the in-plane motions and another responsible for the out-of-plane motions, with the outputs shared across all segments of the arm.

The topology of the CPG is illustrated in the accompanying diagram below. Rather than manually defining the phase biases between oscillators to achieve a rowing-like behavior, this notebook optimizes these phase biases along with other parameters. Here we thus deviate from the CPG and Q-Learning tutorials, in the sense that we do not manually define CPG parameters based on biological observations, but rather optimise it completely. Specifically, the parameters subject to optimization include a single shared frequency ($\omega$);, for each oscillator, an amplitude ($R$) and an offset ($X$); and for each coupling between oscillators, a phase bias ($\rho_{ij}$), where the biases satisfy the condition $\rho_{ij} = -\rho_{ji}$ as described in the CPG tutorial.

Given that the system comprises 10 oscillators and 15 bi-directional couplings, the overall optimization problem involves tuning a total of $1 + (10 \times 2) + (15 \times 1) = 36$ parameters. 

It is important to note that we are evolving an open-loop controller, meaning that the controller is non-adaptive and does not incorporate sensory feedback to modify its operation in real-time. Later in the project, you will close the loop and for instance optimise a seperate ANN to output CPG modulations based on sensory inputs.

![](assets/brittle_star_cpg.png)

### Environment setup
First things first, let's set up our brittle star simulation environment. We will use the undirected locomotion variant. The following cell will first do some preliminary checks to make sure that the underlying physics engine (MuJoCo) is correctly loaded and to make sure that JAX can access the GPU.

In [None]:
import os
import subprocess
import logging
import jax

try:
    if subprocess.run('nvidia-smi').returncode:
        raise RuntimeError(
                'Cannot communicate with GPU. '
                'Make sure you are using a GPU Colab runtime. '
                'Go to the Runtime menu and select Choose runtime type.'
                )

    # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
    # This is usually installed as part of an Nvidia driver package, but the Colab
    # kernel doesn't install its driver via APT, and as a result the ICD is missing.
    # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
    NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
    if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
        with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
            f.write(
                    """{
                            "file_format_version" : "1.0.0",
                            "ICD" : {
                                "library_path" : "libEGL_nvidia.so.0"
                            }
                        }
                        """
                    )

    # Configure MuJoCo to use the EGL rendering backend (requires GPU)
    print('Setting environment variable to use GPU rendering:')
    %env MUJOCO_GL=egl

    # xla_flags = os.environ.get('XLA_FLAGS', '')
    # xla_flags += ' --xla_gpu_triton_gemm_any=True'
    # os.environ['XLA_FLAGS'] = xla_flags

    print(jax.devices('gpu'))

except Exception:
    logging.warning("Failed to initialize GPU. Everything will run on the cpu.")

print(jax.devices())

try:
    print('Checking that the mujoco installation succeeded:')
    import mujoco

    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise e from RuntimeError(
            'Something went wrong during installation. Check the shell output above '
            'for more information.\n'
            'If using a hosted Colab runtime, make sure you enable GPU acceleration '
            'by going to the Runtime menu and selecting "Choose runtime type".'
            )

print('MuJoCo installation successful.')

This next cell (similar to previous tutorials) defines the `morphology_specification` (i.e. the brittle star morphology), the `arena_configuration` (i.e. some settings w.r.t. the aquarium in which we place the brittle star) and the `environment_configuration` (which defines and configures the undirected locomotion task). The cell also implements some utility functions for visualization.

In [None]:
from biorobot.brittle_star.environment.undirected_locomotion.dual import BrittleStarUndirectedLocomotionEnvironment
from biorobot.brittle_star.environment.undirected_locomotion.shared import \
    BrittleStarUndirectedLocomotionEnvironmentConfiguration
import numpy as np
from moojoco.environment.base import MuJoCoEnvironmentConfiguration
from typing import List
import mediapy as media
from biorobot.brittle_star.mjcf.morphology.morphology import MJCFBrittleStarMorphology
from biorobot.brittle_star.mjcf.morphology.specification.default import default_brittle_star_morphology_specification
from biorobot.brittle_star.mjcf.arena.aquarium import AquariumArenaConfiguration, MJCFAquariumArena

morphology_specification = default_brittle_star_morphology_specification(
        num_arms=5, num_segments_per_arm=3, use_p_control=True, use_torque_control=False
        )
arena_configuration = AquariumArenaConfiguration(
        size=(3, 3), sand_ground_color=False, attach_target=False, wall_height=1.5, wall_thickness=0.1
        )
environment_configuration = BrittleStarUndirectedLocomotionEnvironmentConfiguration(
        joint_randomization_noise_scale=0.0,
        render_mode="rgb_array",
        simulation_time=5,
        num_physics_steps_per_control_step=10,
        time_scale=2,
        camera_ids=[0, 1],
        render_size=(480, 640)
        )


def create_environment() -> BrittleStarUndirectedLocomotionEnvironment:
    morphology = MJCFBrittleStarMorphology(
            specification=morphology_specification
            )
    arena = MJCFAquariumArena(
            configuration=arena_configuration
            )
    env = BrittleStarUndirectedLocomotionEnvironment.from_morphology_and_arena(
            morphology=morphology, arena=arena, configuration=environment_configuration, backend="MJX"
            )
    return env


def post_render(
        render_output: List[np.ndarray],
        environment_configuration: MuJoCoEnvironmentConfiguration
        ) -> np.ndarray:
    num_cameras = len(environment_configuration.camera_ids)
    num_envs = len(render_output) // num_cameras

    if num_cameras > 1:
        # Horizontally stack frames of the same environment
        frames_per_env = np.array_split(render_output, num_envs)
        render_output = [np.concatenate(env_frames, axis=1) for env_frames in frames_per_env]

    # Vertically stack frames of different environments
    render_output = np.concatenate(render_output, axis=0)

    return render_output[:, :, ::-1]  # RGB to BGR


def show_video(
        images: List[np.ndarray | None],
        sim_time: float,
        path: str | None = None
        ) -> str | None:
    if path:
        media.write_video(path=path, images=images)
    return media.show_video(images=images, fps=len(images)//sim_time)

Now we can create our environment and `jax.jit` the `step` and `reset` functions. 

In [None]:
rng = jax.random.PRNGKey(seed=0)
env = create_environment()
env_reset_fn = jax.jit(env.reset)
env_step_fn = jax.jit(env.step)

The next cell prints out the environment's observation space, action space and the `info` dictionary that our environment updates every step. It also renders a single frame, showing the initial state of the environment after a reset.

In [None]:
print("Observation space:")
print(env.observation_space)
print()
print("Action space:")
print(env.action_space)
rng, sub_rng = jax.random.split(rng, 2)
env_state = env_reset_fn(rng=sub_rng)
print("env_state.info:")
print(env_state.info)
media.show_image(post_render(env.render(env_state), environment_configuration=env.environment_configuration))

### CPG model
The next cell first copies the CPG implementation, the CPG creation, and the CPG readout functions (i.e. CPG state to joint-level actuator actions) from the CPG tutorial.

In [None]:
import functools
from flax import struct
import jax
import jax.numpy as jnp
import chex
from functools import partial
from typing import Tuple

from typing import Callable


def euler_solver(
        current_time: float,
        y: float,
        derivative_fn: Callable[[float, float], float],
        delta_time: float
        ) -> float:
    slope = derivative_fn(current_time, y)
    next_y = y + delta_time * slope
    return next_y


@struct.dataclass
class CPGState:
    time: float
    adjacency: jnp.ndarray
    weights: jnp.ndarray
    phases: jnp.ndarray
    dot_amplitudes: jnp.ndarray  # first order derivative of the amplitude
    amplitudes: jnp.ndarray
    dot_offsets: jnp.ndarray  # first order derivative of the offset 
    offsets: jnp.ndarray
    outputs: jnp.ndarray

    # We'll make these modulatory parameters part of the state as they will change as well
    R: jnp.ndarray
    X: jnp.ndarray
    omegas: jnp.ndarray
    rhos: jnp.ndarray


class CPG:
    def __init__(
            self,
            adjacency: jnp.ndarray,
            amplitude_gain: float = 20,
            offset_gain: float = 20,
            dt: float = 0.01
            ) -> None:
        self._adjacency = adjacency
        self._amplitude_gain = amplitude_gain
        self._offset_gain = offset_gain
        self._dt = dt
        self._solver = euler_solver

    @staticmethod
    def phase_de(
            weights: jnp.ndarray,
            amplitudes: jnp.ndarray,
            phases: jnp.ndarray,
            phase_biases: jnp.ndarray,
            omegas: jnp.ndarray
            ) -> jnp.ndarray:
        @jax.vmap  # vectorizes this function for us over an additional batch dimension (in this case over all oscillators)
        def sine_term(
                phase_i: float,
                phase_biases_i: float
                ) -> jnp.ndarray:
            return jnp.sin(phases - phase_i - phase_biases_i)

        couplings = jnp.sum(weights * amplitudes * sine_term(phase_i=phases, phase_biases_i=phase_biases), axis=1)
        return omegas + couplings

    @staticmethod
    def second_order_de(
            gain: jnp.ndarray,
            modulator: jnp.ndarray,
            values: jnp.ndarray,
            dot_values: jnp.ndarray
            ) -> jnp.ndarray:
        return gain * ((gain / 4) * (modulator - values) - dot_values)

    @staticmethod
    def first_order_de(
            dot_values: jnp.ndarray
            ) -> jnp.ndarray:
        return dot_values

    @staticmethod
    def output(
            offsets: jnp.ndarray,
            amplitudes: jnp.ndarray,
            phases: jnp.ndarray
            ) -> jnp.ndarray:
        return offsets + amplitudes * jnp.cos(phases)

    def reset(
            self,
            rng: chex.PRNGKey
            ) -> CPGState:
        num_oscillators = self._adjacency.shape[0]
        
        # noinspection PyArgumentList
        state = CPGState(
                adjacency=self._adjacency,
                phases=jax.random.uniform(
                        key=rng, shape=(num_oscillators,), dtype=jnp.float32, minval=-0.001, maxval=0.001
                        ),
                amplitudes=jnp.zeros(num_oscillators),
                offsets=jnp.zeros(num_oscillators),
                dot_amplitudes=jnp.zeros(num_oscillators),
                dot_offsets=jnp.zeros(num_oscillators),
                outputs=jnp.zeros(num_oscillators),
                time=0.0,
                R=jnp.zeros(num_oscillators),
                X=jnp.zeros(num_oscillators),
                omegas=jnp.zeros(num_oscillators),
                weights=jnp.zeros_like(self._adjacency),
                rhos=jnp.zeros_like(self._adjacency)
                )
        return state

    @functools.partial(jax.jit, static_argnums=(0,))
    def step(
            self,
            state: CPGState
            ) -> CPGState:
        # Update phase
        new_phases = self._solver(
                current_time=state.time,
                y=state.phases,
                derivative_fn=lambda
                    t,
                    y: self.phase_de(
                        omegas=state.omegas,
                        amplitudes=state.amplitudes,
                        phases=y,
                        phase_biases=state.rhos,
                        weights=state.weights
                        ),
                delta_time=self._dt
                )
        new_dot_amplitudes = self._solver(
                current_time=state.time,
                y=state.dot_amplitudes,
                derivative_fn=lambda
                    t,
                    y: self.second_order_de(
                        gain=self._amplitude_gain, modulator=state.R, values=state.amplitudes, dot_values=y
                        ),
                delta_time=self._dt
                )
        new_amplitudes = self._solver(
                current_time=state.time,
                y=state.amplitudes,
                derivative_fn=lambda
                    t,
                    y: self.first_order_de(dot_values=state.dot_amplitudes),
                delta_time=self._dt
                )
        new_dot_offsets = self._solver(
                current_time=state.time,
                y=state.dot_offsets,
                derivative_fn=lambda
                    t,
                    y: self.second_order_de(
                        gain=self._offset_gain, modulator=state.X, values=state.offsets, dot_values=y
                        ),
                delta_time=self._dt
                )
        new_offsets = self._solver(
                current_time=0,
                y=state.offsets,
                derivative_fn=lambda
                    t,
                    y: self.first_order_de(dot_values=state.dot_offsets),
                delta_time=self._dt
                )

        new_outputs = self.output(offsets=new_offsets, amplitudes=new_amplitudes, phases=new_phases)
        # noinspection PyUnresolvedReferences
        return state.replace(
                phases=new_phases,
                dot_amplitudes=new_dot_amplitudes,
                amplitudes=new_amplitudes,
                dot_offsets=new_dot_offsets,
                offsets=new_offsets,
                outputs=new_outputs,
                time=state.time + self._dt
                )


def create_cpg() -> CPG:
    ip_oscillator_indices = jnp.arange(0, 10, 2)
    oop_oscillator_indices = jnp.arange(1, 10, 2)

    adjacency_matrix = jnp.zeros((10, 10))
    # Connect oscillators within an arm
    adjacency_matrix = adjacency_matrix.at[ip_oscillator_indices, oop_oscillator_indices].set(1)
    # Connect IP oscillators of neighbouring arms
    adjacency_matrix = adjacency_matrix.at[
        ip_oscillator_indices, jnp.concatenate((ip_oscillator_indices[1:], jnp.array([ip_oscillator_indices[0]])))].set(
            1
            )
    # Connect OOP oscillators of neighbouring arms
    adjacency_matrix = adjacency_matrix.at[oop_oscillator_indices, jnp.concatenate(
            (oop_oscillator_indices[1:], jnp.array([oop_oscillator_indices[0]]))
            )].set(1)

    # Make adjacency matrix symmetric (i.e. make all connections bi-directional)
    adjacency_matrix = jnp.maximum(adjacency_matrix, adjacency_matrix.T)

    return CPG(
            adjacency=adjacency_matrix,
            amplitude_gain=40,
            offset_gain=40,
            dt=environment_configuration.control_timestep
            )


def get_oscillator_indices_for_arm(
        arm_index: int
        ) -> Tuple[int, int]:
    return arm_index * 2, arm_index * 2 + 1

@jax.jit
def map_cpg_outputs_to_actions(
        cpg_state: CPGState
        ) -> jnp.ndarray:
    num_arms = morphology_specification.number_of_arms
    num_oscillators_per_arm = 2
    num_segments_per_arm = morphology_specification.number_of_segments_per_arm[0]

    cpg_outputs_per_arm = cpg_state.outputs.reshape((num_arms, num_oscillators_per_arm))
    cpg_outputs_per_segment = cpg_outputs_per_arm.repeat(num_segments_per_arm, axis=0)

    actions = cpg_outputs_per_segment.flatten()
    return actions

### Evaluation function

Now we can write our evaluation function, i.e. the function that takes in the candidate solution (i.e. in this case the CPG modulation parameters), evaluates it (i.e. runs the brittle star simulation with the modulated CPG generating the actions) and returns a fitness score (i.e. the total distance travelled since the start of the simulation).

Let's start by implementing a helper function that takes in parameters and actually modulates the CPG with them.

In [None]:
def modulate_cpg(cpg_state: CPGState, parameters: jnp.ndarray) -> CPGState:
    num_oscillators = cpg_state.R.shape[0]

    R = parameters[:num_oscillators]
    X = parameters[num_oscillators : 2 * num_oscillators]
    omegas = parameters[2 * num_oscillators] * jnp.ones(num_oscillators)
    rhos = parameters[2 * num_oscillators + 1 :]

    # The rho's (phase biases) need to be reshaped into the same shape as the adjacency matrix

    # First we need to get coupling indices out of the adjacency matrix
    # We only want a single index per bi-directional coupling (the inverse direction is the negation)
    # So let's first get an adjacency matrix with all elements above the diagonal set to zero
    single_direction_adjacency = jnp.tril(cpg_state.adjacency)

    # Now we want to get the indices of elements that are 1 (i.e., indices of the couplings in the adjacency matrix)
    # Because JAX requires static shapes, we need to explicitly pass the size argument here.
    coupling_indices = jnp.where(single_direction_adjacency == 1, size=15)
    # With these indices, we can set the phase biases of the couplings
    rho_matrix = single_direction_adjacency.at[coupling_indices].set(rhos)

    # Make bidirectional with negation
    rho_matrix = rho_matrix - rho_matrix.T

    cpg_state = cpg_state.replace(R=R, X=X, omegas=omegas, rhos=rho_matrix)

    return cpg_state

Now we'll write two variants of our evaluation function: one that just runs the simulation to be used during optimisation (making efficient use of `jax.lax.scan`), and one that also visualises the simulation to be used for analysis.

In [None]:
def evaluate_parameters(rng: chex.PRNGKey,
            parameters: jnp.ndarray) -> float:
    rng, env_rng = jax.random.split(key=rng, num=2)
    env_state = env_reset_fn(env_rng)
   
    cpg = create_cpg()
    rng, cpg_rng = jax.random.split(key=rng, num=2)
    cpg_state = cpg.reset(rng=cpg_rng)
    
    cpg_state = modulate_cpg(cpg_state=cpg_state, parameters=parameters)

    def _step(carry, _):
        _env_state, _cpg_state = carry
        
        _cpg_state = cpg.step(state=_cpg_state)
        _actions = map_cpg_outputs_to_actions(cpg_state=_cpg_state)
        _env_state = env_step_fn(_env_state, _actions)
        
        return (_env_state, _cpg_state), None
    
    (final_env_state, _), _ = jax.lax.scan(
        f=_step,
        init=(env_state, cpg_state),
        length=env.environment_configuration.total_num_control_steps
        )
   
    fitness = final_env_state.info["xy_distance_from_origin"]
    return fitness 
    

def evaluate_parameters_visual(
        rng: chex.PRNGKey,
        parameters: jnp.ndarray,
        ) -> float:
    rng, env_rng = jax.random.split(key=rng, num=2)
    env_state = env_reset_fn(env_rng)
   
    cpg = create_cpg()
    rng, cpg_rng = jax.random.split(key=rng, num=2)
    cpg_state = cpg.reset(rng=cpg_rng)

    cpg_state = modulate_cpg(cpg_state=cpg_state, parameters=parameters)

    frames = []
    
    while not (env_state.terminated | env_state.truncated):
        cpg_state = cpg.step(state=cpg_state)
        actions = map_cpg_outputs_to_actions(cpg_state=cpg_state)
        env_state = env_step_fn(env_state, actions)
        frame = post_render(env.render(state=env_state), environment_configuration=environment_configuration)
        frames.append(frame)        

    show_video(images=frames, sim_time=env.environment_configuration.simulation_time)
    
    fitness = env_state.info["xy_distance_from_origin"]
    return fitness 

### Applying EvoSAX

Now we have everything we need to apply [EvoSax' OpenES](https://github.com/RobertTLange/evosax/blob/main/evosax/strategies/open_es.py).

Similar to the Q-Learning tutorial, we will use [WandB](https://wandb.ai/) for logging.

In [None]:
from evosax import OpenES
from tqdm import tqdm
import wandb 

NUM_GENERATIONS = 100
NUM_PARAMETERS = 36
POP_SIZE = 100
wandb.init(
        project="SEL3-2024-ES-Tutorial"
        )

# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = OpenES(popsize=POP_SIZE, num_dims=NUM_PARAMETERS, maximize=True)
es_params = strategy.default_params
es_state = strategy.initialize(rng, es_params)

# Important: We parallelise the evaluation using jax.vmap!
evaluate_fn = jax.jit(jax.vmap(evaluate_parameters))
# Run ask-eval-tell loop - 
for generation in tqdm(range(NUM_GENERATIONS), desc="Generation: "):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, es_state = strategy.ask(rng_gen, es_state, es_params)
    
    
    rng_eval = jax.random.split(key=rng_eval, num=POP_SIZE)
    fitness = evaluate_fn(rng_eval, x)
    es_state = strategy.tell(x, fitness, es_state, es_params)
    
    wandb.log({"max_fitness": jnp.max(fitness), "mean_fitness": jnp.mean(fitness)}, step=generation) 
     
# Get best overall population member & its fitness
es_state.best_member, es_state.best_fitness

wandb.finish()
evaluate_parameters_visual(rng=rng, parameters=es_state.best_member)

# Excercises and next steps

In general: try to improve the optimisation as well as possible so that you can generate better and more realistic gaits. Always try to reason about and predict the influence of a certain modification before optimisation, and compare your predictions with the actual results afterward! This is the best and fastest way to **improve your intuition, which is the main goal of this tutorial**!

We applied an Evolution Strategy (ES) to optimize the modulation parameters of our CPG. Remember that ES is a **black-box optimizer**, meaning it could also optimize **neural network weights**, providing an alternative to Reinforcement Learning for adaptive robot control.

* Some next steps:
    * Take another look at the [EvoSax](https://github.com/RobertTLange/evosax) library and check out which hyperparameters our evolution strategy has, and what the `FitnessShaper` can do.
    * Currently, our `ES.ask` method can propose any value as a parameter. This can lead to strange values such as negative amplitudes. Mitigate this by clipping the values of the candidate solutions to $[-1, 1]$ using [EvoParams](https://github.com/RobertTLange/evosax/blob/cae0e9271794f4702ba16a8c6fcaaf8595f4a2f3/evosax/strategies/open_es.py#L28). This will bound the parameters proposed by `ES.ask`. In the `modulate_cpg` you can then rescale the different parameters to appropriate ranges (e.g. amplitude is bounded between $(0, 1)$, while offsets are bounded between $(-1, 1)$).
        * Similarly, rescale and bound your CPG outputs to the actual joint range of motions.
    * Enrich the fitness function to stimulate more realistic gaits (e.g. by incorporating an energy penalty in the fitness calculation). To do so, take a closer look at which observations the environment returns.
    * Instead of evolving a gait that can move our brittle star the furthest in any direction, optimise it towards maximising distance along the x-axis. Does left-right symmetry arise?
    * Undo the sharing of oscillators between the segments of a single arm, and instead use one oscillator per joint. This will allow more complex behaviours to be evolved.
    * Increase the optimisation scope and also optimise the coupling weights.
    * Improve the logging to WandB
        * Log a video of the current best candidate solution every $N$ generations
        * Log a checkpoint of the current best candidate solution every $N$ generations
        * Log parameters related to the Evolution Strategy (e.g. learning rate)