From 6f78a4ad4ca3e5bbc51687b66fe3e9930bae4dca Mon Sep 17 00:00:00 2001 From: Bryan Lim <46229436+limbryan@users.noreply.github.com> Date: Thu, 30 Mar 2023 00:01:24 +0100 Subject: [PATCH 1/6] chore: uniform sac network sizes (#145) * Uniformize architecture definition across RL algs. Separate actor and critic architecture definition for all SAC related networks including DADS and DIAYN too --- examples/dads.ipynb | 18 +++++++++++------- examples/diayn.ipynb | 18 +++++++++++------- examples/me_sac_pbt.ipynb | 18 +++++++++++------- examples/sac_pbt.ipynb | 15 +++++++++++---- examples/smerl.ipynb | 18 +++++++++++------- qdax/baselines/dads.py | 2 ++ qdax/baselines/diayn.py | 3 ++- qdax/baselines/sac.py | 7 +++++-- qdax/baselines/sac_pbt.py | 6 ++++-- .../neuroevolution/networks/dads_networks.py | 11 ++++++----- .../neuroevolution/networks/diayn_networks.py | 11 ++++++----- .../neuroevolution/networks/sac_networks.py | 9 +++++---- tests/baselines_test/dads_smerl_test.py | 6 ++++-- tests/baselines_test/dads_test.py | 6 ++++-- tests/baselines_test/diayn_smerl_test.py | 6 ++++-- tests/baselines_test/diayn_test.py | 6 ++++-- tests/baselines_test/me_pbt_sac_test.py | 6 ++++-- tests/baselines_test/pbt_sac_test.py | 6 ++++-- tests/baselines_test/sac_test.py | 6 ++++-- 19 files changed, 113 insertions(+), 65 deletions(-) diff --git a/examples/dads.ipynb b/examples/dads.ipynb index f64f4685..72380b34 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -116,7 +116,8 @@ "alpha_init = 1.0 #@param {type:\"number\"}\n", "discount = 0.97 #@param {type:\"number\"}\n", "reward_scaling = 1.0 #@param {type:\"number\"}\n", - "hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n", + "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", + "policy_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", "fix_alpha = False #@param {type:\"boolean\"}\n", "normalize_observations = False #@param {type:\"boolean\"}\n", "# DADS config\n", @@ -202,7 +203,8 @@ " alpha_init=alpha_init,\n", " discount=discount,\n", " reward_scaling=reward_scaling,\n", - " hidden_layer_sizes=hidden_layer_sizes,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " policy_hidden_layer_size=policy_hidden_layer_size,\n", " fix_alpha=fix_alpha,\n", " # DADS config\n", " num_skills=num_skills,\n", @@ -520,11 +522,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.9.2 64-bit ('3.9.2')", "language": "python", "name": "python3" }, @@ -538,7 +537,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534" + } } }, "nbformat": 4, diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 10cfda49..c48ab765 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -116,7 +116,8 @@ "alpha_init = 1.0 #@param {type:\"number\"}\n", "discount = 0.97 #@param {type:\"number\"}\n", "reward_scaling = 1.0 #@param {type:\"number\"}\n", - "hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n", + "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", + "policy_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", "fix_alpha = False #@param {type:\"boolean\"}\n", "normalize_observations = False #@param {type:\"boolean\"}\n", "# DIAYN config\n", @@ -200,7 +201,8 @@ " alpha_init=alpha_init,\n", " discount=discount,\n", " reward_scaling=reward_scaling,\n", - " hidden_layer_sizes=hidden_layer_sizes,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " policy_hidden_layer_size=policy_hidden_layer_size,\n", " fix_alpha=fix_alpha,\n", " # DIAYN config\n", " num_skills=num_skills,\n", @@ -510,11 +512,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.9.2 64-bit ('3.9.2')", "language": "python", "name": "python3" }, @@ -528,7 +527,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534" + } } }, "nbformat": 4, diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 3f856c3d..6d4dfdfe 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -63,7 +63,8 @@ "episode_length = 1000\n", "tau = 0.005\n", "alpha_init = 1.0\n", - "hidden_layer_sizes = (256, 256)\n", + "critic_hidden_layer_size = (256, 256) \n", + "policy_hidden_layer_size = (256, 256) \n", "fix_alpha = False\n", "normalize_observations = False\n", "\n", @@ -148,7 +149,8 @@ " tau=tau,\n", " normalize_observations=normalize_observations,\n", " alpha_init=alpha_init,\n", - " hidden_layer_sizes=hidden_layer_sizes,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " policy_hidden_layer_size=policy_hidden_layer_size,\n", " fix_alpha=fix_alpha,\n", ")\n", "\n", @@ -527,11 +529,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.9.2 64-bit ('3.9.2')", "language": "python", "name": "python3" }, @@ -545,7 +544,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534" + } } }, "nbformat": 4, diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 37bc06d7..4f225667 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -95,7 +95,8 @@ "grad_updates_per_step = 1.0\n", "tau = 0.005\n", "alpha_init = 1.0\n", - "hidden_layer_sizes = (256, 256)\n", + "critic_hidden_layer_size = (256, 256) \n", + "policy_hidden_layer_size = (256, 256)\n", "fix_alpha = False\n", "normalize_observations = False\n", "\n", @@ -217,7 +218,8 @@ " tau=tau,\n", " normalize_observations=normalize_observations,\n", " alpha_init=alpha_init,\n", - " hidden_layer_sizes=hidden_layer_sizes,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " policy_hidden_layer_size=policy_hidden_layer_size,\n", " fix_alpha=fix_alpha,\n", ")\n", "\n", @@ -544,7 +546,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.9.2 64-bit ('3.9.2')", "language": "python", "name": "python3" }, @@ -558,7 +560,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534" + } } }, "nbformat": 4, diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index 8042c8cf..98d57b94 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -117,7 +117,8 @@ "alpha_init = 1.0 #@param {type:\"number\"}\n", "discount = 0.97 #@param {type:\"number\"}\n", "reward_scaling = 1.0 #@param {type:\"number\"}\n", - "hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n", + "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", + "policy_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", "fix_alpha = False #@param {type:\"boolean\"}\n", "normalize_observations = False #@param {type:\"boolean\"}\n", "# DIAYN config\n", @@ -212,7 +213,8 @@ " alpha_init=alpha_init,\n", " discount=discount,\n", " reward_scaling=reward_scaling,\n", - " hidden_layer_sizes=hidden_layer_sizes,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " policy_hidden_layer_size=policy_hidden_layer_size,\n", " fix_alpha=fix_alpha,\n", " # DIAYN config\n", " num_skills=num_skills,\n", @@ -525,11 +527,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.9.2 64-bit ('3.9.2')", "language": "python", "name": "python3" }, @@ -543,7 +542,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "1508da3da994e8b4133db52fbfc99c200ce19b8717cb4612fe84174533968534" + } } }, "nbformat": 4, diff --git a/qdax/baselines/dads.py b/qdax/baselines/dads.py index 310a9c8b..41f2ff08 100644 --- a/qdax/baselines/dads.py +++ b/qdax/baselines/dads.py @@ -84,6 +84,8 @@ def __init__(self, config: DadsConfig, action_size: int, descriptor_size: int): action_size=action_size, descriptor_size=descriptor_size, omit_input_dynamics_dim=config.omit_input_dynamics_dim, + policy_hidden_layer_size=config.policy_hidden_layer_size, + critic_hidden_layer_size=config.critic_hidden_layer_size, ) # define the action distribution diff --git a/qdax/baselines/diayn.py b/qdax/baselines/diayn.py index e2def709..c03cfb3f 100644 --- a/qdax/baselines/diayn.py +++ b/qdax/baselines/diayn.py @@ -78,7 +78,8 @@ def __init__(self, config: DiaynConfig, action_size: int): self._policy, self._critic, self._discriminator = make_diayn_networks( num_skills=self._config.num_skills, action_size=action_size, - hidden_layer_sizes=self._config.hidden_layer_sizes, + policy_hidden_layer_size=self._config.policy_hidden_layer_size, + critic_hidden_layer_size=self._config.critic_hidden_layer_size, ) # define the action distribution diff --git a/qdax/baselines/sac.py b/qdax/baselines/sac.py index 90a823b6..a5ce15c5 100644 --- a/qdax/baselines/sac.py +++ b/qdax/baselines/sac.py @@ -71,7 +71,8 @@ class SacConfig: alpha_init: float = 1.0 discount: float = 0.97 reward_scaling: float = 1.0 - hidden_layer_sizes: tuple = (256, 256) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) fix_alpha: bool = False @@ -82,7 +83,9 @@ def __init__(self, config: SacConfig, action_size: int) -> None: # define the networks self._policy, self._critic = make_sac_networks( - action_size=action_size, hidden_layer_sizes=self._config.hidden_layer_sizes + action_size=action_size, + critic_hidden_layer_size=self._config.critic_hidden_layer_size, + policy_hidden_layer_size=self._config.policy_hidden_layer_size, ) # define the action distribution diff --git a/qdax/baselines/sac_pbt.py b/qdax/baselines/sac_pbt.py index f5fd24c1..9aa2ff4c 100644 --- a/qdax/baselines/sac_pbt.py +++ b/qdax/baselines/sac_pbt.py @@ -110,7 +110,8 @@ class PBTSacConfig: tau: float = 0.005 normalize_observations: bool = False alpha_init: float = 1.0 - hidden_layer_sizes: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) + critic_hidden_layer_size: tuple = (256, 256) fix_alpha: bool = False @@ -123,7 +124,8 @@ def __init__(self, config: PBTSacConfig, action_size: int) -> None: tau=config.tau, normalize_observations=config.normalize_observations, alpha_init=config.alpha_init, - hidden_layer_sizes=config.hidden_layer_sizes, + policy_hidden_layer_size=config.policy_hidden_layer_size, + critic_hidden_layer_size=config.critic_hidden_layer_size, fix_alpha=config.fix_alpha, # unused default values for parameters that will be learnt as part of PBT learning_rate=3e-4, diff --git a/qdax/core/neuroevolution/networks/dads_networks.py b/qdax/core/neuroevolution/networks/dads_networks.py index 785e4ef3..beb4b77a 100644 --- a/qdax/core/neuroevolution/networks/dads_networks.py +++ b/qdax/core/neuroevolution/networks/dads_networks.py @@ -126,7 +126,8 @@ def __call__( def make_dads_networks( action_size: int, descriptor_size: int, - hidden_layer_sizes: Tuple[int, ...] = (256, 256), + critic_hidden_layer_size: Tuple[int, ...] = (256, 256), + policy_hidden_layer_size: Tuple[int, ...] = (256, 256), omit_input_dynamics_dim: int = 2, identity_covariance: bool = True, dynamics_initializer: Optional[Initializer] = None, @@ -155,7 +156,7 @@ def _actor_fn(obs: Observation) -> jnp.ndarray: network = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [2 * action_size], + list(policy_hidden_layer_size) + [2 * action_size], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -167,7 +168,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network1 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -176,7 +177,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network2 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -191,7 +192,7 @@ def _dynamics_fn( obs: StateDescriptor, skill: Skill, target: StateDescriptor ) -> jnp.ndarray: dynamics_network = DynamicsNetwork( - hidden_layer_sizes, + critic_hidden_layer_size, descriptor_size, omit_input_dynamics_dim=omit_input_dynamics_dim, identity_covariance=identity_covariance, diff --git a/qdax/core/neuroevolution/networks/diayn_networks.py b/qdax/core/neuroevolution/networks/diayn_networks.py index dc45d298..c656cace 100644 --- a/qdax/core/neuroevolution/networks/diayn_networks.py +++ b/qdax/core/neuroevolution/networks/diayn_networks.py @@ -10,7 +10,8 @@ def make_diayn_networks( action_size: int, num_skills: int, - hidden_layer_sizes: Tuple[int, ...] = (256, 256), + critic_hidden_layer_size: Tuple[int, ...] = (256, 256), + policy_hidden_layer_size: Tuple[int, ...] = (256, 256), ) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]: """Creates networks used in DIAYN. @@ -30,7 +31,7 @@ def _actor_fn(obs: Observation) -> jnp.ndarray: network = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [2 * action_size], + list(policy_hidden_layer_size) + [2 * action_size], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -42,7 +43,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network1 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -51,7 +52,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network2 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -66,7 +67,7 @@ def _discriminator_fn(obs: Observation) -> jnp.ndarray: network = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [num_skills], + list(critic_hidden_layer_size) + [num_skills], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), diff --git a/qdax/core/neuroevolution/networks/sac_networks.py b/qdax/core/neuroevolution/networks/sac_networks.py index be6db1b2..dcadfaa2 100644 --- a/qdax/core/neuroevolution/networks/sac_networks.py +++ b/qdax/core/neuroevolution/networks/sac_networks.py @@ -9,7 +9,8 @@ def make_sac_networks( action_size: int, - hidden_layer_sizes: Tuple[int, ...] = (256, 256), + critic_hidden_layer_size: Tuple[int, ...] = (256, 256), + policy_hidden_layer_size: Tuple[int, ...] = (256, 256), ) -> Tuple[hk.Transformed, hk.Transformed]: """Creates networks used in SAC. @@ -27,7 +28,7 @@ def _actor_fn(obs: Observation) -> jnp.ndarray: network = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [2 * action_size], + list(policy_hidden_layer_size) + [2 * action_size], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -39,7 +40,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network1 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), @@ -48,7 +49,7 @@ def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: network2 = hk.Sequential( [ hk.nets.MLP( - list(hidden_layer_sizes) + [1], + list(critic_hidden_layer_size) + [1], w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), activation=jax.nn.relu, ), diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 767407ab..1e782f2a 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -31,7 +31,8 @@ def test_dads_smerl() -> None: tau = 0.005 grad_updates_per_step = 0.25 normalize_observations = False - hidden_layer_sizes = (256, 256) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) alpha_init = 1.0 fix_alpha = False discount = 0.97 @@ -102,7 +103,8 @@ def test_dads_smerl() -> None: alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, - hidden_layer_sizes=hidden_layer_sizes, + critic_hidden_layer_size=critic_hidden_layer_size, + policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, # DADS config num_skills=num_skills, diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 54f486d6..0b9af46e 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -30,7 +30,8 @@ def test_dads() -> None: tau = 0.005 grad_updates_per_step = 0.25 normalize_observations = False - hidden_layer_sizes = (256, 256) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) alpha_init = 1.0 fix_alpha = False discount = 0.97 @@ -93,7 +94,8 @@ def test_dads() -> None: alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, - hidden_layer_sizes=hidden_layer_sizes, + critic_hidden_layer_size=critic_hidden_layer_size, + policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, # DADS config num_skills=num_skills, diff --git a/tests/baselines_test/diayn_smerl_test.py b/tests/baselines_test/diayn_smerl_test.py index bb75c37e..abd94b45 100644 --- a/tests/baselines_test/diayn_smerl_test.py +++ b/tests/baselines_test/diayn_smerl_test.py @@ -34,7 +34,8 @@ def test_diayn_smerl() -> None: alpha_init = 1.0 discount = 0.97 reward_scaling = 1.0 - hidden_layer_sizes = (64, 64) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) fix_alpha = False normalize_observations = False # DIAYN config @@ -100,7 +101,8 @@ def test_diayn_smerl() -> None: alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, - hidden_layer_sizes=hidden_layer_sizes, + critic_hidden_layer_size=critic_hidden_layer_size, + policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, # DIAYN config num_skills=num_skills, diff --git a/tests/baselines_test/diayn_test.py b/tests/baselines_test/diayn_test.py index c0dd8c09..3492d9c1 100644 --- a/tests/baselines_test/diayn_test.py +++ b/tests/baselines_test/diayn_test.py @@ -33,7 +33,8 @@ def test_diayn() -> None: alpha_init = 1.0 discount = 0.97 reward_scaling = 1.0 - hidden_layer_sizes = (64, 64) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) fix_alpha = False normalize_observations = False # DIAYN config @@ -85,7 +86,8 @@ def test_diayn() -> None: alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, - hidden_layer_sizes=hidden_layer_sizes, + critic_hidden_layer_size=critic_hidden_layer_size, + policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, # DIAYN config num_skills=num_skills, diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 5bbc8241..079fde45 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -27,7 +27,8 @@ def test_me_pbt_sac() -> None: episode_length = 100 tau = 0.005 alpha_init = 1.0 - hidden_layer_sizes = (64, 64) + policy_hidden_layer_size = (64, 64) + critic_hidden_layer_size = (64, 64) fix_alpha = False normalize_observations = False @@ -79,7 +80,8 @@ def test_me_pbt_sac() -> None: tau=tau, normalize_observations=normalize_observations, alpha_init=alpha_init, - hidden_layer_sizes=hidden_layer_sizes, + policy_hidden_layer_size=policy_hidden_layer_size, + critic_hidden_layer_size=critic_hidden_layer_size, fix_alpha=fix_alpha, ) diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index d8cbcf58..c83f277c 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -31,7 +31,8 @@ def test_pbt_sac() -> None: grad_updates_per_step = 1.0 tau = 0.005 alpha_init = 1.0 - hidden_layer_sizes = (64, 64) + policy_hidden_layer_size = (64, 64) + critic_hidden_layer_size = (64, 64) fix_alpha = False normalize_observations = False @@ -89,7 +90,8 @@ def init_environments(random_key): # type: ignore tau=tau, normalize_observations=normalize_observations, alpha_init=alpha_init, - hidden_layer_sizes=hidden_layer_sizes, + policy_hidden_layer_size=policy_hidden_layer_size, + critic_hidden_layer_size=critic_hidden_layer_size, fix_alpha=fix_alpha, ) diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index f4029beb..c667aa66 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -31,7 +31,8 @@ def test_sac() -> None: alpha_init = 1.0 discount = 0.95 reward_scaling = 10.0 - hidden_layer_sizes = (64, 64) + critic_hidden_layer_size: tuple = (256, 256) + policy_hidden_layer_size: tuple = (256, 256) fix_alpha = False # Initialize environments @@ -73,7 +74,8 @@ def test_sac() -> None: alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, - hidden_layer_sizes=hidden_layer_sizes, + critic_hidden_layer_size=critic_hidden_layer_size, + policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, ) From 26c010e1698fb0be773bd7ef9b0ec78f89380ca3 Mon Sep 17 00:00:00 2001 From: Bryan Lim <46229436+limbryan@users.noreply.github.com> Date: Fri, 28 Apr 2023 16:35:23 +0100 Subject: [PATCH 2/6] chore: remove singularity tools and update haiku version (#147) * update haiku version * update jaxlib and optax dependencies * remove singularity scripts and remove singularity from documentation --- .readthedocs.yaml | 4 +- README.md | 2 +- docs/installation.md | 53 ----- environment.yaml | 1 - requirements.txt | 4 +- setup.py | 1 + singularity/build_final_image | 347 ------------------------------- singularity/build_final_image.py | 347 ------------------------------- singularity/singularity.def | 41 ---- singularity/start_container | 184 ---------------- singularity/start_container.py | 184 ---------------- 11 files changed, 7 insertions(+), 1161 deletions(-) delete mode 100755 singularity/build_final_image delete mode 100755 singularity/build_final_image.py delete mode 100644 singularity/singularity.def delete mode 100755 singularity/start_container delete mode 100755 singularity/start_container.py diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2ef47062..7eec359d 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -18,7 +18,7 @@ mkdocs: # Optionally declare the Python requirements required to build your docs python: install: - - requirements: requirements.txt - - requirements: docs/requirements.txt - method: pip path: . + - requirements: requirements.txt + - requirements: docs/requirements.txt diff --git a/README.md b/README.md index 8b321c6b..2477348d 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ pip install git+https://github.com/adaptive-intelligent-robotics/QDax.git@main ``` Installing QDax via ```pip``` installs a CPU-only version of JAX by default. To use QDax with NVidia GPUs, you must first install [CUDA, CuDNN, and JAX with GPU support](https://github.com/google/jax#installation). -However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/). +However, we also provide and recommend using either Docker or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/). ## Basic API Usage For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./examples/mapelites.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default). diff --git a/docs/installation.md b/docs/installation.md index 319a88c7..2c273da1 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -56,59 +56,6 @@ sudo docker run --rm -it -v $QDAX_PATH:/app instadeep/qdax:$USER /bin/bash sudo docker run --rm -it --gpus '"device=0,1"' -v $QDAX_PATH:/app instadeep/qdax:$USER /bin/bash ``` - - -### Using singularity - -First, follow these initial steps: - -1. If it is not already done, install Singularity, following [these instructions](https://docs.sylabs.io/guides/3.0/user-guide/installation.html). - -2. Clone `qdax` -```zsh -git clone git@github.com:adaptive-intelligent-robotics/QDax.git -``` - -3. Enter the singularity folder -```zsh -cd qdax/singularity/ -``` - -You can build two distinct types of images with singularity: "final images" or "sandbox images". -A final image is a single file with the `.sif` extension, it is immutable. -On the contrary, a sandbox image is not a file but a folder, it allows you to develop inside the singularity container to test your code while writing it. - -To build a final image, execute the `build_final_image` script: -```zsh -./build_final_image -``` -It will generate a `.sif` file: `[image_name].sif`. If you execute this file using singularity, as follows, it will run the default application of the image, defined in the `singularity.def` file that you can find in the `singularity` folder as well. At the moment, this is just running the MAP-Elites algorithm on a simple task. -```zsh -singularity run --nv [image_name].sif -``` - -!!! warning "Using GPU" - The `--nv` flag of the `singularity run` command allows the container to use the GPU, it is thus important to use it for QDax. - - -To build a sandbox image, execute the `start_container` script: -```zsh -./start_container -n -``` - -!!! warning "Using GPU" - The `-n` flag of the `start_container` command allow the container to use the GPU, it is thus important to use it for QDax. - -This command will generate a sandbox container `qdax.sif/` and enter it. If you execute this command again later, it will not generate a new container but enter directly the existing one. -Once inside the sandbox container, enter the qdax development folder: -```zsh -cd /git/exp/qdax -``` -This folder is linked with the `qdax` folder on your machine, meaning that any modification inside the container will directly modify the files on your machine. You can now use this development environment to develop your own QDax-based code. - - - - ### Using conda 1. If it is not already done, install conda from [here](https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html) diff --git a/environment.yaml b/environment.yaml index 78058b9d..e46c034e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -8,6 +8,5 @@ dependencies: - conda>=4.9.2 - pip: - --find-links https://storage.googleapis.com/jax-releases/jax_releases.html - - jaxlib==0.3.15 - -r requirements.txt - -r requirements-dev.txt diff --git a/requirements.txt b/requirements.txt index b97297fa..718d6213 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,16 @@ absl-py==1.0.0 brax==0.0.15 chex==0.1.5 -dm-haiku==0.0.5 +dm-haiku==0.0.9 flax==0.6.0 gym==0.23.1 ipython jax==0.3.17 +jaxlib==0.3.15 jumanji==0.1.3 jupyter numpy==1.22.3 +optax==0.1.4 protobuf==3.19.4 scikit-learn==1.0.2 scipy==1.8.0 diff --git a/setup.py b/setup.py index 2e50e0ea..a71f3174 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "brax>=0.0.15", "gym>=0.23.1", "numpy>=1.22.3", + "optax>=0.1, <0.1.5", "scikit-learn>=1.0.2", "scipy>=1.8.0", ], diff --git a/singularity/build_final_image b/singularity/build_final_image deleted file mode 100755 index 99911407..00000000 --- a/singularity/build_final_image +++ /dev/null @@ -1,347 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import os -import subprocess -import sys -import time -from typing import Tuple, Union - -SINGULARITY_DEFINITION_FILE_NAME = "singularity.def" - - -class BColors: - HEADER = "\033[95m" - OKBLUE = "\033[94m" - OKCYAN = "\033[96m" - OKGREEN = "\033[92m" - WARNING = "\033[93m" - FAIL = "\033[91m" - ENDC = "\033[0m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - - -def error_print(message: str) -> None: - print(f"{BColors.FAIL}{message}{BColors.ENDC}", file=sys.stderr) - - -def bold(message: str) -> str: - return f"{BColors.BOLD}{message}{BColors.ENDC}" - - -def load_singularity_file(path_to_singularity_definition_file: str) -> str: - try: - # read input file - fin = open(path_to_singularity_definition_file, "rt") - - except IOError: - error_print(f"ERROR, {path_to_singularity_definition_file} file not found!") - - finally: - data = fin.read() - # close the input file - fin.close() - return data - - -def get_repo_address() -> str: - # Search projects - command = os.popen("git config --local remote.origin.url") - url = command.read()[:-1] - - # if it is using the ssh protocal, we need to convert it into an address - # compatible with https as the key is not available inside the container - if url.startswith("git@"): - url = url.replace(":", "/") - url = url.replace("git@", "") - - if url.startswith("https://"): - url = url[len("https://") :] # Removing the https header - - return url - - -def get_commit_sha_and_branch_name( - project_commit_sha_to_consider: str, -) -> Tuple[str, str]: - # Search projects - command = os.popen(f"git rev-parse --short {project_commit_sha_to_consider}") - sha = command.read()[:-1] - command = os.popen(f"git rev-parse --abbrev-ref {project_commit_sha_to_consider}") - branch = command.read()[:-1] - - return sha, branch - - -def check_local_changes() -> None: - command = os.popen("git status --porcelain --untracked-files=no") - output = command.read()[:-1] - if output: - error_print("WARNING: There are currently unpushed changes:") - error_print(output) - - -def check_local_commit_is_pushed(project_commit_ref_to_consider: str) -> None: - command = os.popen(f"git branch -r --contains {project_commit_ref_to_consider}") - remote_branches_containing_commit = command.read()[:-1] - - if not remote_branches_containing_commit: - error_print( - f"WARNING: local commit {project_commit_ref_to_consider} not pushed, " - f"build is likely to fail!" - ) - - -def get_project_folder_name() -> str: - return ( - os.path.basename(os.path.dirname(os.getcwd())).strip().lower().replace(" ", "_") - ) - - -def clone_commands( - project_commit_ref_to_consider: str, - ci_job_token: str, - personal_token: str, - project_name: str, - no_check: bool = False, -) -> str: - repo_address = get_repo_address() - sha, branch = get_commit_sha_and_branch_name(project_commit_ref_to_consider) - - if ci_job_token: # we are in a CI environment - repo_address = f"http://gitlab-ci-token:{ci_job_token}@{repo_address}" - elif personal_token: # if a personal token is available - repo_address = f"https://oauth:{personal_token}@{repo_address}" - else: - repo_address = f"https://{repo_address}" - - print( - f"Building final image using branch: {bold(branch)} with sha: {bold(sha)} \n" - f"URL: {bold(repo_address)}" - ) - - if not no_check: - code_block = f""" - if [ ! -d {project_name} ] - then - echo 'ERROR: you are probably not cloning your project in the right directory' - echo 'Consider using the --project option of build_final_image' - echo 'with one of the folders shown below:' - ls - echo 'if you want to build your image anyway, use the --no-check option' - exit 1 - fi - - """ - else: - code_block = "" - - code_block += f""" - git clone --recurse-submodules --shallow-submodules {repo_address} {project_name} - cd {project_name} - git checkout {sha} - git submodule update - cd .. - """ - - return code_block - - -def apply_changes( - original_file: str, - project_commit_ref_to_consider: str, - ci_job_token: str, - personal_token: str, - project_name: str, - no_check: bool = False, -) -> None: - fout = open("./tmp.def", "w") - for line in original_file.splitlines(): - if "#NOTFORFINAL" in line: - continue - if "#CLONEHERE" in line: - line = clone_commands( - project_commit_ref_to_consider, - ci_job_token, - personal_token, - project_name, - no_check, - ) - fout.write(line + "\n") - fout.close() - - -def compile_container( - project_name: str, image_name: Union[str, None], debug: bool -) -> None: - if not image_name: - image_name = f"final_{project_name}_{time.strftime('%Y-%m-%d_%H_%M_%S')}.sif" - subprocess.run( - ["singularity", "build", "--force", "--fakeroot", image_name, "./tmp.def"] - ) - if not debug: - os.remove("./tmp.def") - - -def get_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Build a read-only final container " - "in which the entire project repository is cloned", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "--path-def", - required=False, - type=str, - default=SINGULARITY_DEFINITION_FILE_NAME, - help="path to singularity definition file.", - ) - - parser.add_argument( - "--commit-ref", - "-c", - required=False, - type=str, - default="HEAD", - help="commit/branch/tag to consider in the project repository " - "(useful only when using #CLONEHERE).", - ) - - parser.add_argument( - "--ci-job-token", - required=False, - type=str, - default=get_ci_job_token(), - help="Gitlab CI job token (useful in particular when using #CLONEHERE). " - "If not specified, it takes the value of the environment variable " - "CI_JOB_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_CI_JOB_TOKEN is not set yet, " - "then it is set the value provided.", - ) - parser.add_argument( - "--personal-token", - required=False, - type=str, - default=get_personal_token(), - help="Gitlab Personal token (useful in particular when using #CLONEHERE). " - "If not specified, it takes the value of the environment variable " - "PERSONAL_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_PERSONAL_TOKEN is not set yet, " - "then it is set the value provided.", - ) - - parser.add_argument( - "--project", - required=False, - type=str, - default=get_project_folder_name(), - help="Specify the name of the project. This corresponds to: " - "(1) Name of the folder in which the current repository will be cloned " - "(useful only when using #CLONEHERE); " - "(2) the name in the final singularity image " - '"final__YYYY_mm_DD_HH_MM_SS.sif". ' - "By default, it uses the name of the parent folder, as it is considered that " - "the script is executed in the 'singularity/' folder of the project.", - ) - - parser.add_argument( - "--image", - "-i", - required=False, - type=str, - default=None, - help="Name of the image to create. By default: " - '"final__YYYY_mm_DD_HH_MM_SS.sif"', - ) - - parser.add_argument( - "--no-check", - action="store_true", - help="Avoids standard verifications (checking if the repository is " - "cloned at the right place).", - ) - - parser.add_argument( - "--debug", - "-d", - action="store_true", - help="Shows debugging information. Temporary files are not removed.", - ) - - args = parser.parse_args() - return args - - -def get_ci_job_token() -> Union[str, None]: - if "CI_JOB_TOKEN" in os.environ: - return os.getenv("CI_JOB_TOKEN") - else: - return None - - -def get_personal_token() -> Union[str, None]: - if "PERSONAL_TOKEN" in os.environ: - return os.getenv("PERSONAL_TOKEN") - else: - return None - - -def generate_singularity_environment_variables( - ci_job_token: Union[str, None], - personal_token: Union[str, None], - project_folder: Union[str, None], -) -> None: - key_singularityenv_ci_job_token = "SINGULARITYENV_CI_JOB_TOKEN" - if ci_job_token and key_singularityenv_ci_job_token not in os.environ: - os.environ[key_singularityenv_ci_job_token] = ci_job_token - - key_singularityenv_personal_token = "SINGULARITYENV_PERSONAL_TOKEN" - if personal_token and key_singularityenv_personal_token not in os.environ: - os.environ[key_singularityenv_personal_token] = personal_token - - key_singularityenv_project_folder = "SINGULARITYENV_PROJECT_FOLDER" - if project_folder and key_singularityenv_project_folder not in os.environ: - os.environ[key_singularityenv_project_folder] = project_folder - - -def main() -> None: - args = get_args() - - path_to_singularity_definition_file = args.path_def - project_commit_ref_to_consider = args.commit_ref - ci_job_token = args.ci_job_token - personal_token = args.personal_token - project_name = args.project - debug = args.debug - image_name = args.image - no_check = args.no_check - - # doing some checks and print warnings - check_local_changes() - check_local_commit_is_pushed(project_commit_ref_to_consider) - - # getting the orignal singularity file - data = load_singularity_file(path_to_singularity_definition_file) - - # appling the changes and writing this in ./tmp.def - apply_changes( - data, - project_commit_ref_to_consider, - ci_job_token, - personal_token, - project_name, - no_check, - ) - - # Create environment variables for singularity - generate_singularity_environment_variables( - ci_job_token, personal_token, project_folder=project_name - ) - - # compiling and deleting ./tmp.def - compile_container(project_name, image_name, debug) - - -if __name__ == "__main__": - main() diff --git a/singularity/build_final_image.py b/singularity/build_final_image.py deleted file mode 100755 index 99911407..00000000 --- a/singularity/build_final_image.py +++ /dev/null @@ -1,347 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import os -import subprocess -import sys -import time -from typing import Tuple, Union - -SINGULARITY_DEFINITION_FILE_NAME = "singularity.def" - - -class BColors: - HEADER = "\033[95m" - OKBLUE = "\033[94m" - OKCYAN = "\033[96m" - OKGREEN = "\033[92m" - WARNING = "\033[93m" - FAIL = "\033[91m" - ENDC = "\033[0m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - - -def error_print(message: str) -> None: - print(f"{BColors.FAIL}{message}{BColors.ENDC}", file=sys.stderr) - - -def bold(message: str) -> str: - return f"{BColors.BOLD}{message}{BColors.ENDC}" - - -def load_singularity_file(path_to_singularity_definition_file: str) -> str: - try: - # read input file - fin = open(path_to_singularity_definition_file, "rt") - - except IOError: - error_print(f"ERROR, {path_to_singularity_definition_file} file not found!") - - finally: - data = fin.read() - # close the input file - fin.close() - return data - - -def get_repo_address() -> str: - # Search projects - command = os.popen("git config --local remote.origin.url") - url = command.read()[:-1] - - # if it is using the ssh protocal, we need to convert it into an address - # compatible with https as the key is not available inside the container - if url.startswith("git@"): - url = url.replace(":", "/") - url = url.replace("git@", "") - - if url.startswith("https://"): - url = url[len("https://") :] # Removing the https header - - return url - - -def get_commit_sha_and_branch_name( - project_commit_sha_to_consider: str, -) -> Tuple[str, str]: - # Search projects - command = os.popen(f"git rev-parse --short {project_commit_sha_to_consider}") - sha = command.read()[:-1] - command = os.popen(f"git rev-parse --abbrev-ref {project_commit_sha_to_consider}") - branch = command.read()[:-1] - - return sha, branch - - -def check_local_changes() -> None: - command = os.popen("git status --porcelain --untracked-files=no") - output = command.read()[:-1] - if output: - error_print("WARNING: There are currently unpushed changes:") - error_print(output) - - -def check_local_commit_is_pushed(project_commit_ref_to_consider: str) -> None: - command = os.popen(f"git branch -r --contains {project_commit_ref_to_consider}") - remote_branches_containing_commit = command.read()[:-1] - - if not remote_branches_containing_commit: - error_print( - f"WARNING: local commit {project_commit_ref_to_consider} not pushed, " - f"build is likely to fail!" - ) - - -def get_project_folder_name() -> str: - return ( - os.path.basename(os.path.dirname(os.getcwd())).strip().lower().replace(" ", "_") - ) - - -def clone_commands( - project_commit_ref_to_consider: str, - ci_job_token: str, - personal_token: str, - project_name: str, - no_check: bool = False, -) -> str: - repo_address = get_repo_address() - sha, branch = get_commit_sha_and_branch_name(project_commit_ref_to_consider) - - if ci_job_token: # we are in a CI environment - repo_address = f"http://gitlab-ci-token:{ci_job_token}@{repo_address}" - elif personal_token: # if a personal token is available - repo_address = f"https://oauth:{personal_token}@{repo_address}" - else: - repo_address = f"https://{repo_address}" - - print( - f"Building final image using branch: {bold(branch)} with sha: {bold(sha)} \n" - f"URL: {bold(repo_address)}" - ) - - if not no_check: - code_block = f""" - if [ ! -d {project_name} ] - then - echo 'ERROR: you are probably not cloning your project in the right directory' - echo 'Consider using the --project option of build_final_image' - echo 'with one of the folders shown below:' - ls - echo 'if you want to build your image anyway, use the --no-check option' - exit 1 - fi - - """ - else: - code_block = "" - - code_block += f""" - git clone --recurse-submodules --shallow-submodules {repo_address} {project_name} - cd {project_name} - git checkout {sha} - git submodule update - cd .. - """ - - return code_block - - -def apply_changes( - original_file: str, - project_commit_ref_to_consider: str, - ci_job_token: str, - personal_token: str, - project_name: str, - no_check: bool = False, -) -> None: - fout = open("./tmp.def", "w") - for line in original_file.splitlines(): - if "#NOTFORFINAL" in line: - continue - if "#CLONEHERE" in line: - line = clone_commands( - project_commit_ref_to_consider, - ci_job_token, - personal_token, - project_name, - no_check, - ) - fout.write(line + "\n") - fout.close() - - -def compile_container( - project_name: str, image_name: Union[str, None], debug: bool -) -> None: - if not image_name: - image_name = f"final_{project_name}_{time.strftime('%Y-%m-%d_%H_%M_%S')}.sif" - subprocess.run( - ["singularity", "build", "--force", "--fakeroot", image_name, "./tmp.def"] - ) - if not debug: - os.remove("./tmp.def") - - -def get_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Build a read-only final container " - "in which the entire project repository is cloned", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - parser.add_argument( - "--path-def", - required=False, - type=str, - default=SINGULARITY_DEFINITION_FILE_NAME, - help="path to singularity definition file.", - ) - - parser.add_argument( - "--commit-ref", - "-c", - required=False, - type=str, - default="HEAD", - help="commit/branch/tag to consider in the project repository " - "(useful only when using #CLONEHERE).", - ) - - parser.add_argument( - "--ci-job-token", - required=False, - type=str, - default=get_ci_job_token(), - help="Gitlab CI job token (useful in particular when using #CLONEHERE). " - "If not specified, it takes the value of the environment variable " - "CI_JOB_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_CI_JOB_TOKEN is not set yet, " - "then it is set the value provided.", - ) - parser.add_argument( - "--personal-token", - required=False, - type=str, - default=get_personal_token(), - help="Gitlab Personal token (useful in particular when using #CLONEHERE). " - "If not specified, it takes the value of the environment variable " - "PERSONAL_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_PERSONAL_TOKEN is not set yet, " - "then it is set the value provided.", - ) - - parser.add_argument( - "--project", - required=False, - type=str, - default=get_project_folder_name(), - help="Specify the name of the project. This corresponds to: " - "(1) Name of the folder in which the current repository will be cloned " - "(useful only when using #CLONEHERE); " - "(2) the name in the final singularity image " - '"final__YYYY_mm_DD_HH_MM_SS.sif". ' - "By default, it uses the name of the parent folder, as it is considered that " - "the script is executed in the 'singularity/' folder of the project.", - ) - - parser.add_argument( - "--image", - "-i", - required=False, - type=str, - default=None, - help="Name of the image to create. By default: " - '"final__YYYY_mm_DD_HH_MM_SS.sif"', - ) - - parser.add_argument( - "--no-check", - action="store_true", - help="Avoids standard verifications (checking if the repository is " - "cloned at the right place).", - ) - - parser.add_argument( - "--debug", - "-d", - action="store_true", - help="Shows debugging information. Temporary files are not removed.", - ) - - args = parser.parse_args() - return args - - -def get_ci_job_token() -> Union[str, None]: - if "CI_JOB_TOKEN" in os.environ: - return os.getenv("CI_JOB_TOKEN") - else: - return None - - -def get_personal_token() -> Union[str, None]: - if "PERSONAL_TOKEN" in os.environ: - return os.getenv("PERSONAL_TOKEN") - else: - return None - - -def generate_singularity_environment_variables( - ci_job_token: Union[str, None], - personal_token: Union[str, None], - project_folder: Union[str, None], -) -> None: - key_singularityenv_ci_job_token = "SINGULARITYENV_CI_JOB_TOKEN" - if ci_job_token and key_singularityenv_ci_job_token not in os.environ: - os.environ[key_singularityenv_ci_job_token] = ci_job_token - - key_singularityenv_personal_token = "SINGULARITYENV_PERSONAL_TOKEN" - if personal_token and key_singularityenv_personal_token not in os.environ: - os.environ[key_singularityenv_personal_token] = personal_token - - key_singularityenv_project_folder = "SINGULARITYENV_PROJECT_FOLDER" - if project_folder and key_singularityenv_project_folder not in os.environ: - os.environ[key_singularityenv_project_folder] = project_folder - - -def main() -> None: - args = get_args() - - path_to_singularity_definition_file = args.path_def - project_commit_ref_to_consider = args.commit_ref - ci_job_token = args.ci_job_token - personal_token = args.personal_token - project_name = args.project - debug = args.debug - image_name = args.image - no_check = args.no_check - - # doing some checks and print warnings - check_local_changes() - check_local_commit_is_pushed(project_commit_ref_to_consider) - - # getting the orignal singularity file - data = load_singularity_file(path_to_singularity_definition_file) - - # appling the changes and writing this in ./tmp.def - apply_changes( - data, - project_commit_ref_to_consider, - ci_job_token, - personal_token, - project_name, - no_check, - ) - - # Create environment variables for singularity - generate_singularity_environment_variables( - ci_job_token, personal_token, project_folder=project_name - ) - - # compiling and deleting ./tmp.def - compile_container(project_name, image_name, debug) - - -if __name__ == "__main__": - main() diff --git a/singularity/singularity.def b/singularity/singularity.def deleted file mode 100644 index ee8ca27a..00000000 --- a/singularity/singularity.def +++ /dev/null @@ -1,41 +0,0 @@ -Bootstrap: library -From: airl_lab/default/airl_env:qdax_f57720d0 - -%labels - Author adaptive.intelligent.robotics@gmail.com - Version v0.0.1 - -%environment - export PYTHONPATH=$PYTHONPATH:/workspace/lib/python3.8/site-packages/ - export LD_LIBRARY_PATH="/workspace/lib:$LD_LIBRARY_PATH" - export PATH=$PATH:/usr/local/go/bin - -%post - export LD_LIBRARY_PATH="/workspace/lib:$LD_LIBRARY_PATH" - apt-get update -y - pip3 install --upgrade pip - - # Create working directory - mkdir -p /git/exp/qdax/ - - #================================================================================== - exit 0 #NOTFORFINAL - the lines below this "exit" will be executed only when building the final image - #================================================================================== - - # Enter working directory - cd /git/exp/ - - #CLONEHERE - -%runscript - # Entering directory - cd /git/exp/qdax/ - - # Running the test file as a demo - echo - echo 'Running the test of MAP-Elites algorithm as a demo' - echo - pytest tests/core_test/map_elites_test.py - -%help - This is the development and running environment of QDax diff --git a/singularity/start_container b/singularity/start_container deleted file mode 100755 index 09b8dd9b..00000000 --- a/singularity/start_container +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import os -import subprocess -import tempfile - -import build_final_image - -EXP_PATH = "git/exp/" -ABSOLUTE_EXP_PATH = "/" + EXP_PATH - - -def get_default_image_name() -> str: - return f"{build_final_image.get_project_folder_name()}.sif" - - -def build_sandbox(path_singularity_def: str, image_name: str) -> None: - # check if the sandbox has already been created - if os.path.exists(image_name): - return - - print(f"{image_name} does not exist, building it now from {path_singularity_def}") - assert os.path.exists( - path_singularity_def - ) # exit if path_singularity_definition_file is not found - - # run commands - command = ( - f"singularity build --force --fakeroot --sandbox {image_name} " - f"{path_singularity_def}" - ) - subprocess.run(command.split()) - - -def run_container( - nvidia: bool, - use_no_home: bool, - use_tmp_home: bool, - image_name: str, - binding_folder_inside_container: str, -) -> None: - additional_args = "" - - if nvidia: - print("Nvidia runtime ON") - additional_args += " " + "--nv" - - if use_no_home: - print("Using --no-home") - additional_args += " " + "--no-home --containall" - - if use_tmp_home: - tmp_home_folder = tempfile.mkdtemp(dir="/tmp") - additional_args += " " + f"--home {tmp_home_folder}" - build_final_image.error_print( - f"Warning: The HOME folder is a temporary directory located in " - f"{tmp_home_folder}! Do not store any result there!" - ) - - if not binding_folder_inside_container: - binding_folder_inside_container = build_final_image.get_project_folder_name() - - path_folder_binding_in_container = os.path.join( - image_name, EXP_PATH, binding_folder_inside_container - ) - if not os.path.exists(path_folder_binding_in_container): - list_possible_folder_binding_in_container = next( - os.walk(os.path.join(image_name, EXP_PATH)) - )[1] - list_possible_options = [ - f" --binding-folder {existing_folder}" - for existing_folder in list_possible_folder_binding_in_container - ] - build_final_image.error_print( - f"Warning: The folder " - f"{os.path.join(ABSOLUTE_EXP_PATH, binding_folder_inside_container)} " - f"does not exist in the container. The Binding between your project folder " - f"and your container is likely to be unsuccessful.\n" - f"You may want to consider adding one of the following options to the " - f"'start_container' command:\n" + "\n".join(list_possible_options) - ) - - command = ( - f"singularity shell -w {additional_args} " - f"--bind {os.path.dirname(os.getcwd())}:" - f"{ABSOLUTE_EXP_PATH}/{binding_folder_inside_container} " - f"{image_name}" - ) - subprocess.run(command.split()) - - -def get_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Build a sandbox container and shell into it.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "-n", "--nv", action="store_true", help="enable experimental Nvidia support" - ) - parser.add_argument( - "--no-home", action="store_true", help='apply --no-home to "singularity shell"' - ) - parser.add_argument( - "--tmp-home", - action="store_true", - help="binds HOME directory of the singularity container to a temporary folder", - ) - - parser.add_argument( - "--path-def", - required=False, - type=str, - default=build_final_image.SINGULARITY_DEFINITION_FILE_NAME, - help="path to singularity definition file", - ) - - parser.add_argument( - "--personal-token", - required=False, - type=str, - default=build_final_image.get_personal_token(), - help="Gitlab Personal token. " - "If not specified, it takes the value of the environment variable " - "PERSONAL_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_PERSONAL_TOKEN is not set yet, " - "then it is set the value provided.", - ) - - parser.add_argument( - "-b", - "--binding-folder", - required=False, - type=str, - default=build_final_image.get_project_folder_name(), - help=f"If specified, it corresponds to the name folder in {ABSOLUTE_EXP_PATH} " - f"from which the binding is performed to the current project source code. " - f"By default, it corresponds to the image name (without the .sif extension)", - ) - - parser.add_argument( - "-i", - "--image", - required=False, - type=str, - default=get_default_image_name(), - help="name of the sandbox image to start", - ) - - args = parser.parse_args() - - return args - - -def main() -> None: - args = get_args() - - enable_nvidia_support = args.nv - use_no_home = args.no_home - use_tmp_home = args.tmp_home - path_singularity_definition_file = args.path_def - image_name = args.image - binding_folder_inside_container = args.binding_folder - personal_token = args.personal_token - - # Create environment variables for singularity - build_final_image.generate_singularity_environment_variables( - ci_job_token=None, - personal_token=personal_token, - project_folder=binding_folder_inside_container, - ) - - build_sandbox(path_singularity_definition_file, image_name) - run_container( - enable_nvidia_support, - use_no_home, - use_tmp_home, - image_name, - binding_folder_inside_container, - ) - - -if __name__ == "__main__": - main() diff --git a/singularity/start_container.py b/singularity/start_container.py deleted file mode 100755 index 09b8dd9b..00000000 --- a/singularity/start_container.py +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import os -import subprocess -import tempfile - -import build_final_image - -EXP_PATH = "git/exp/" -ABSOLUTE_EXP_PATH = "/" + EXP_PATH - - -def get_default_image_name() -> str: - return f"{build_final_image.get_project_folder_name()}.sif" - - -def build_sandbox(path_singularity_def: str, image_name: str) -> None: - # check if the sandbox has already been created - if os.path.exists(image_name): - return - - print(f"{image_name} does not exist, building it now from {path_singularity_def}") - assert os.path.exists( - path_singularity_def - ) # exit if path_singularity_definition_file is not found - - # run commands - command = ( - f"singularity build --force --fakeroot --sandbox {image_name} " - f"{path_singularity_def}" - ) - subprocess.run(command.split()) - - -def run_container( - nvidia: bool, - use_no_home: bool, - use_tmp_home: bool, - image_name: str, - binding_folder_inside_container: str, -) -> None: - additional_args = "" - - if nvidia: - print("Nvidia runtime ON") - additional_args += " " + "--nv" - - if use_no_home: - print("Using --no-home") - additional_args += " " + "--no-home --containall" - - if use_tmp_home: - tmp_home_folder = tempfile.mkdtemp(dir="/tmp") - additional_args += " " + f"--home {tmp_home_folder}" - build_final_image.error_print( - f"Warning: The HOME folder is a temporary directory located in " - f"{tmp_home_folder}! Do not store any result there!" - ) - - if not binding_folder_inside_container: - binding_folder_inside_container = build_final_image.get_project_folder_name() - - path_folder_binding_in_container = os.path.join( - image_name, EXP_PATH, binding_folder_inside_container - ) - if not os.path.exists(path_folder_binding_in_container): - list_possible_folder_binding_in_container = next( - os.walk(os.path.join(image_name, EXP_PATH)) - )[1] - list_possible_options = [ - f" --binding-folder {existing_folder}" - for existing_folder in list_possible_folder_binding_in_container - ] - build_final_image.error_print( - f"Warning: The folder " - f"{os.path.join(ABSOLUTE_EXP_PATH, binding_folder_inside_container)} " - f"does not exist in the container. The Binding between your project folder " - f"and your container is likely to be unsuccessful.\n" - f"You may want to consider adding one of the following options to the " - f"'start_container' command:\n" + "\n".join(list_possible_options) - ) - - command = ( - f"singularity shell -w {additional_args} " - f"--bind {os.path.dirname(os.getcwd())}:" - f"{ABSOLUTE_EXP_PATH}/{binding_folder_inside_container} " - f"{image_name}" - ) - subprocess.run(command.split()) - - -def get_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Build a sandbox container and shell into it.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "-n", "--nv", action="store_true", help="enable experimental Nvidia support" - ) - parser.add_argument( - "--no-home", action="store_true", help='apply --no-home to "singularity shell"' - ) - parser.add_argument( - "--tmp-home", - action="store_true", - help="binds HOME directory of the singularity container to a temporary folder", - ) - - parser.add_argument( - "--path-def", - required=False, - type=str, - default=build_final_image.SINGULARITY_DEFINITION_FILE_NAME, - help="path to singularity definition file", - ) - - parser.add_argument( - "--personal-token", - required=False, - type=str, - default=build_final_image.get_personal_token(), - help="Gitlab Personal token. " - "If not specified, it takes the value of the environment variable " - "PERSONAL_TOKEN, if it exists. " - "If the environment variable SINGULARITYENV_PERSONAL_TOKEN is not set yet, " - "then it is set the value provided.", - ) - - parser.add_argument( - "-b", - "--binding-folder", - required=False, - type=str, - default=build_final_image.get_project_folder_name(), - help=f"If specified, it corresponds to the name folder in {ABSOLUTE_EXP_PATH} " - f"from which the binding is performed to the current project source code. " - f"By default, it corresponds to the image name (without the .sif extension)", - ) - - parser.add_argument( - "-i", - "--image", - required=False, - type=str, - default=get_default_image_name(), - help="name of the sandbox image to start", - ) - - args = parser.parse_args() - - return args - - -def main() -> None: - args = get_args() - - enable_nvidia_support = args.nv - use_no_home = args.no_home - use_tmp_home = args.tmp_home - path_singularity_definition_file = args.path_def - image_name = args.image - binding_folder_inside_container = args.binding_folder - personal_token = args.personal_token - - # Create environment variables for singularity - build_final_image.generate_singularity_environment_variables( - ci_job_token=None, - personal_token=personal_token, - project_folder=binding_folder_inside_container, - ) - - build_sandbox(path_singularity_definition_file, image_name) - run_container( - enable_nvidia_support, - use_no_home, - use_tmp_home, - image_name, - binding_folder_inside_container, - ) - - -if __name__ == "__main__": - main() From a07397d933477b815bdd49c0bc0bc98efa2bd105 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 17 May 2023 12:48:39 +0200 Subject: [PATCH 3/6] fix: dependencies in notebook examples --- examples/cmame.ipynb | 8 +- examples/cmamega.ipynb | 8 +- examples/dads.ipynb | 8 +- examples/diayn.ipynb | 8 +- examples/distributed_mapelites.ipynb | 140 ++++++--------------------- examples/mapelites.ipynb | 8 +- examples/mees.ipynb | 8 +- examples/mome.ipynb | 8 +- examples/nsga2_spea2.ipynb | 6 ++ examples/omgmega.ipynb | 6 ++ examples/pgame.ipynb | 8 +- examples/qdpg.ipynb | 8 +- examples/smerl.ipynb | 10 +- 13 files changed, 112 insertions(+), 122 deletions(-) mode change 100755 => 100644 examples/mees.ipynb diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index c9d6f67e..1d7337d4 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -49,7 +49,13 @@ "except:\n", " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", " import chex\n", - " \n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", "try:\n", " import qdax\n", "except:\n", diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index 509e52ea..2e00d660 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -43,7 +43,13 @@ "except:\n", " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", " import chex\n", - " \n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", "try:\n", " import qdax\n", "except:\n", diff --git a/examples/dads.ipynb b/examples/dads.ipynb index f64f4685..deba8835 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -45,10 +45,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import haiku\n", "except:\n", " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 10cfda49..c725da4b 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -45,10 +45,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import haiku\n", "except:\n", " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index b8a08b52..434725a3 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -2,22 +2,14 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/distributed_mapelites.ipynb)" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "# Optimizing with MAP-Elites in Jax (multi-devices example)\n", "\n", @@ -34,11 +26,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "#@title Installs and Imports\n", @@ -61,10 +49,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -93,22 +87,14 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Setup and get devices" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "Setup the default platform where the MAP-Elites will be stored and MAP-Elite updates will happen. " ] @@ -116,11 +102,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "default_device = 'cpu'\n", @@ -130,11 +112,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Get devices (change gpu by tpu if needed)\n", @@ -146,11 +124,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Setup run parameters" ] @@ -158,11 +132,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "#@title QD Training Definitions Fields\n", @@ -185,11 +155,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Init environment, policy, population params, init states of the env\n", "\n", @@ -199,11 +165,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -237,11 +199,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Define the way the policy interacts with the env\n", "\n", @@ -251,11 +209,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Define the fonction to play a step with the policy in the environment\n", @@ -289,11 +243,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Define the scoring function and the way metrics are computed\n", "\n", @@ -303,11 +253,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Prepare the scoring function\n", @@ -332,11 +278,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Define the emitter\n", "\n", @@ -346,11 +288,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Define emitter\n", @@ -367,11 +305,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Instantiate and initialise the MAP Elites algorithm" ] @@ -379,11 +313,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -423,11 +353,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "## Launch MAP-Elites iterations" ] @@ -435,11 +361,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "log_period = 10\n", @@ -493,11 +415,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Get the repertoire from the first device\n", diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index c456cf5b..18728e73 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -49,10 +49,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/mees.ipynb b/examples/mees.ipynb old mode 100755 new mode 100644 index 8f1dc444..ab5fad93 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -54,10 +54,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@feat/add-algo-mees |tail -n 1\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 6a6f7d39..a4ca36a6 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -49,7 +49,13 @@ "except:\n", " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1\n", " import chex\n", - " \n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", "try:\n", " import qdax\n", "except:\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index 5cbe02a2..51c5f5bd 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -52,6 +52,12 @@ " import chex\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index d75a0077..0a28876a 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -47,6 +47,12 @@ " import chex\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 24222ddf..7a51a0bd 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -48,10 +48,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 5642fd3b..d778ad1d 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -48,10 +48,16 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", " import brax\n", "\n", "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index 8042c8cf..47ff96e9 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -45,8 +45,14 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", - " import \n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", " \n", "try:\n", " import haiku\n", From 5531831288607edbae08214caec8dfb8b0ec6ea4 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 17 May 2023 14:14:49 +0200 Subject: [PATCH 4/6] make all tests pass --- .readthedocs.yaml | 4 ++-- environment.yaml | 1 - setup.py | 1 + 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2ef47062..7eec359d 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -18,7 +18,7 @@ mkdocs: # Optionally declare the Python requirements required to build your docs python: install: - - requirements: requirements.txt - - requirements: docs/requirements.txt - method: pip path: . + - requirements: requirements.txt + - requirements: docs/requirements.txt diff --git a/environment.yaml b/environment.yaml index 78058b9d..e46c034e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -8,6 +8,5 @@ dependencies: - conda>=4.9.2 - pip: - --find-links https://storage.googleapis.com/jax-releases/jax_releases.html - - jaxlib==0.3.15 - -r requirements.txt - -r requirements-dev.txt diff --git a/setup.py b/setup.py index 2e50e0ea..a71f3174 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "brax>=0.0.15", "gym>=0.23.1", "numpy>=1.22.3", + "optax>=0.1, <0.1.5", "scikit-learn>=1.0.2", "scipy>=1.8.0", ], From b07d1a7bc29da73b5780c07b662f0a7b31bd42f1 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 17 May 2023 14:27:51 +0200 Subject: [PATCH 5/6] make all tests pass --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index b97297fa..16c91bc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,9 +6,11 @@ flax==0.6.0 gym==0.23.1 ipython jax==0.3.17 +jaxlib==0.3.15 jumanji==0.1.3 jupyter numpy==1.22.3 +optax==0.1.4 protobuf==3.19.4 scikit-learn==1.0.2 scipy==1.8.0 From 79939ee0f8e6dabccc67baf0ef300eb5906ba314 Mon Sep 17 00:00:00 2001 From: Bryon Tjanaka <38124174+btjanaka@users.noreply.github.com> Date: Wed, 23 Aug 2023 09:13:34 -0400 Subject: [PATCH 6/6] feat(algo): Add MAP-Elites Low-Spread (#152) * Add MELS Repertoire * Create MELS Algorithm class * Introduce Spread type * Add multi_sample_scoring_function Authored-by: b-tjanaka@wings --- README.md | 1 + docs/api_documentation/core/mels.md | 7 + examples/mels.ipynb | 559 ++++++++++++++++++ mkdocs.yml | 1 + qdax/core/containers/mels_repertoire.py | 311 ++++++++++ qdax/core/mels.py | 104 ++++ qdax/types.py | 1 + qdax/utils/sampling.py | 81 ++- .../containers_test/mels_repertoire_test.py | 236 ++++++++ tests/core_test/mels_test.py | 156 +++++ 10 files changed, 1444 insertions(+), 13 deletions(-) create mode 100644 docs/api_documentation/core/mels.md create mode 100644 examples/mels.ipynb create mode 100644 qdax/core/containers/mels_repertoire.py create mode 100644 qdax/core/mels.py create mode 100644 tests/core_test/containers_test/mels_repertoire_test.py create mode 100644 tests/core_test/mels_test.py diff --git a/README.md b/README.md index 2477348d..0881da1d 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,7 @@ QDax currently supports the following algorithms: | [Multi-Objective MAP-Elites (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb) | | [MAP-Elites Evolution Strategies (MEES)](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mees.ipynb) | | [MAP-Elites PBT (ME-PBT)](https://openreview.net/forum?id=CBfYffLqWqb) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/me_sac_pbt.ipynb) | +| [MAP-Elites Low-Spread (ME-LS)](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/me_ls.ipynb) | diff --git a/docs/api_documentation/core/mels.md b/docs/api_documentation/core/mels.md new file mode 100644 index 00000000..3aa212b5 --- /dev/null +++ b/docs/api_documentation/core/mels.md @@ -0,0 +1,7 @@ +# MAP-Elites Low-Spread (ME-LS) + +[ME-LS](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) is a variant of +MAP-Elites that thrives the search process towards solutions that are consistent +in the behavior space for uncertain domains. + +::: qdax.core.mels.MELS diff --git a/examples/mels.ipynb b/examples/mels.ipynb new file mode 100644 index 00000000..1fcd6c42 --- /dev/null +++ b/examples/mels.ipynb @@ -0,0 +1,559 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mels.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing Uncertain Domains with ME-LS in JAX\n", + "\n", + "This notebook shows how to discover controllers that achieve consistent performance in MDP domains using the [MAP-Elites Low-Spread](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) algorithm. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create an emitter\n", + "- how to create an ME-LS instance\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualise the optimization process\n", + "- how to save/load a repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Installs and Imports\n", + "!pip install ipympl |tail -n 1\n", + "# %matplotlib widget\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "\n", + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "import functools\n", + "import time\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "\n", + "from qdax.core.mels import MELS\n", + "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", + "from qdax.core.containers.mels_repertoire import MELSRepertoire\n", + "from qdax import environments\n", + "from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n", + "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", + "from qdax.core.neuroevolution.networks.networks import MLP\n", + "from qdax.core.emitters.mutation_operators import isoline_variation\n", + "from qdax.core.emitters.standard_emitters import MixingEmitter\n", + "from qdax.utils.plotting import plot_map_elites_results\n", + "\n", + "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", + "\n", + "from jax.flatten_util import ravel_pytree\n", + "\n", + "from IPython.display import HTML\n", + "from brax.io import html\n", + "\n", + "\n", + "\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", + "\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "#@markdown ---\n", + "batch_size = 100 #@param {type:\"number\"}\n", + "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", + "num_samples = 5 #@param {type:\"number\"}\n", + "episode_length = 100 #@param {type:\"integer\"}\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", + "seed = 42 #@param {type:\"integer\"}\n", + "policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n", + "iso_sigma = 0.005 #@param {type:\"number\"}\n", + "line_sigma = 0.05 #@param {type:\"number\"}\n", + "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", + "num_centroids = 1024 #@param {type:\"integer\"}\n", + "min_bd = 0. #@param {type:\"number\"}\n", + "max_bd = 1.0 #@param {type:\"number\"}\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Init environment, policy, population params, init states of the env\n", + "\n", + "Define the environment in which the policies will be trained. In this notebook, we consider the problem where each controller is evaluated `num_samples` times, each time in a different environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Init environment\n", + "env = environments.create(env_name, episode_length=episode_length)\n", + "\n", + "# Init a random key\n", + "random_key = jax.random.PRNGKey(seed)\n", + "\n", + "# Init policy network\n", + "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", + "policy_network = MLP(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "\n", + "# Init population of controllers. There are batch_size controllers, and each\n", + "# controller will be evaluated num_samples times.\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jax.random.split(subkey, num=batch_size)\n", + "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the way the policy interacts with the env\n", + "\n", + "Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data. This is identical to the function in the MAP-Elites tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define the function to play a step with the policy in the environment\n", + "def play_step_fn(\n", + " env_state,\n", + " policy_params,\n", + " random_key,\n", + "):\n", + " \"\"\"Play an environment step and return the updated state and the\n", + " transition.\"\"\"\n", + "\n", + " actions = policy_network.apply(policy_params, env_state.obs)\n", + "\n", + " state_desc = env_state.info[\"state_descriptor\"]\n", + " next_state = env.step(env_state, actions)\n", + "\n", + " transition = QDTransition(\n", + " obs=env_state.obs,\n", + " next_obs=next_state.obs,\n", + " rewards=next_state.reward,\n", + " dones=next_state.done,\n", + " actions=actions,\n", + " truncations=next_state.info[\"truncation\"],\n", + " state_desc=state_desc,\n", + " next_state_desc=next_state.info[\"state_descriptor\"],\n", + " )\n", + "\n", + " return next_state, policy_params, random_key, transition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the scoring function and the way metrics are computed\n", + "\n", + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. Note that while the MAP-Elites tutorial uses `scoring_function_brax_envs` as the basis for the scoring function, we use `reset_based_scoring_function_brax_envs`. The difference is that `reset_based_scoring_function_brax_envs` generates initial states randomly instead of taking in a fixed set of initial states. This is necessary since we are evaluating each controller across sampled initial states. If the initial states were kept the same for all evaluations, there would be no stochasticity in the behavior." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Prepare the scoring function\n", + "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "scoring_fn = functools.partial(\n", + " reset_based_scoring_function_brax_envs,\n", + " episode_length=episode_length,\n", + " play_reset_fn=env.reset,\n", + " play_step_fn=play_step_fn,\n", + " behavior_descriptor_extractor=bd_extraction_fn,\n", + ")\n", + "\n", + "# Get minimum reward value to make sure qd_score are positive\n", + "reward_offset = environments.reward_offset[env_name]\n", + "\n", + "# Define a metrics function\n", + "metrics_fn = functools.partial(\n", + " default_qd_metrics,\n", + " qd_offset=reward_offset * episode_length,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the emitter\n", + "\n", + "The emitter is used to evolve the population at each mutation step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define emitter\n", + "variation_fn = functools.partial(\n", + " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", + ")\n", + "mixing_emitter = MixingEmitter(\n", + " mutation_fn=None, \n", + " variation_fn=variation_fn, \n", + " variation_percentage=1.0, \n", + " batch_size=batch_size\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and initialise the ME-LS algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Instantiate ME-LS.\n", + "mels = MELS(\n", + " scoring_function=scoring_fn,\n", + " emitter=mixing_emitter,\n", + " metrics_function=metrics_fn,\n", + " num_samples=num_samples,\n", + ")\n", + "\n", + "# Compute the centroids\n", + "centroids, random_key = compute_cvt_centroids(\n", + " num_descriptors=env.behavior_descriptor_length,\n", + " num_init_cvt_samples=num_init_cvt_samples,\n", + " num_centroids=num_centroids,\n", + " minval=min_bd,\n", + " maxval=max_bd,\n", + " random_key=random_key,\n", + ")\n", + "\n", + "# Compute initial repertoire and emitter state\n", + "repertoire, emitter_state, random_key = mels.init(init_variables, centroids, random_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch ME-LS iterations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "log_period = 10\n", + "num_loops = int(num_iterations / log_period)\n", + "\n", + "csv_logger = CSVLogger(\n", + " \"mapelites-logs.csv\",\n", + " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + ")\n", + "all_metrics = {}\n", + "\n", + "# main loop\n", + "mels_scan_update = mels.scan_update\n", + "for i in range(num_loops):\n", + " start_time = time.time()\n", + " # main iterations\n", + " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " mels_scan_update,\n", + " (repertoire, emitter_state, random_key),\n", + " (),\n", + " length=log_period,\n", + " )\n", + " timelapse = time.time() - start_time\n", + "\n", + " # log metrics\n", + " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", + " for key, value in metrics.items():\n", + " # take last value\n", + " logged_metrics[key] = value[-1]\n", + "\n", + " # take all values\n", + " if key in all_metrics.keys():\n", + " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", + " else:\n", + " all_metrics[key] = value\n", + "\n", + " csv_logger.log(logged_metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Visualization\n", + "\n", + "# create the x-axis array\n", + "env_steps = jnp.arange(num_iterations) * episode_length * batch_size\n", + "\n", + "# create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to save/load a repertoire\n", + "\n", + "The following cells show how to save or load a repertoire of individuals and add a few lines to visualise the best performing individual in a simulation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the final repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "repertoire_path = \"./last_repertoire/\"\n", + "os.makedirs(repertoire_path, exist_ok=True)\n", + "repertoire.save(path=repertoire_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build the reconstruction function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Init population of policies\n", + "random_key, subkey = jax.random.split(random_key)\n", + "fake_batch = jnp.zeros(shape=(env.observation_size,))\n", + "fake_params = policy_network.init(subkey, fake_batch)\n", + "\n", + "_, reconstruction_fn = ravel_pytree(fake_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use the reconstruction function to load and re-create the repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "repertoire = MELSRepertoire.load(reconstruction_fn=reconstruction_fn, path=repertoire_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get the best individual of the repertoire\n", + "\n", + "Note that in ME-LS, the individual's cell is computed by finding its most frequent archive cell among its `num_samples` behavior descriptors. Thus, the descriptor associated with each individual in the archive is not its mean descriptor. Rather, we set the descriptor in the archive to be the centroid of the individual's most frequent archive cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "best_idx = jnp.argmax(repertoire.fitnesses)\n", + "best_fitness = jnp.max(repertoire.fitnesses)\n", + "best_bd = repertoire.descriptors[best_idx]\n", + "best_spread = repertoire.spreads[best_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " f\"Best fitness in the repertoire: {best_fitness:.2f}\\n\"\n", + " f\"Behavior descriptor of the best individual in the repertoire: {best_bd}\\n\"\n", + " f\"Spread of the best individual in the repertoire: {best_spread}\\n\"\n", + " f\"Index in the repertoire of this individual: {best_idx}\\n\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_params = jax.tree_util.tree_map(\n", + " lambda x: x[best_idx],\n", + " repertoire.genotypes\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Play some steps in the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jit_env_reset = jax.jit(env.reset)\n", + "jit_env_step = jax.jit(env.step)\n", + "jit_inference_fn = jax.jit(policy_network.apply)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rollout = []\n", + "rng = jax.random.PRNGKey(seed=1)\n", + "state = jit_env_reset(rng=rng)\n", + "while not state.done:\n", + " rollout.append(state)\n", + " action = jit_inference_fn(my_params, state.obs)\n", + " state = jit_env_step(state, action)\n", + "\n", + "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "HTML(html.render(env.sys, [s.qp for s in rollout[:500]]))" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/mkdocs.yml b/mkdocs.yml index 702a474a..2c0bbdb6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -140,6 +140,7 @@ nav: - MOME: api_documentation/core/mome.md - ME ES: api_documentation/core/mees.md - ME PBT: api_documentation/core/me_pbt.md + - ME LS: api_documentation/core/mels.md - Baseline algorithms: - SMERL: api_documentation/core/smerl.md - DIAYN: api_documentation/core/diayn.md diff --git a/qdax/core/containers/mels_repertoire.py b/qdax/core/containers/mels_repertoire.py new file mode 100644 index 00000000..a2e99971 --- /dev/null +++ b/qdax/core/containers/mels_repertoire.py @@ -0,0 +1,311 @@ +"""This file contains the class to define the repertoire used to +store individuals in the Multi-Objective MAP-Elites algorithm as +well as several variants.""" + +from __future__ import annotations + +from typing import Callable, Optional + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from qdax.core.containers.mapelites_repertoire import ( + MapElitesRepertoire, + get_cells_indices, +) +from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, Spread + + +def _dispersion(descriptors: jnp.ndarray) -> jnp.ndarray: + """Computes dispersion of a batch of num_samples descriptors. + + Args: + descriptors: (num_samples, num_descriptors) array of descriptors. + Returns: + The float dispersion of the descriptors (this is represented as a scalar + jnp.ndarray). + """ + + # Pairwise distances between the descriptors. + dists = jnp.linalg.norm(descriptors[:, None] - descriptors, axis=2) + + # Compute dispersion -- this is the mean of the unique pairwise distances. + # + # Zero out the duplicate distances since the distance matrix is diagonal. + # Setting k=1 will also remove entries on the diagonal since they are zero. + dists = jnp.triu(dists, k=1) + + num_samples = len(descriptors) + n_pairwise = num_samples * (num_samples - 1) / 2.0 + + return jnp.sum(dists) / n_pairwise + + +def _mode(x: jnp.ndarray) -> jnp.ndarray: + """Computes mode (most common item) of an array. + + The return type is a scalar ndarray. + """ + unique_vals, counts = jnp.unique(x, return_counts=True, size=x.size) + return unique_vals[jnp.argmax(counts)] + + +class MELSRepertoire(MapElitesRepertoire): + """Class for the repertoire in MAP-Elites Low-Spread. + + This class inherits from MapElitesRepertoire. In addition to the stored data in + MapElitesRepertoire (genotypes, fitnesses, descriptors, centroids), this repertoire + also maintains an array of spreads. We overload the save, load, add, and + init_default methods of MapElitesRepertoire. + + Refer to Mace 2023 for more info on MAP-Elites Low-Spread: + https://dl.acm.org/doi/abs/10.1145/3583131.3590433 + + Args: + genotypes: a PyTree containing all the genotypes in the repertoire ordered + by the centroids. Each leaf has a shape (num_centroids, num_features). The + PyTree can be a simple Jax array or a more complex nested structure such + as to represent parameters of neural network in Flax. + fitnesses: an array that contains the fitness of solutions in each cell of the + repertoire, ordered by centroids. The array shape is (num_centroids,). + descriptors: an array that contains the descriptors of solutions in each cell + of the repertoire, ordered by centroids. The array shape + is (num_centroids, num_descriptors). + centroids: an array that contains the centroids of the tessellation. The array + shape is (num_centroids, num_descriptors). + spreads: an array that contains the spread of solutions in each cell of the + repertoire, ordered by centroids. The array shape is (num_centroids,). + """ + + spreads: Spread + + def save(self, path: str = "./") -> None: + """Saves the repertoire on disk in the form of .npy files. + + Flattens the genotypes to store it with .npy format. Supposes that + a user will have access to the reconstruction function when loading + the genotypes. + + Args: + path: Path where the data will be saved. Defaults to "./". + """ + + def flatten_genotype(genotype: Genotype) -> jnp.ndarray: + flatten_genotype, _ = ravel_pytree(genotype) + return flatten_genotype + + # flatten all the genotypes + flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes) + + # save data + jnp.save(path + "genotypes.npy", flat_genotypes) + jnp.save(path + "fitnesses.npy", self.fitnesses) + jnp.save(path + "descriptors.npy", self.descriptors) + jnp.save(path + "centroids.npy", self.centroids) + jnp.save(path + "spreads.npy", self.spreads) + + @classmethod + def load(cls, reconstruction_fn: Callable, path: str = "./") -> MELSRepertoire: + """Loads a MAP-Elites Low-Spread Repertoire. + + Args: + reconstruction_fn: Function to reconstruct a PyTree + from a flat array. + path: Path where the data is saved. Defaults to "./". + + Returns: + A MAP-Elites Low-Spread Repertoire. + """ + + flat_genotypes = jnp.load(path + "genotypes.npy") + genotypes = jax.vmap(reconstruction_fn)(flat_genotypes) + + fitnesses = jnp.load(path + "fitnesses.npy") + descriptors = jnp.load(path + "descriptors.npy") + centroids = jnp.load(path + "centroids.npy") + spreads = jnp.load(path + "spreads.npy") + + return cls( + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + spreads=spreads, + ) + + @jax.jit + def add( + self, + batch_of_genotypes: Genotype, + batch_of_descriptors: Descriptor, + batch_of_fitnesses: Fitness, + batch_of_extra_scores: Optional[ExtraScores] = None, + ) -> MELSRepertoire: + """ + Add a batch of elements to the repertoire. + + The key difference between this method and the default add() in + MapElitesRepertoire is that it expects each individual to be evaluated + `num_samples` times, resulting in `num_samples` fitnesses and + `num_samples` descriptors per individual. + + If multiple individuals may be added to a single cell, this method will + arbitrarily pick one -- the exact choice depends on the implementation of + jax.at[].set(), which can be non-deterministic: + https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html + We do not currently check if one of the multiple individuals dominates the + others (dominate means that the individual has both highest fitness and lowest + spread among the individuals for that cell). + + If `num_samples` is only 1, the spreads will default to 0. + + Args: + batch_of_genotypes: a batch of genotypes to be added to the repertoire. + Similarly to the self.genotypes argument, this is a PyTree in which + the leaves have a shape (batch_size, num_features) + batch_of_descriptors: an array that contains the descriptors of the + aforementioned genotypes over all evals. Its shape is + (batch_size, num_samples, num_descriptors). Note that we "aggregate" + descriptors by finding the most frequent cell of each individual. Thus, + the actual descriptors stored in the repertoire are just the coordinates + of the centroid of the most frequent cell. + batch_of_fitnesses: an array that contains the fitnesses of the + aforementioned genotypes over all evals. Its shape is (batch_size, + num_samples) + batch_of_extra_scores: unused tree that contains the extra_scores of + aforementioned genotypes. + + Returns: + The updated repertoire. + """ + batch_size, num_samples = batch_of_fitnesses.shape + + # Compute indices/cells of all descriptors. + batch_of_all_indices = get_cells_indices( + batch_of_descriptors.reshape(batch_size * num_samples, -1), self.centroids + ).reshape((batch_size, num_samples)) + + # Compute most frequent cell of each solution. + batch_of_indices = jax.vmap(_mode)(batch_of_all_indices)[:, None] + + # Compute dispersion / spread. The dispersion is set to zero if + # num_samples is 1. + batch_of_spreads = jax.lax.cond( + num_samples == 1, + lambda desc: jnp.zeros(batch_size), + lambda desc: jax.vmap(_dispersion)( + desc.reshape((batch_size, num_samples, -1)) + ), + batch_of_descriptors, + ) + batch_of_spreads = jnp.expand_dims(batch_of_spreads, axis=-1) + + # Compute canonical descriptors as the descriptor of the centroid of the most + # frequent cell. Note that this line redefines the earlier batch_of_descriptors. + batch_of_descriptors = jnp.take_along_axis( + self.centroids, batch_of_indices, axis=0 + ) + + # Compute canonical fitnesses as the average fitness. + # + # Shape: (batch_size, 1) + batch_of_fitnesses = batch_of_fitnesses.mean(axis=-1, keepdims=True) + + num_centroids = self.centroids.shape[0] + + # get current repertoire fitnesses and spreads + repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1) + current_fitnesses = jnp.take_along_axis( + repertoire_fitnesses, batch_of_indices, 0 + ) + + repertoire_spreads = jnp.expand_dims(self.spreads, axis=-1) + current_spreads = jnp.take_along_axis(repertoire_spreads, batch_of_indices, 0) + + # get addition condition + addition_condition_fitness = batch_of_fitnesses > current_fitnesses + addition_condition_spread = batch_of_spreads <= current_spreads + addition_condition = jnp.logical_and( + addition_condition_fitness, addition_condition_spread + ) + + # assign fake position when relevant : num_centroids is out of bound + batch_of_indices = jnp.where( + addition_condition, x=batch_of_indices, y=num_centroids + ) + + # create new repertoire + new_repertoire_genotypes = jax.tree_util.tree_map( + lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[ + batch_of_indices.squeeze(axis=-1) + ].set(new_genotypes), + self.genotypes, + batch_of_genotypes, + ) + + # compute new fitness and descriptors + new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set( + batch_of_fitnesses.squeeze(axis=-1) + ) + new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set( + batch_of_descriptors + ) + new_spreads = self.spreads.at[batch_of_indices.squeeze(axis=-1)].set( + batch_of_spreads.squeeze(axis=-1) + ) + + return MELSRepertoire( + genotypes=new_repertoire_genotypes, + fitnesses=new_fitnesses, + descriptors=new_descriptors, + centroids=self.centroids, + spreads=new_spreads, + ) + + @classmethod + def init_default( + cls, + genotype: Genotype, + centroids: Centroid, + ) -> MELSRepertoire: + """Initialize a MAP-Elites Low-Spread repertoire with an initial population of + genotypes. Requires the definition of centroids that can be computed with any + method such as CVT or Euclidean mapping. + + Note: this function has been kept outside of the object MELS, so + it can be called easily called from other modules. + + Args: + genotype: the typical genotype that will be stored. + centroids: the centroids of the repertoire. + + Returns: + A repertoire filled with default values. + """ + + # get number of centroids + num_centroids = centroids.shape[0] + + # default fitness is -inf + default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) + + # default genotypes is all 0 + default_genotypes = jax.tree_util.tree_map( + lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype), + genotype, + ) + + # default descriptor is all zeros + default_descriptors = jnp.zeros_like(centroids) + + # default spread is inf so that any spread will be less + default_spreads = jnp.full(shape=num_centroids, fill_value=jnp.inf) + + return cls( + genotypes=default_genotypes, + fitnesses=default_fitnesses, + descriptors=default_descriptors, + centroids=centroids, + spreads=default_spreads, + ) diff --git a/qdax/core/mels.py b/qdax/core/mels.py new file mode 100644 index 00000000..6c06b785 --- /dev/null +++ b/qdax/core/mels.py @@ -0,0 +1,104 @@ +"""Core components of the MAP-Elites Low-Spread algorithm.""" +from __future__ import annotations + +from functools import partial +from typing import Callable, Optional, Tuple + +import jax + +from qdax.core.containers.mels_repertoire import MELSRepertoire +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.core.map_elites import MAPElites +from qdax.types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + Metrics, + RNGKey, +) +from qdax.utils.sampling import multi_sample_scoring_function + + +class MELS(MAPElites): + """Core elements of the MAP-Elites Low-Spread algorithm. + + Most methods in this class are inherited from MAPElites. + + The same scoring function can be passed into both MAPElites and this class. + We have overridden __init__ such that it takes in the scoring function and + wraps it such that every solution is evaluated `num_samples` times. + + We also overrode the init method to use the MELSRepertoire instead of + MapElitesRepertoire. + """ + + def __init__( + self, + scoring_function: Callable[ + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + ], + emitter: Emitter, + metrics_function: Callable[[MELSRepertoire], Metrics], + num_samples: int, + ) -> None: + self._scoring_function = partial( + multi_sample_scoring_function, + scoring_fn=scoring_function, + num_samples=num_samples, + ) + self._emitter = emitter + self._metrics_function = metrics_function + self._num_samples = num_samples + + @partial(jax.jit, static_argnames=("self",)) + def init( + self, + init_genotypes: Genotype, + centroids: Centroid, + random_key: RNGKey, + ) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]: + """Initialize a MAP-Elites Low-Spread repertoire with an initial + population of genotypes. Requires the definition of centroids that can + be computed with any method such as CVT or Euclidean mapping. + + Args: + init_genotypes: initial genotypes, pytree in which leaves + have shape (batch_size, num_features) + centroids: tessellation centroids of shape (batch_size, num_descriptors) + random_key: a random key used for stochastic operations. + + Returns: + A tuple of (initialized MAP-Elites Low-Spread repertoire, initial emitter + state, JAX random key). + """ + # score initial genotypes + fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + init_genotypes, random_key + ) + + # init the repertoire + repertoire = MELSRepertoire.init( + genotypes=init_genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # get initial state of the emitter + emitter_state, random_key = self._emitter.init( + init_genotypes=init_genotypes, random_key=random_key + ) + + # update emitter state + emitter_state = self._emitter.state_update( + emitter_state=emitter_state, + repertoire=repertoire, + genotypes=init_genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, + ) + return repertoire, emitter_state, random_key diff --git a/qdax/types.py b/qdax/types.py index 67fbb8a0..5000869b 100644 --- a/qdax/types.py +++ b/qdax/types.py @@ -26,6 +26,7 @@ Genotype: TypeAlias = ArrayTree Descriptor: TypeAlias = jnp.ndarray Centroid: TypeAlias = jnp.ndarray +Spread: TypeAlias = jnp.ndarray Gradient: TypeAlias = jnp.ndarray Skill: TypeAlias = jnp.ndarray diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index 88b6286e..a25e190f 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -8,7 +8,7 @@ from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey -@partial(jax.jit, static_argnames=("num_samples")) +@partial(jax.jit, static_argnames=("num_samples",)) def dummy_extra_scores_extractor( extra_scores: ExtraScores, num_samples: int, @@ -29,6 +29,60 @@ def dummy_extra_scores_extractor( return extra_scores +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_samples", + ), +) +def multi_sample_scoring_function( + policies_params: Genotype, + random_key: RNGKey, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_samples: int, +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Wrap scoring_function to perform sampling. + + This function returns the fitnesses, descriptors, and extra_scores computed + over num_samples evaluations with the scoring_fn. + + Args: + policies_params: policies to evaluate + random_key: JAX random key + scoring_fn: scoring function used for evaluation + num_samples: number of samples to generate for each individual + + Returns: + (n, num_samples) array of fitnesses, + (n, num_samples, num_descriptors) array of descriptors, + dict with num_samples extra_scores per individual, + JAX random key + """ + + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=num_samples) + + # evaluate + sample_scoring_fn = jax.vmap( + scoring_fn, + # vectorizing over axis 0 vectorizes over the num_samples random keys + in_axes=(None, 0), + # indicates that the vectorized axis will become axis 1, i.e., the final + # output is shape (batch_size, num_samples, ...) + out_axes=1, + ) + all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn( + policies_params, keys + ) + + return all_fitnesses, all_descriptors, all_extra_scores, random_key + + @partial( jax.jit, static_argnames=( @@ -49,14 +103,16 @@ def sampling( [ExtraScores, int], ExtraScores ] = dummy_extra_scores_extractor, ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: - """ - Wrap scoring_function to perform sampling. + """Wrap scoring_function to perform sampling. + + This function averages the fitnesses and descriptors for each individual + over `num_samples` evaluations. Args: policies_params: policies to evaluate - random_key + random_key: JAX random key scoring_fn: scoring function used for evaluation - num_samples + num_samples: number of samples to generate for each individual extra_scores_extractor: function to extract the extra_scores from multiple samples of the same policy. @@ -65,14 +121,13 @@ def sampling( The extra_score extract from samples with extra_scores_extractor A new random key """ - - random_key, subkey = jax.random.split(random_key) - keys = jax.random.split(subkey, num=num_samples) - - # evaluate - sample_scoring_fn = jax.vmap(scoring_fn, (None, 0), 1) - all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn( - policies_params, keys + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = multi_sample_scoring_function( + policies_params, random_key, scoring_fn, num_samples ) # average results diff --git a/tests/core_test/containers_test/mels_repertoire_test.py b/tests/core_test/containers_test/mels_repertoire_test.py new file mode 100644 index 00000000..2fb1bd76 --- /dev/null +++ b/tests/core_test/containers_test/mels_repertoire_test.py @@ -0,0 +1,236 @@ +import jax.numpy as jnp +import pytest + +from qdax.core.containers.mels_repertoire import MELSRepertoire +from qdax.types import ExtraScores + + +def test_add_to_mels_repertoire() -> None: + """Test several additions to the MELSRepertoire, including adding a solution + and overwriting it by adding multiple solutions.""" + genotype_size = 12 + num_centroids = 4 + num_descriptors = 2 + + # create a repertoire instance + repertoire = MELSRepertoire( + genotypes=jnp.zeros(shape=(num_centroids, genotype_size)), + fitnesses=jnp.ones(shape=(num_centroids,)) * (-jnp.inf), + descriptors=jnp.zeros(shape=(num_centroids, num_descriptors)), + centroids=jnp.array( + [ + [1.0, 1.0], + [2.0, 1.0], + [2.0, 2.0], + [1.0, 2.0], + ] + ), + spreads=jnp.full(shape=(num_centroids,), fill_value=jnp.inf), + ) + + # + # Test 1: Insert a single solution. + # + + # create fake genotypes and scores to add + fake_genotypes = jnp.ones(shape=(1, genotype_size)) + # each solution gets two fitnesses and two descriptors + fake_fitnesses = jnp.array([[0.0, 0.0]]) + fake_descriptors = jnp.array([[[0.0, 1.0], [1.0, 1.0]]]) + fake_extra_scores: ExtraScores = {} + + # do an addition + repertoire = repertoire.add( + fake_genotypes, fake_descriptors, fake_fitnesses, fake_extra_scores + ) + + # check that the repertoire looks as expected + expected_genotypes = jnp.array( + [ + [1.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses = jnp.array([0.0, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + expected_spreads = jnp.array([1.0, jnp.inf, jnp.inf, jnp.inf]) + + # check values + pytest.assume(jnp.allclose(repertoire.genotypes, expected_genotypes, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.fitnesses, expected_fitnesses, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.descriptors, expected_descriptors, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads, atol=1e-6)) + + # + # Test 2: Adding solutions into the same cell as above. + # + + # create fake genotypes and scores to add + fake_genotypes = jnp.concatenate( + ( + jnp.full(shape=(1, genotype_size), fill_value=2.0), + jnp.full(shape=(1, genotype_size), fill_value=3.0), + ), + axis=0, + ) + # Each solution gets two fitnesses and two descriptors (i.e. num_evals = 2). One + # solution has fitness 1.0 and spread 0.75, while the other has fitness 0.5 and + # spread 0.5. Thus, neither solution dominates the other (by having both higher + # fitness and lower spread). However, both solutions would be valid candidates for + # the archive due to dominating the current solution there. + fake_fitnesses = jnp.array([[1.0, 1.0], [0.5, 0.5]]) + fake_descriptors = jnp.array([[[1.0, 0.25], [1.0, 1.0]], [[1.0, 0.5], [1.0, 1.0]]]) + fake_extra_scores: ExtraScores = {} + + # do an addition + repertoire = repertoire.add( + fake_genotypes, fake_descriptors, fake_fitnesses, fake_extra_scores + ) + + # Either solution may be added due to the behavior of jax.at[].set(): + # https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html + # Thus, we provide possible values for each scenario. + + # check that the repertoire looks like expected + expected_genotypes_1 = jnp.array( + [ + [2.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses_1 = jnp.array([1.0, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors_1 = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + expected_spreads_1 = jnp.array([0.75, jnp.inf, jnp.inf, jnp.inf]) + + expected_genotypes_2 = jnp.array( + [ + [3.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses_2 = jnp.array([0.5, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors_2 = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + expected_spreads_2 = jnp.array([0.5, jnp.inf, jnp.inf, jnp.inf]) + + # check values + pytest.assume( + jnp.allclose(repertoire.genotypes, expected_genotypes_1, atol=1e-6) + or jnp.allclose(repertoire.genotypes, expected_genotypes_2, atol=1e-6) + ) + + if jnp.allclose(repertoire.genotypes, expected_genotypes_1, atol=1e-6): + pytest.assume( + jnp.allclose(repertoire.genotypes, expected_genotypes_1, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.fitnesses, expected_fitnesses_1, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.descriptors, expected_descriptors_1, atol=1e-6) + ) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads_1, atol=1e-6)) + elif jnp.allclose(repertoire.genotypes, expected_genotypes_2, atol=1e-6): + pytest.assume( + jnp.allclose(repertoire.genotypes, expected_genotypes_2, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.fitnesses, expected_fitnesses_2, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.descriptors, expected_descriptors_2, atol=1e-6) + ) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads_2, atol=1e-6)) + + +def test_add_with_single_eval() -> None: + """Tries adding with a single evaluation. + + This is a special case because the spread defaults to 0. + """ + genotype_size = 12 + num_centroids = 4 + num_descriptors = 2 + + # create a repertoire instance + repertoire = MELSRepertoire( + genotypes=jnp.zeros(shape=(num_centroids, genotype_size)), + fitnesses=jnp.ones(shape=(num_centroids,)) * (-jnp.inf), + descriptors=jnp.zeros(shape=(num_centroids, num_descriptors)), + centroids=jnp.array( + [ + [1.0, 1.0], + [2.0, 1.0], + [2.0, 2.0], + [1.0, 2.0], + ] + ), + spreads=jnp.full(shape=(num_centroids,), fill_value=jnp.inf), + ) + + # Insert a single solution with only one eval. + + # create fake genotypes and scores to add + fake_genotypes = jnp.ones(shape=(1, genotype_size)) + # the solution gets one fitness and one descriptor. + fake_fitnesses = jnp.array([[0.0]]) + fake_descriptors = jnp.array([[[0.0, 1.0]]]) + fake_extra_scores: ExtraScores = {} + + # do an addition + repertoire = repertoire.add( + fake_genotypes, fake_descriptors, fake_fitnesses, fake_extra_scores + ) + + # check that the repertoire looks as expected + expected_genotypes = jnp.array( + [ + [1.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses = jnp.array([0.0, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + # Spread should be 0 since there's only one eval. + expected_spreads = jnp.array([0.0, jnp.inf, jnp.inf, jnp.inf]) + + # check values + pytest.assume(jnp.allclose(repertoire.genotypes, expected_genotypes, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.fitnesses, expected_fitnesses, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.descriptors, expected_descriptors, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads, atol=1e-6)) diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py new file mode 100644 index 00000000..21f90517 --- /dev/null +++ b/tests/core_test/mels_test.py @@ -0,0 +1,156 @@ +"""Tests MAP-Elites Low-Spread implementation.""" + +import functools +from typing import Dict, Tuple + +import jax +import jax.numpy as jnp +import pytest + +from qdax import environments +from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids +from qdax.core.containers.mels_repertoire import MELSRepertoire +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.core.mels import MELS +from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs +from qdax.types import EnvState, Params, RNGKey + + +@pytest.mark.parametrize( + "env_name, batch_size", + [("walker2d_uni", 1), ("walker2d_uni", 10), ("hopper_uni", 10)], +) +def test_mels(env_name: str, batch_size: int) -> None: + batch_size = batch_size + env_name = env_name + num_samples = 5 + episode_length = 100 + num_iterations = 5 + seed = 42 + policy_hidden_layer_sizes = (64, 64) + num_init_cvt_samples = 1000 + num_centroids = 50 + min_bd = 0.0 + max_bd = 1.0 + + # Init environment + env = environments.create(env_name, episode_length=episode_length) + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init policy network + policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) + policy_network = MLP( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + + # Init population of controllers. There are batch_size controllers, and each + # controller will be evaluated num_samples times. + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=batch_size) + fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) + init_variables = jax.vmap(policy_network.init)(keys, fake_batch) + + # Define the function to play a step with the policy in the environment + def play_step_fn( + env_state: EnvState, + policy_params: Params, + random_key: RNGKey, + ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: + """Play an environment step and return the updated state and the + transition.""" + + actions = policy_network.apply(policy_params, env_state.obs) + + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = QDTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + actions=actions, + truncations=next_state.info["truncation"], + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + ) + + return next_state, policy_params, random_key, transition + + # Prepare the scoring function + bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] + scoring_fn = functools.partial( + reset_based_scoring_function_brax_envs, + episode_length=episode_length, + play_reset_fn=env.reset, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + + # Define emitter + variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + + # Get minimum reward value to make sure qd_score are positive + reward_offset = environments.reward_offset[env_name] + + # Define a metrics function + def metrics_fn(repertoire: MELSRepertoire) -> Dict: + # Get metrics + grid_empty = repertoire.fitnesses == -jnp.inf + qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty) + # Add offset for positive qd_score + qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty) + coverage = 100 * jnp.mean(1.0 - grid_empty) + max_fitness = jnp.max(repertoire.fitnesses) + + return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} + + # Instantiate ME-LS. + mels = MELS( + scoring_function=scoring_fn, + emitter=mixing_emitter, + metrics_function=metrics_fn, + num_samples=num_samples, + ) + + # Compute the centroids + centroids, random_key = compute_cvt_centroids( + num_descriptors=env.behavior_descriptor_length, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=min_bd, + maxval=max_bd, + random_key=random_key, + ) + + # Compute initial repertoire + repertoire, emitter_state, random_key = mels.init( + init_variables, centroids, random_key + ) + + # Run the algorithm + (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + mels.scan_update, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None) + + +if __name__ == "__main__": + test_mels(env_name="pointmaze", batch_size=10)