# Working with extended applications

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EMI-Group/evox/blob/main/docs/source/guide/basics/2-problems.ipynb)

Working with extended applications in EvoX is easy.

In [1]:
# install evox, skip it if you have already installed evox
try:
    import evox
except ImportError:
    !pip install --disable-pip-version-check --upgrade -q evox brax
    import evox

In [2]:
from evox import algorithms, problems, workflows, monitors, utils

import jax.numpy as jnp
from jax import jit, random
from jax.tree_util import tree_map
from flax import linen as nn

from IPython.display import HTML, display

## Neuroevolution

Here we will be focusing on neuroevolution tasks, where one need to evolve a neural network that suits a certain tasks.

### Brax

To begin with we will be using Brax, a GPU accelerated physical engine that is also written in JAX.
Since Brax is also using JAX, running EvoX with Brax is quite easy.

We will be demostrating using the "swimmer" environment in Brax.

First we will need to decide how we are going to evolve a neural network.
In this case, we will be using a fixed-size ANN, and only evolve it's weights.

In [3]:
# construct an ANN using flax.
# "swimmer" environment has 8 observations and 2 actions
# and the actions are in (-1.0, 1.0)
class SwimmerPolicy(nn.Module):
    """A simple model for Hopper"""

    @nn.compact
    def __call__(self, x):
        x = x.astype(jnp.float32)
        x = x.reshape(-1)
        x = nn.Dense(32)(x)
        x = nn.tanh(x)
        x = nn.Dense(32)(x)
        x = nn.tanh(x)
        x = nn.Dense(2)(x)
        x = nn.tanh(x)

        return x

model = SwimmerPolicy()
weights = model.init(random.PRNGKey(42), jnp.zeros((8, )))

In [4]:
print(tree_map(lambda x: x.shape, weights)) # print the structure of the weights

{'params': {'Dense_0': {'bias': (32,), 'kernel': (8, 32)}, 'Dense_1': {'bias': (32,), 'kernel': (32, 32)}, 'Dense_2': {'bias': (2,), 'kernel': (32, 2)}}}


However, if we check the weights for this network, we will see that it's group of parameter sets,
and EC algorithms cannot directly work with data in this format.

Thankfully, EvoX provides some useful utilities to help us bridge the gap, and in this case, we have `TreeAndVector` to help us convert a tree-like struct into a vector and back.

In [5]:
adapter = utils.TreeAndVector(weights)

Now, `adapter` can help us convert the data back-and-forth.

- `to_vector` can convert a tree into a vector.
- `to_tree` can convert a vector back to a tree.

There are also batched version conversion.

- `batched_to_vector` can convert a batch of trees into a batch of vectors.
- `batched_to_tree` can convert a batch of vectors into a batch of trees.

In [6]:
vector_form_weights = adapter.to_vector(weights)
print(vector_form_weights.shape) # it's a single vector!

(1410,)


Now we can create an algorithm object.

In [7]:
# we wish the weights to be in the range [-10, 10]
lower_bound = jnp.full_like(vector_form_weights, -10.0)
upper_bound = jnp.full_like(vector_form_weights, 10.0)

# You can also use any other algorithms
algorithm = algorithms.PSO(
    lb=lower_bound,
    ub=upper_bound,
    pop_size=1024, # don't worry, it's fast
)

Now create brax-based problem.
The `batch_size` defines how many environment in a single batch, 
and `cap_episode` controls when to truncate the episode.

In this case, we set `batch_size` to 4096, which is the same as our population size.
This means the whole population will be evaluated in a single pass!

In [8]:
problem = problems.neuroevolution.Brax(
    env_name="swimmer",
    policy=jit(model.apply),
    cap_episode=1000,
)

Assemble our workflow and fire it!

Notice the `pop_transform` option.
It's used to convert the population into the tree-like structure that representing a neural network's weight.

In [9]:
monitor = monitors.StdSOMonitor()
workflow = workflows.UniWorkflow(
    algorithm,
    problem,
    monitor,
    pop_transform=adapter.batched_to_tree,
    record_pop=True,
)

In [10]:
state = workflow.init(random.PRNGKey(123))

# run the workflow for 50 iterations
for i in range(50):
    state = workflow.step(state)

In [11]:
monitor.flush()
best_weight = monitor.get_best_solution()
# shout out to Brax's team for making the html renderer
html_result, state = problem.visualize(state, random.key(0), adapter.to_tree(best_weight))

In [12]:
display(HTML(html_result))