Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compat option for jax devices #320

Merged
merged 2 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install python dependecies
run: |
python -m pip install --upgrade pip
pip install ase dill "dm-haiku<0.0.11" "e3nn-jax!=0.20.4" "jax<=0.4.15" "jaxlib<=0.4.15" "jax-md>=0.2.7" jaxopt pytest matplotlib
pip install ase dill "dm-haiku<0.0.11" "e3nn-jax!=0.20.4" "jax<=0.4.15" "jaxlib<=0.4.15" "jax-md>=0.2.7" jaxopt pytest matplotlib "scipy<1.13"

- name: Install pysages
run: pip install .
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
furo
jaxlib
myst-parser
scipy<1.13
setuptools-scm
sphinx
sphinx-copybutton
Expand Down
12 changes: 10 additions & 2 deletions pysages/methods/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@
from pysages.methods.restraints import canonicalize
from pysages.methods.utils import ReplicasConfiguration
from pysages.typing import Callable, Optional, Union
from pysages.utils import ToCPU, copy, dispatch, dispatch_table, has_method, identity
from pysages.utils import (
ToCPU,
copy,
device_platform,
dispatch,
dispatch_table,
has_method,
identity,
)

# Base Classes
# ============
Expand Down Expand Up @@ -392,7 +400,7 @@ def _run( # noqa: F811 # pylint: disable=C0116,E0102
context_args["context"] = sampling_context.context
sampler = sampling_context.sampler
prev_snapshot = result.snapshots
if sampler.state.xi.device().platform == "cpu":
if device_platform(sampler.state.xi) == "cpu":
prev_snapshot = copy(prev_snapshot, ToCPU())
sampler.restore(prev_snapshot)
sampler.state = result.states
Expand Down
1 change: 1 addition & 0 deletions pysages/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .compat import (
check_device_array,
device_platform,
dispatch_table,
has_method,
is_generic_subclass,
Expand Down
15 changes: 15 additions & 0 deletions pysages/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ def prod(iterable, start=1):
return result


# Compatibility for jax >=0.4.27

# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0427-may-7-2024

if _jax_version_tuple < (0, 4, 27):

def device_platform(array):
return array.device().platform

else:

def device_platform(array):
return next(iter(array.devices())).platform


# Compatibility for jax >=0.4.1

# https://github.com/google/jax/releases/tag/jax-v0.4.1
Expand Down
Loading