# Solving Brax Problem in EvoX

EvoX deeply dives into neuroevolution with Brax.
Here we will show an example of solving Brax problem in EvoX.

In [13]:
# install EvoX and Brax, skip it if you have already installed EvoX
from importlib.util import find_spec

if find_spec("evox") is None:
    %pip install evox

if find_spec("evox") is None:
    %pip install evox

In [14]:
# The dependent packages or functions in this example
import time

import torch
import torch.nn as nn

from evox.algorithms import PSO
from evox.problems.neuroevolution.brax import BraxProblem
from evox.utils import ParamsAndVector
from evox.workflows import EvalMonitor, StdWorkflow

## What is Brax

Brax is a fast and fully differentiable physics engine used for research and development of robotics, human perception, materials science, reinforcement learning, and other simulation-heavy applications. 

Here we will demonstrate a "hopper" environment of Brax. 

For more information, you can browse the [Github of Brax](https://github.com/google/brax).

## Design a neural network class

To start with, we need to decide which neural network we are about to construct.

Here we will give a simple Multilayer Perceptron (MLP) class. 

In [15]:
# Construct an MLP using PyTorch.
# This MLP has 3 layers.


class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.features = nn.Sequential(nn.Linear(11, 4), nn.Tanh(), nn.Linear(4, 3))

    def forward(self, x):
        x = self.features(x)
        return x

## Initiate a model

Through the ``SimpleMLP`` class, we can initiate a MLP model.

In [16]:
# Make sure that the model is on the same device, better to be on the GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Reset the random seed
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Initialize the MLP model
model = SimpleMLP().to(device)

### Initiate an adapter

An adapter can help us convert the data back-and-forth.

In [17]:
adapter = ParamsAndVector(dummy_model=model)

With an adapter, we can set out to do this Neuroevolution Task.

## Set up the running process

### Initiate an algorithm and a problem

We initiate a PSO algorithm, and the problem is a Brax problem in "hopper" environment.

In [18]:
# Set the population size
POP_SIZE = 10

# Get the bound of the PSO algorithm
model_params = dict(model.named_parameters())
pop_center = adapter.to_vector(model_params)
lower_bound = pop_center - 1
upper_bound = pop_center + 1

# Initialize the PSO, and you can also use any other algorithms
algorithm = PSO(
            pop_size=POP_SIZE,
            lb=lower_bound,
            ub=upper_bound,
            device=device,
        )
algorithm.setup()

# Initialize the Brax problem
problem = BraxProblem(
            policy=model,
            env_name="hopper",
            max_episode_length=1000,
            num_episodes=3,
            pop_size=POP_SIZE,
            device=device,
        )

In this case, we will be using 1000 steps for each episode, and the average reward of 3 episodes will be returned as the fitness value.

### Set an monitor

In [19]:
#set an monitor, and it can record the top 3 best fitnesses
pop_monitor = EvalMonitor(
            topk=3,
            device=device,
        )
pop_monitor.setup()

EvalMonitor()

### Initiate an workflow

In [20]:
# Initiate an workflow
workflow = StdWorkflow(opt_direction="max")
workflow.setup(
    algorithm=algorithm,
    problem=problem,
    solution_transform=adapter,
    monitor=pop_monitor,
    device=device,
)

### Run the workflow

Run the workflow and see the magic!

```{note}
The following block will take around 1 minute to run.
The time may vary depending on your hardware.
```

In [21]:
# Set the maximum number of generations
max_generation = 3

# Run the workflow
for index in range(max_generation):
        print(f"In generation {index}:")
        t = time.time()
        workflow.step()
        print(f"\tTime elapsed: {time.time() - t: .4f}(s).")
        monitor: EvalMonitor = workflow.get_submodule("monitor")
        print(f"\tTop fitness: {monitor.topk_fitness}")
        best_params = adapter.to_params(monitor.topk_solutions[0])
        print(f"\tBest params: {best_params}")

In generation 0:
	Time elapsed:  21.9649(s).
	Top fitness: tensor([-637.1375, -624.4020, -600.7485])
	Best params: {'features.0.weight': tensor([[ 0.6129, -0.9546,  0.3324,  0.0898, -0.4953, -0.0171,  0.1294, -0.2014,
         -0.8625, -0.4629,  1.0194],
        [-0.8231, -0.1714,  0.6062,  0.5981, -0.6424, -0.8224,  0.6978,  0.5227,
          0.7899,  0.6351, -0.9706],
        [ 0.1818,  0.0418,  0.4349, -0.2376,  0.9384, -0.0688, -0.5130,  0.3086,
          0.0793, -0.7233, -0.2675],
        [-0.7382,  0.2115,  0.3164, -0.0696,  0.0377, -0.2678, -1.2471, -0.0779,
          0.6473, -0.7199,  0.2839]], grad_fn=<ViewBackward0>), 'features.0.bias': tensor([-0.6246, -0.0321,  0.1846, -0.7549], grad_fn=<ViewBackward0>), 'features.2.weight': tensor([[-0.4736,  0.2928,  0.1783,  1.1226],
        [ 0.5829,  0.8406, -0.2133, -1.0378],
        [-0.9943, -1.0495, -0.1016, -0.3519]], grad_fn=<ViewBackward0>), 'features.2.bias': tensor([-0.0952,  0.0042, -0.0048], grad_fn=<ViewBackward0>)}
In gene

```{note}
The PSO wasn’t specialized for this type of tasks, so its performance limitations here are expected. Here we just show an example.
```

Hope you can have fun solving Brax problems in EvoX and enjoy your time!