Skip to content

Commit

Permalink
dev: gym dynamic work distribution instead of fixed env_per_worker
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Nov 30, 2023
1 parent eefd1e9 commit fd17a46
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 76 deletions.
123 changes: 58 additions & 65 deletions src/evox/problems/neuroevolution/reinforcement_learning/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
import numpy as np
import ray
from jax import jit, vmap
from jax.tree_util import tree_map, tree_structure, tree_transpose
from jax.tree_util import tree_map, tree_structure, tree_transpose, tree_leaves

from evox import Problem, State, Stateful, jit_class, jit_method


@jit
def tree_batch_size(tree):
"""Get the batch size of a tree"""
return tree_leaves(tree)[0].shape[0]


@jit_class
class Normalizer(Stateful):
def __init__(self):
Expand Down Expand Up @@ -52,15 +58,12 @@ def normalize_obvs(self, state, obvs):

@ray.remote(num_cpus=1)
class Worker:
def __init__(self, env_creator, num_env, policy=None, mo_keys=None):
self.num_env = num_env
self.envs = [env_creator() for _ in range(num_env)]
def __init__(self, env_creator, policy=None, mo_keys=None):
self.envs = []
self.env_creator = env_creator
self.policy = policy
self.mo_keys = mo_keys

self.seed2key = jit(vmap(jax.random.PRNGKey))
self.splitKey = jit(vmap(jax.random.split))

def step(self, actions):
for i, (env, action) in enumerate(zip(self.envs, actions)):
# take the action if not terminated
Expand Down Expand Up @@ -98,22 +101,28 @@ def get_rewards(self):
def get_episode_length(self):
return self.episode_length

def reset(self, seeds):
self.total_rewards = np.zeros((self.num_env,))
def reset(self, seed, num_env):
# create new envs if needed
while len(self.envs) < num_env:
self.envs.append(self.env_creator())

self.total_rewards = np.zeros((num_env,))
self.acc_mo_values = np.zeros((len(self.mo_keys),)) # accumulated mo_value
self.episode_length = np.zeros((self.num_env,))
self.terminated = np.zeros((self.num_env,), dtype=bool)
self.truncated = np.zeros((self.num_env,), dtype=bool)
self.episode_length = np.zeros((num_env,))
self.terminated = np.zeros((num_env,), dtype=bool)
self.truncated = np.zeros((num_env,), dtype=bool)
self.observations, self.infos = zip(
*[env.reset(seed=seed) for seed, env in zip(seeds, self.envs)]
*[env.reset(seed=seed) for env in self.envs[:num_env]]
)
self.observations, self.infos = list(self.observations), list(self.infos)
return self.observations

def rollout(self, seeds, subpop, cap_episode_length):
def rollout(self, seed, subpop, cap_episode_length):
subpop = jax.device_put(subpop)
# num_env is the first dim of subpop
num_env = tree_batch_size(subpop)
assert self.policy is not None
self.reset(seeds)
self.reset(seed, num_env)
i = 0
while True:
observations = jnp.asarray(self.observations)
Expand All @@ -136,18 +145,15 @@ def __init__(
self,
policy,
num_workers,
env_per_worker,
env_creator,
worker_options,
batch_policy,
mo_keys,
):
self.num_workers = num_workers
self.env_per_worker = env_per_worker
self.workers = [
Worker.options(**worker_options).remote(
env_creator,
env_per_worker,
None if batch_policy else jit(vmap(policy)),
mo_keys,
)
Expand All @@ -162,12 +168,12 @@ def slice_pop(self, pop):
def reshape_weight(w):
# first dim is batch
weight_dim = w.shape[1:]
return list(w.reshape((self.num_workers, self.env_per_worker, *weight_dim)))
return jnp.array_split(w, self.num_workers, axis=0)

if isinstance(pop, jax.Array):
# first dim is batch
param_dim = pop.shape[1:]
pop = pop.reshape((self.num_workers, self.env_per_worker, *param_dim))
pop = jnp.array_split(pop, self.num_workers, axis=0)
else:
outer_treedef = tree_structure(pop)
inner_treedef = tree_structure([0 for _i in range(self.num_workers)])
Expand All @@ -176,58 +182,59 @@ def reshape_weight(w):

return pop

def _evaluate(self, seeds, pop, cap_episode_length):
def _evaluate(self, seed, pop, cap_episode_length):
sliced_pop = self.slice_pop(pop)
rollout_future = [
worker.rollout.remote(worker_seeds, subpop, cap_episode_length)
for worker_seeds, subpop, worker in zip(seeds, sliced_pop, self.workers)
worker.rollout.remote(seed, subpop, cap_episode_length)
for subpop, worker in zip(sliced_pop, self.workers)
]

rewards, acc_mo_values, episode_length = zip(*ray.get(rollout_future))
rewards = np.concatenate(rewards, axis=0)
acc_mo_values = np.concatenate(acc_mo_values, axis=0)
episode_length = np.concatenate(episode_length, axis=0)
acc_mo_values = np.array(acc_mo_values)
if acc_mo_values.size != 0:
acc_mo_values = acc_mo_values.reshape(-1, self.num_obj)
return (
np.array(rewards).reshape(-1),
acc_mo_values,
np.array(episode_length).reshape(-1),
)
return rewards, acc_mo_values, episode_length

@jit_method
def batch_policy_evaluation(self, observations, pop):
# the first two dims are num_workers and env_per_worker
observation_dim = observations.shape[2:]
actions = jax.vmap(self.policy)(
pop,
observations.reshape(
(self.num_workers * self.env_per_worker, *observation_dim)
),
observations,
)
# reshape in order to distribute to different workers
action_dim = actions.shape[1:]
actions = actions.reshape((self.num_workers, self.env_per_worker, *action_dim))
actions = jnp.array_split(actions, self.num_workers, axis=0)
return actions

def _batched_evaluate(self, seeds, pop, cap_episode_length):
def _batched_evaluate(self, seed, pop, cap_episode_length):
pop_size = tree_batch_size(pop)
env_per_worker = pop_size // self.num_workers
reminder = pop_size % self.num_workers
num_envs = [
env_per_worker + 1 if i < reminder else env_per_worker
for i in range(self.num_workers)
]
observations = ray.get(
[
worker.reset.remote(worker_seeds)
for worker_seeds, worker in zip(seeds, self.workers)
worker.reset.remote(seed, num_env)
for worker, num_env in zip(self.workers, num_envs)
]
)
terminated = False
episode_length = 0

i = 0
while True:
# flatten observations
observations = [obs for worker_obs in observations for obs in worker_obs]
observations = np.stack(observations, axis=0)
observations = jnp.asarray(observations)
# get action from policy
actions = self.batch_policy_evaluation(observations, pop)
# convert to numpy array
actions = np.asarray(actions)

futures = [
worker.step.remote(action)
worker.step.remote(np.asarray(action))
for worker, action in zip(self.workers, actions)
]
observations, terminated, truncated = zip(*ray.get(futures))
Expand All @@ -243,22 +250,18 @@ def _batched_evaluate(self, seeds, pop, cap_episode_length):
rewards, acc_mo_values = zip(
*ray.get([worker.get_rewards.remote() for worker in self.workers])
)
acc_mo_values = np.array(acc_mo_values)
if acc_mo_values.size != 0:
acc_mo_values = acc_mo_values.reshape(-1, self.num_obj)
rewards = np.concatenate(rewards, axis=0)
acc_mo_values = np.concatenate(acc_mo_values, axis=0)
episode_length = [worker.get_episode_length.remote() for worker in self.workers]
episode_length = ray.get(episode_length)
return (
np.array(rewards).reshape(-1),
acc_mo_values,
np.array(episode_length).reshape(-1),
)
episode_length = np.concatenate(episode_length, axis=0)
return rewards, acc_mo_values, episode_length

def evaluate(self, seeds, pop, cap_episode_length):
def evaluate(self, seed, pop, cap_episode_length):
if self.batch_policy:
return self._batched_evaluate(seeds, pop, cap_episode_length)
return self._batched_evaluate(seed, pop, cap_episode_length)
else:
return self._evaluate(seeds, pop, cap_episode_length)
return self._evaluate(seed, pop, cap_episode_length)


@jit_class
Expand All @@ -283,7 +286,6 @@ def __init__(
self,
policy: Callable,
num_workers: int,
env_per_worker: int,
env_name: Optional[str] = None,
env_options: dict = {},
env_creator: Optional[Callable] = None,
Expand All @@ -302,8 +304,6 @@ def __init__(
the first one is the parameter and the second is the input.
num_workers
Number of worker actors.
env_per_worker
Number of gym environment per worker.
env_name
The name of the gym environment.
env_options
Expand All @@ -323,7 +323,6 @@ def __init__(
set this field to::
{"num_gpus": 1}
worker_options
The runtime options for worker actors.
"""
Expand All @@ -336,14 +335,12 @@ def __init__(
self.controller = Controller.options(**controller_options).remote(
policy,
num_workers,
env_per_worker,
env_creator,
worker_options,
batch_policy,
mo_keys,
)
self.num_workers = num_workers
self.env_per_worker = env_per_worker
self.env_name = env_name
self.policy = policy
if init_cap is not None:
Expand All @@ -357,19 +354,15 @@ def setup(self, key):
def evaluate(self, state, pop):
key, subkey = jax.random.split(state.key)
# generate a list of seeds for gym
seeds = jax.random.randint(
subkey, (self.num_workers, self.env_per_worker), 0, jnp.iinfo(jnp.int32).max
)

seeds = seeds.tolist()
seed = jax.random.randint(subkey, (1,), 0, jnp.iinfo(jnp.int32).max).item()

cap_episode_length = None
if self.cap_episode:
cap_episode_length, state = self.cap_episode.get(state)
cap_episode_length = cap_episode_length.item()

rewards, acc_mo_values, episode_length = ray.get(
self.controller.evaluate.remote(seeds, pop, cap_episode_length)
self.controller.evaluate.remote(seed, pop, cap_episode_length)
)

# convert np.array -> jnp.array here
Expand Down
30 changes: 19 additions & 11 deletions tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def __call__(self, x):
problem = problems.neuroevolution.Gym(
env_name="CartPole-v1",
policy=jax.jit(model.apply),
num_workers=2,
env_per_worker=4,
num_workers=3,
worker_options={"num_gpus": 0, "num_cpus": 0},
controller_options={
"num_cpus": 0,
Expand All @@ -41,10 +40,12 @@ def __call__(self, x):
center = adapter.to_vector(params)
# create a workflow
workflow = workflows.UniWorkflow(
algorithm=algorithms.PGPE(
optimizer="adam",
center_init=center,
pop_size=8,
algorithm=algorithms.CSO(
lb=jnp.full_like(center, -10.0),
ub=jnp.full_like(center, 10.0),
mean=center,
stdev=0.1,
pop_size=16,
),
problem=problem,
monitor=monitor,
Expand All @@ -56,12 +57,19 @@ def __call__(self, x):
# init the workflow
state = workflow.init(workflow_key)

# run the workflow for 5 steps
for i in range(5):
# run the workflow for 2 steps
for i in range(2):
state = workflow.step(state)

monitor.close()
# the result should be close to 0
monitor.flush()
min_fitness = monitor.get_best_fitness()
# gym is deterministic, so the result should always be the same
assert min_fitness == 16.0
assert min_fitness == 40.0

# run the workflow for another 25 steps
for i in range(25):
state = workflow.step(state)

monitor.flush()
min_fitness = monitor.get_best_fitness()
assert min_fitness == 48.0

0 comments on commit fd17a46

Please sign in to comment.