# Classic control with Gym

In this notebook, we will use the Gym to train an agent that solves classic control problems.

In [1]:
# install evox, skip it if you have already installed evox
try:
    import evox
except ImportError:
    !pip install --disable-pip-version-check --upgrade -q evox gymnasium flax
    import evox

In [2]:
from evox import workflows, algorithms, problems
from evox.monitors import StdSOMonitor
from evox.utils import TreeAndVector
import jax
import jax.numpy as jnp
from flax import linen as nn

In [3]:
gym_name = "Pendulum-v1" # choose a setup

def tanh2(x):
    return 2 * nn.tanh(x)

policy_params = {
    "Acrobot-v1": (3, (6,), jnp.argmax),
    "CartPole-v1": (2, (4,), jnp.argmax),
    "MountainCarContinuous-v0": (1, (2,), nn.tanh),
    "MountainCar-v0": (3, (2,), jnp.argmax),
    "Pendulum-v1": (1, (3,), tanh2),
}

In [4]:
# define a policy model
class ClassicPolicy(nn.Module):
    """A simple model for Classic Control problem"""

    @nn.compact
    def __call__(self, x):
        x = x.at[1].multiply(10)  # normalization
        x = nn.Dense(16)(x)
        x = nn.relu(x)
        x = nn.Dense(policy_params[gym_name][0])(x)

        return policy_params[gym_name][2](x)

In [5]:
key = jax.random.PRNGKey(42)
model_key, workflow_key = jax.random.split(key)

model = ClassicPolicy()
params = model.init(model_key, jnp.zeros(policy_params[gym_name][1]))
adapter = TreeAndVector(params)
monitor = StdSOMonitor()
problem = problems.neuroevolution.Gym(
    env_name=gym_name,
    policy=jax.jit(model.apply),
    num_workers=16, # adjust according to your need
    controller_options={
        "num_cpus": 0,
        "num_gpus": 0,
    },
    worker_options={"num_cpus": 1, "num_gpus": 1 / 16},
    batch_policy=False,
)
center = adapter.to_vector(params)
# create a workflow
workflow = workflows.StdWorkflow(
    algorithm=algorithms.CMAES(center_init=center, init_stdev=1, pop_size=64),
    problem=problem,
    pop_transform=adapter.batched_to_tree,
    monitor=monitor,
    opt_direction="max"
)

2023-10-24 15:54:46,501	INFO worker.py:1553 -- Started a local Ray instance.


Now run the workflow.
You may see warnings like
```
CUDA backend failed to initialize: Unable to load CUDA.
```
This is expected behaivor, because we have a controller thread that manages a group of Gym workers,
and the controller thread does not use GPU.

If the program stucks, you may want to check whether is `num_workers` is larger than the number of available cores on your computer.

In [6]:
# init the workflow
state = workflow.init(workflow_key)
# run the workflow for 100 steps
for i in range(100):
    state = workflow.step(state)

best_fitness = monitor.get_best_fitness()
print(best_fitness)

[2m[36m(Controller pid=641434)[0m CUDA backend failed to initialize: Unable to load CUDA. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


-0.114485
