Skip to content

Commit

Permalink
Dropped support for jax before v0.2.17
Browse files Browse the repository at this point in the history
  • Loading branch information
lumip committed Nov 21, 2022
1 parent bba2e4b commit db47884
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', '3.11']
cuda-setup: [
# [cuda-version, cuda-link, gcc-version]
[10.0.130, "https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux", 7],
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/jax_compatibility_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ jobs:
fail-fast: false
matrix:
jax-version: [
0.2.12, 0.2.13, 0.2.14, 0.2.15, 0.2.16,
0.2.17, 0.2.18, 0.2.19, 0.2.20, 0.2.21,
0.2.22, 0.2.27, 0.3.1, 0.3.13, 0.3.15,
0.3.17, 0.3.23, 0.3.25
Expand All @@ -36,7 +35,7 @@ jobs:
python -m pip install pytest
- name: Install dependencies
run: |
python -m pip install "jax[minimum-jaxlib]==${{ matrix.jax-version }}"
python -m pip install "jax[cpu]==${{ matrix.jax-version }}"
python -m pip install .
- name: Test with pytest
run: |
Expand Down
1 change: 1 addition & 0 deletions ChangeLog.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
- HEAD:
- Added: support for jax up to v0.3.25
- Added: support for Python v3.11
- Removed: support for jax before v0.2.17
- 2.0.0-rc.3:
- Fix: OpenMP is no longer a strict requirement for installation.
- Added: chacha.native.openmp_accelerated, returns True if CPU kernels are parallelised using OpenMP.
Expand Down
11 changes: 1 addition & 10 deletions chacha/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,7 @@
from chacha import defs

# importing canonicalize_shape function
try:
# pre jax v0.2.14 location
_canonicalize_shape = jax.abstract_arrays.canonicalize_shape # type: ignore
except (AttributeError, ImportError): # pragma: no cover
# post jax v0.2.14 location
try:
_canonicalize_shape = jax.core.canonicalize_shape # type: ignore
except (AttributeError, ImportError): # pragma: no cover
raise ImportError("Cannot import canonicalize_shape routine. "
"You are probably using an incompatible version of jax.")
_canonicalize_shape = jax.core.canonicalize_shape # type: ignore

# importing _UINT_DTYPES
try:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def build_extension(self, ext):
version_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(version_module)

_jax_version_lower_constraint = ' >= 0.2.12'
_jax_version_lower_constraint = ' >= 0.2.17'
_jax_version_optimistic_upper_constraint = ', <= 2.0.0'
_jax_version_upper_constraint = ', <= 0.3.25'

Expand All @@ -94,7 +94,7 @@ def build_extension(self, ext):
],
extras_require={
"tests": [
f"jax[minimum-jaxlib]",
f"jax[cpu]",
"pytest"
],
"compatible-jax": [f"jax{_jax_version_lower_constraint}{_jax_version_upper_constraint}"]
Expand Down

0 comments on commit db47884

Please sign in to comment.