Skip to content

Commit

Permalink
Update JAX implementation (#161)
Browse files Browse the repository at this point in the history
* Remove implementation that support JAX in Python 3.7 in NVIDIA Isaac examples

* Get JAX device from string

* Catch invalid device index

* Perform JAX computation on the selected device

* Update CHANGELOG
  • Loading branch information
Toni-SM committed Jun 23, 2024
1 parent 932a56b commit 50f4a96
Show file tree
Hide file tree
Showing 48 changed files with 108 additions and 594 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Fixed
- Move the batch sampling inside gradient step loop for DDPG and TD3
- Perform JAX computation on the selected device

## [1.1.0] - 2024-02-12
### Added
Expand Down
20 changes: 0 additions & 20 deletions docs/source/examples/isaaclab/jax_ant_ddpg.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
"""
Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
* Python 3.7 is only supported up to jax<=0.3.25.
See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
* Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
* The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
* Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
"""

import flax.linen as nn
import jax
import jax.numpy as jnp


jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier

# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ddpg import DDPG, DDPG_DEFAULT_CONFIG
Expand All @@ -42,9 +28,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.relu(nn.Dense(512)(inputs["states"]))
Expand All @@ -57,9 +40,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)
Expand Down
20 changes: 0 additions & 20 deletions docs/source/examples/isaaclab/jax_ant_ppo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
"""
Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
* Python 3.7 is only supported up to jax<=0.3.25.
See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
* Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
* The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
* Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
"""

import flax.linen as nn
import jax
import jax.numpy as jnp


jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier

# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
Expand Down Expand Up @@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
Expand All @@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
Expand Down
20 changes: 0 additions & 20 deletions docs/source/examples/isaaclab/jax_ant_sac.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
"""
Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
* Python 3.7 is only supported up to jax<=0.3.25.
See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
* Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
* The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
* Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
"""

import flax.linen as nn
import jax
import jax.numpy as jnp


jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier

# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.sac import SAC, SAC_DEFAULT_CONFIG
Expand All @@ -42,9 +28,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.relu(nn.Dense(512)(inputs["states"]))
Expand All @@ -58,9 +41,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)
Expand Down
20 changes: 0 additions & 20 deletions docs/source/examples/isaaclab/jax_ant_td3.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
"""
Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
* Python 3.7 is only supported up to jax<=0.3.25.
See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
* Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
* The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
* Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
"""

import flax.linen as nn
import jax
import jax.numpy as jnp


jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier

# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.td3 import TD3, TD3_DEFAULT_CONFIG
Expand All @@ -42,9 +28,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.relu(nn.Dense(512)(inputs["states"]))
Expand All @@ -57,9 +40,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)
Expand Down
20 changes: 0 additions & 20 deletions docs/source/examples/isaaclab/jax_cartpole_ppo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
"""
Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
* Python 3.7 is only supported up to jax<=0.3.25.
See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
* Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
* The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
* Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
"""

import flax.linen as nn
import jax
import jax.numpy as jnp


jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier

# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
Expand Down Expand Up @@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(32)(inputs["states"]))
Expand All @@ -59,9 +42,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(32)(inputs["states"]))
Expand Down
20 changes: 0 additions & 20 deletions docs/source/examples/isaaclab/jax_humanoid_ppo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
"""
Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
* Python 3.7 is only supported up to jax<=0.3.25.
See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
* Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
* The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
* Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
"""

import flax.linen as nn
import jax
import jax.numpy as jnp


jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier

# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
Expand Down Expand Up @@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(400)(inputs["states"]))
Expand All @@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(400)(inputs["states"]))
Expand Down
20 changes: 0 additions & 20 deletions docs/source/examples/isaaclab/jax_lift_franka_ppo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
"""
Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
* Python 3.7 is only supported up to jax<=0.3.25.
See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
* Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
* The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
* Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
"""

import flax.linen as nn
import jax
import jax.numpy as jnp


jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier

# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
Expand Down Expand Up @@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
Expand All @@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
Expand Down
20 changes: 0 additions & 20 deletions docs/source/examples/isaaclab/jax_reach_franka_ppo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
"""
Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
* Python 3.7 is only supported up to jax<=0.3.25.
See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
* Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
* The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
* Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
"""

import flax.linen as nn
import jax
import jax.numpy as jnp


jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier

# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
Expand Down Expand Up @@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
Expand All @@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(256)(inputs["states"]))
Expand Down
20 changes: 0 additions & 20 deletions docs/source/examples/isaaclab/jax_velocity_anymal_c_ppo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
"""
Notes for Isaac Sim 2022.2.1 or earlier (Python 3.7 environment):
* Python 3.7 is only supported up to jax<=0.3.25.
See: https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-041-dec-13-2022.
* Builds for jaxlib<=0.3.25 are only available up to NVIDIA CUDA 11 and cuDNN 8.2 versions.
See: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and search for `cuda11/jaxlib-0.3.25+cuda11.cudnn82-cp37-cp37m-manylinux2014_x86_64.whl`.
* The `jax.Device = jax.xla.Device` statement is required by skrl to support jax<0.4.3.
* Models require overloading the `__hash__` method to avoid "TypeError: Failed to hash Flax Module".
"""

import flax.linen as nn
import jax
import jax.numpy as jnp


jax.Device = jax.xla.Device # for Isaac Sim 2022.2.1 or earlier

# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
Expand Down Expand Up @@ -43,9 +29,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(128)(inputs["states"]))
Expand All @@ -60,9 +43,6 @@ def __init__(self, observation_space, action_space, device=None, clip_actions=Fa
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

def __hash__(self): # for Isaac Sim 2022.2.1 or earlier
return id(self)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
x = nn.elu(nn.Dense(128)(inputs["states"]))
Expand Down
Loading

0 comments on commit 50f4a96

Please sign in to comment.