In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

EvoTF_ES With MLP Controller

In [2]:
import jax
import time
import numpy as np
from evojax.obs_norm import ObsNormalizer
from evojax.sim_mgr import SimManager
from evojax.task.brax_task import BraxTask
from evojax.policy import MLPPolicy

from evosax import Strategies
from evosax.utils.evojax_wrapper import Evosax2JAX_Wrapper

In [4]:
def get_brax_task(
    env_name = "ant",
    hidden_dims = [32, 32, 32, 32],
):
    train_task = BraxTask(env_name, test=False)
    test_task = BraxTask(env_name, test=True)
    policy = MLPPolicy(
        input_dim=train_task.obs_shape[0],
        output_dim=train_task.act_shape[0],
        hidden_dims=hidden_dims,
    )
    return train_task, test_task, policy

In [5]:
print(jax.devices())

[CudaDevice(id=0)]


In [6]:
train_task, test_task, policy = get_brax_task("ant")
solver = Evosax2JAX_Wrapper(
    Strategies["EvoTF_ES"],
    param_size=policy.num_params,
    pop_size=5,
    es_config={"maximize": True,
               "centered_rank": True},
    #            "lrate_init": 0.01,
    #            "lrate_decay": 0.999,
    #            "lrate_limit": 0.001},
    # es_params={"sigma_init": 0.05,
    # "sigma_decay": 0.999,
    # "sigma_limit": 0.01},
    seed=0,
)
obs_normalizer = ObsNormalizer(
    obs_shape=train_task.obs_shape, dummy=not True
)
sim_mgr = SimManager(
    policy_net=policy,
    train_vec_task=train_task,
    valid_vec_task=test_task,
    seed=0,
    obs_normalizer=obs_normalizer,
    pop_size=5,
    use_for_loop=True,  #使用 Python 的 for 循环来进行模拟，那么可以将 use_for_loop 设置为 True。
    #如果你希望利用 JAX 的自动微分和并行计算功能，那么可以将 use_for_loop 设置为 False。
    n_repeats=8,
    test_n_repeats=1,
    n_evaluations=32,  
)

print(f"START EVOLVING {policy.num_params} PARAMS.")
# Run ES Loop.
for gen_counter in range(10):
    start_time = time.time()  # 记录开始时间

    params = solver.ask()
    scores, _ = sim_mgr.eval_params(params=params, test=False)
    solver.tell(fitness=scores)
    if gen_counter == 0 or (gen_counter + 1) % 2 == 0:
        test_scores, _ = sim_mgr.eval_params(
            params=solver.best_params, test=True
        )
        print(
            {
                "num_gens": gen_counter + 1,
            },
            {
                "train_perf": float(np.nanmean(scores)),
                "test_perf": float(np.nanmean(test_scores)),
            },
        )
        end_time = time.time()  # 记录结束时间
        elapsed_time = end_time - start_time  # 计算时间差
        print(f"Time taken for iteration {gen_counter + 1}: {elapsed_time/60} min")

MLPPolicy: 2024-08-30 21:37:48,034 [INFO] MLPPolicy.num_params = 4328


Loaded pretrained EvoTF model from ckpt: 2024_03_SNES_small.pkl


SimManager: 2024-08-30 21:37:48,380 [INFO] use_for_loop=True


START EVOLVING 4328 PARAMS.


2024-08-30 21:38:04.184620: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below -7.23GiB (-7763868160 bytes) by rematerialization; only reduced to 13.96GiB (14985267200 bytes), down from 13.96GiB (14985267200 bytes) originally
2024-08-30 21:38:04.288057: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below -7.23GiB (-7763868160 bytes) by rematerialization; only reduced to 27.94GiB (30002427920 bytes), down from 27.94GiB (30002427920 bytes) originally
2024-08-30 21:38:04.332230: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below -7.23GiB (-7763868160 bytes) by rematerialization; only reduced to 26.42MiB (27699200 bytes), down from 26.42MiB (27699200 bytes) originally
2024-08-30 21:38:04.348334: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below -7.23GiB (-7763868160 bytes) by rematerialization; only reduced to 13.96GiB (14985267200 bytes), down