In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

EvoTF_ES With MLP Controller

In [2]:
import jax
import numpy as np
from evojax.obs_norm import ObsNormalizer
from evojax.sim_mgr import SimManager
from evojax.task.brax_task import BraxTask
from evojax.policy import MLPPolicy

from evosax import Strategies
from evosax.utils.evojax_wrapper import Evosax2JAX_Wrapper

In [3]:
def get_brax_task(
    env_name = "humanoid",
    hidden_dims = [32, 32, 32, 32],
):
    train_task = BraxTask(env_name, test=False)
    test_task = BraxTask(env_name, test=True)
    policy = MLPPolicy(
        input_dim=train_task.obs_shape[0],
        output_dim=train_task.act_shape[0],
        hidden_dims=hidden_dims,
    )
    return train_task, test_task, policy

In [4]:
print(jax.devices())

[cuda(id=0)]


In [5]:
train_task, test_task, policy = get_brax_task("humanoid")
solver = Evosax2JAX_Wrapper(
    Strategies["EvoTF_ES"],
    param_size=policy.num_params,
    pop_size=256,
    # es_config={"maximize": True,
    #            "centered_rank": True,
    #            "lrate_init": 0.01,
    #            "lrate_decay": 0.999,
    #            "lrate_limit": 0.001},
    # es_params={"sigma_init": 0.05,
    # "sigma_decay": 0.999,
    # "sigma_limit": 0.01},
    seed=0,
)
obs_normalizer = ObsNormalizer(
    obs_shape=train_task.obs_shape, dummy=not True
)
sim_mgr = SimManager(
    policy_net=policy,
    train_vec_task=train_task,
    valid_vec_task=test_task,
    seed=0,
    obs_normalizer=obs_normalizer,
    pop_size=256,
    use_for_loop=False,
    n_repeats=16,
    test_n_repeats=1,
    n_evaluations=128
)

print(f"START EVOLVING {policy.num_params} PARAMS.")
# Run ES Loop.
for gen_counter in range(10):
    params = solver.ask()
    scores, _ = sim_mgr.eval_params(params=params, test=False)
    solver.tell(fitness=scores)
    if gen_counter == 0 or (gen_counter + 1) % 2 == 0:
        test_scores, _ = sim_mgr.eval_params(
            params=solver.best_params, test=True
        )
        print(
            {
                "num_gens": gen_counter + 1,
            },
            {
                "train_perf": float(np.nanmean(scores)),
                "test_perf": float(np.nanmean(test_scores)),
            },
        )

2024-08-29 00:03:23.342155: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.5 which is older than the ptxas CUDA version (12.6.20). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
MLPPolicy: 2024-08-29 00:03:41,174 [INFO] MLPPolicy.num_params = 11569


Loaded pretrained EvoTF model from ckpt: 2024_03_SNES_small.pkl


2024-08-29 00:03:41.539173: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 402.82MiB (422389045 bytes) by rematerialization; only reduced to 5.58GiB (5994338228 bytes), down from 5.58GiB (5994338228 bytes) originally
SimManager: 2024-08-29 00:03:41,596 [INFO] use_for_loop=False


START EVOLVING 11569 PARAMS.


2024-08-29 00:03:41.778935: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below -4.92GiB (-5283442480 bytes) by rematerialization; only reduced to 5.62GiB (6029831912 bytes), down from 5.62GiB (6029878188 bytes) originally
2024-08-29 00:03:51.860608: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.52GiB (rounded to 5923328000)requested by op 
2024-08-29 00:03:51.860875: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] **********************************************************************************************______
E0829 00:03:51.860934   43722 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5923328000 bytes.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5923328000 bytes.