Skip to content

Commit

Permalink
Fixed jax config calls for compatibility with jax 0.4.25
Browse files Browse the repository at this point in the history
  • Loading branch information
swryan committed Feb 26, 2024
1 parent 6b42683 commit 44329e4
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 16 deletions.
10 changes: 0 additions & 10 deletions .github/workflows/openmdao_test_workflow.yml
Expand Up @@ -60,7 +60,6 @@ jobs:
# PAROPT: true
# SNOPT: '7.7'
OPTIONAL: '[all]'
JAX: '0.4.14'
TESTS: true

# test minimal install
Expand Down Expand Up @@ -96,7 +95,6 @@ jobs:
PYOPTSPARSE: 'v2.10.1'
SNOPT: '7.7'
OPTIONAL: '[all]'
JAX: '0.4.14'
BUILD_DOCS: true

runs-on: ${{ matrix.OS }}
Expand Down Expand Up @@ -155,14 +153,6 @@ jobs:
echo "============================================================="
python -m pip install .${{ matrix.OPTIONAL }}
- name: Install jax
if: matrix.JAX
run: |
echo "============================================================="
echo "Install jax"
echo "============================================================="
python -m pip install jaxlib=='${{ matrix.JAX }}' jax=='${{ matrix.JAX }}'
- name: Install PETSc
if: matrix.PETSc
run: |
Expand Down
3 changes: 1 addition & 2 deletions openmdao/components/explicit_func_comp.py
Expand Up @@ -15,8 +15,7 @@
import jax
from jax import jit
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True) # jax by default uses 32 bit floats
jax.config.update("jax_enable_x64", True) # jax by default uses 32 bit floats
except Exception:
_, err, tb = sys.exc_info()
if not isinstance(err, ImportError):
Expand Down
4 changes: 2 additions & 2 deletions openmdao/components/func_comp_common.py
Expand Up @@ -9,17 +9,17 @@

import numpy as np
try:
import jax
from jax import vmap
import jax.numpy as jnp
from jax.config import config
# linear_util moved to jax.extend in jax 0.4.17, previous location is deprecated
try:
from jax.extend import linear_util
except ImportError:
from jax import linear_util
from jax.api_util import argnums_partial
from jax._src.api import _jvp, _vjp
config.update("jax_enable_x64", True) # jax by default uses 32 bit floats
jax.config.update("jax_enable_x64", True) # jax by default uses 32 bit floats
except Exception:
_, err, tb = sys.exc_info()
if not isinstance(err, ImportError):
Expand Down
3 changes: 1 addition & 2 deletions openmdao/components/implicit_func_comp.py
Expand Up @@ -14,8 +14,7 @@
try:
import jax
from jax import jit, jacfwd, jacrev
from jax.config import config
config.update("jax_enable_x64", True) # jax by default uses 32 bit floats
jax.config.update("jax_enable_x64", True) # jax by default uses 32 bit floats
except Exception:
_, err, tb = sys.exc_info()
if not isinstance(err, ImportError):
Expand Down

0 comments on commit 44329e4

Please sign in to comment.