In [1]:
from typing import Tuple

import jax
import jax.numpy as jnp
from tqdm import tqdm

from qdax import environments
from qdax.baselines.pbt import PBT
from qdax.baselines.td3_pbt import PBTTD3, PBTTD3Config

In [2]:
jax.config.update("jax_platform_name", "cpu")

In [3]:
devices = jax.devices("gpu")
num_devices = len(devices)
print(f"Detected the following {num_devices} device(s): {devices}")

Detected the following 1 device(s): [GpuDevice(id=0, process_index=0)]


In [4]:
env_name = "ant"
seed = 0
env_batch_size = 250
population_size_per_device = 3
population_size = population_size_per_device * num_devices
num_steps = 10000
warmup_steps = 0
buffer_size = 100000

# PBT Config
num_best_to_replace_from = 1
num_worse_to_replace = 1

# TD3 config
episode_length: int = 1000
batch_size: int = 256
policy_delay: int = 2
grad_updates_per_step: float = 1
soft_tau_update: float = 0.005
critic_hidden_layer_size: Tuple[int, ...] = (256, 256)
policy_hidden_layer_size: Tuple[int, ...] = (256, 256)

num_loops = 10
print_freq = 1

In [5]:
%%time
# Initialize environments
env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size * population_size_per_device,
    episode_length=episode_length,
    auto_reset=True,
)

eval_env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size * population_size_per_device,
    episode_length=episode_length,
    auto_reset=True,
)

CPU times: user 78.5 ms, sys: 1.27 ms, total: 79.8 ms
Wall time: 77.7 ms


In [6]:
@jax.jit
def init_environments(random_key):

    env_states = jax.jit(env.reset)(rng=random_key)
    eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)

    reshape_fn = jax.jit(
        lambda tree: jax.tree_map(
            lambda x: jnp.reshape(
                x, (population_size_per_device, env_batch_size,) + x.shape[1:]
            ),
            tree,
        ),
    )
    env_states = reshape_fn(env_states)
    eval_env_first_states = reshape_fn(eval_env_first_states)

    return env_states, eval_env_first_states

In [7]:
%%time
key = jax.random.PRNGKey(seed)
key, *keys = jax.random.split(key, num=1 + num_devices)
keys = jnp.stack(keys)
env_states, eval_env_first_states = jax.pmap(
    init_environments, axis_name="p", devices=devices
)(keys)

CPU times: user 8.82 s, sys: 400 ms, total: 9.22 s
Wall time: 6.65 s


In [8]:
# get agent
config = PBTTD3Config(
    episode_length=episode_length,
    batch_size=batch_size,
    policy_delay=policy_delay,
    soft_tau_update=soft_tau_update,
    critic_hidden_layer_size=critic_hidden_layer_size,
    policy_hidden_layer_size=policy_hidden_layer_size,
)

agent = PBTTD3(config=config, action_size=env.action_size)

In [9]:
%%time
# get the initial training states and replay buffers
agent_init_fn = agent.get_init_fn(
    population_size=population_size_per_device,
    action_size=env.action_size,
    observation_size=env.observation_size,
    buffer_size=buffer_size,
)
keys, training_states, replay_buffers = jax.pmap(
    agent_init_fn, axis_name="p", devices=devices
)(keys)

CPU times: user 920 ms, sys: 2.74 ms, total: 923 ms
Wall time: 726 ms


In [10]:
# get eval policy fonction
eval_policy = jax.pmap(agent.get_eval_fn(eval_env), axis_name="p", devices=devices)

In [11]:
%%time
# eval policy before training
population_returns, _ = eval_policy(training_states, eval_env_first_states)
population_returns = jnp.reshape(population_returns, (population_size,))
print(
    f"Evaluation over {env_batch_size} episodes,"
    f" Population mean return: {jnp.mean(population_returns)},"
    f" max return: {jnp.max(population_returns)}"
)

Evaluation over 250 episodes, Population mean return: 741.47314453125, max return: 789.097900390625
CPU times: user 7.63 s, sys: 241 ms, total: 7.87 s
Wall time: 4.88 s


In [12]:
# get training function
num_iterations = num_steps // env_batch_size

train_fn = agent.get_train_fn(
    env=env,
    num_iterations=num_iterations,
    env_batch_size=env_batch_size,
    grad_updates_per_step=grad_updates_per_step,
)
train_fn = jax.pmap(train_fn, axis_name="p", devices=devices)

In [13]:
pbt = PBT(
    population_size=population_size,
    num_best_to_replace_from=num_best_to_replace_from // num_devices,
    num_worse_to_replace=num_worse_to_replace // num_devices,
)
select_fn = jax.pmap(pbt.update_states_and_buffer_pmap, axis_name="p", devices=devices)

In [14]:
@jax.jit
def unshard_fn(sharded_tree):
    tree = jax.tree_map(lambda x: jax.device_put(x, "cpu"), sharded_tree)
    tree = jax.tree_map(
        lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree
    )
    return tree

In [15]:
%%time
for i in tqdm(range(num_loops), total=num_loops):

    # Update for num_steps
    (training_states, env_states, replay_buffers), metrics = train_fn(
        training_states, env_states, replay_buffers
    )

    # Eval policy after training
    population_returns, _ = eval_policy(training_states, eval_env_first_states)
    population_returns_flatten = jnp.reshape(population_returns, (population_size,))

    if i % print_freq == 0:
        print(
            f"Evaluation over {env_batch_size} episodes,"
            f" Population mean return: {jnp.mean(population_returns_flatten)},"
            f" max return: {jnp.max(population_returns_flatten)}"
        )

    # PBT selection
    if i < (num_loops-1):
        keys, training_states, replay_buffers = select_fn(
            keys, population_returns, training_states, replay_buffers
        )

  0%|          | 0/10 [00:00<?, ?it/s]

Evaluation over 250 episodes, Population mean return: -2377.849365234375, max return: -1136.852294921875


 20%|██        | 2/10 [00:23<01:32, 11.60s/it]

Evaluation over 250 episodes, Population mean return: -1639.5887451171875, max return: -840.1343994140625


 30%|███       | 3/10 [00:29<01:03,  9.10s/it]

Evaluation over 250 episodes, Population mean return: -799.96142578125, max return: -691.269287109375


 40%|████      | 4/10 [00:35<00:47,  7.91s/it]

Evaluation over 250 episodes, Population mean return: -562.3890380859375, max return: -247.29576110839844


 50%|█████     | 5/10 [00:41<00:36,  7.26s/it]

Evaluation over 250 episodes, Population mean return: 63.16473388671875, max return: 419.86932373046875


 60%|██████    | 6/10 [00:47<00:27,  6.86s/it]

Evaluation over 250 episodes, Population mean return: 527.791748046875, max return: 691.606201171875


 70%|███████   | 7/10 [00:53<00:19,  6.62s/it]

Evaluation over 250 episodes, Population mean return: 751.8699951171875, max return: 811.7660522460938


 80%|████████  | 8/10 [01:00<00:12,  6.46s/it]

Evaluation over 250 episodes, Population mean return: 662.638916015625, max return: 739.002685546875


 90%|█████████ | 9/10 [01:06<00:06,  6.35s/it]

Evaluation over 250 episodes, Population mean return: 959.3629760742188, max return: 1050.8975830078125


100%|██████████| 10/10 [01:12<00:00,  7.23s/it]

Evaluation over 250 episodes, Population mean return: 1183.953125, max return: 1248.028076171875
CPU times: user 1min 13s, sys: 5.63 s, total: 1min 19s
Wall time: 1min 12s



