diff --git a/.github/workflows/openmdao_test_workflow.yml b/.github/workflows/openmdao_test_workflow.yml index e8e7bf6943..dbaa4c2619 100644 --- a/.github/workflows/openmdao_test_workflow.yml +++ b/.github/workflows/openmdao_test_workflow.yml @@ -60,7 +60,6 @@ jobs: # PAROPT: true # SNOPT: '7.7' OPTIONAL: '[all]' - JAX: '0.4.14' TESTS: true # test minimal install @@ -96,7 +95,6 @@ jobs: PYOPTSPARSE: 'v2.10.1' SNOPT: '7.7' OPTIONAL: '[all]' - JAX: '0.4.14' BUILD_DOCS: true runs-on: ${{ matrix.OS }} @@ -155,14 +153,6 @@ jobs: echo "=============================================================" python -m pip install .${{ matrix.OPTIONAL }} - - name: Install jax - if: matrix.JAX - run: | - echo "=============================================================" - echo "Install jax" - echo "=============================================================" - python -m pip install jaxlib=='${{ matrix.JAX }}' jax=='${{ matrix.JAX }}' - - name: Install PETSc if: matrix.PETSc run: | diff --git a/openmdao/components/explicit_func_comp.py b/openmdao/components/explicit_func_comp.py index 833fc03630..f3be776247 100644 --- a/openmdao/components/explicit_func_comp.py +++ b/openmdao/components/explicit_func_comp.py @@ -15,8 +15,7 @@ import jax from jax import jit import jax.numpy as jnp - from jax.config import config - config.update("jax_enable_x64", True) # jax by default uses 32 bit floats + jax.config.update("jax_enable_x64", True) # jax by default uses 32 bit floats except Exception: _, err, tb = sys.exc_info() if not isinstance(err, ImportError): diff --git a/openmdao/components/func_comp_common.py b/openmdao/components/func_comp_common.py index 0c896bc06a..64e2b669d7 100644 --- a/openmdao/components/func_comp_common.py +++ b/openmdao/components/func_comp_common.py @@ -9,9 +9,9 @@ import numpy as np try: + import jax from jax import vmap import jax.numpy as jnp - from jax.config import config # linear_util moved to jax.extend in jax 0.4.17, previous location is deprecated try: from jax.extend import linear_util @@ -19,7 +19,7 @@ from jax import linear_util from jax.api_util import argnums_partial from jax._src.api import _jvp, _vjp - config.update("jax_enable_x64", True) # jax by default uses 32 bit floats + jax.config.update("jax_enable_x64", True) # jax by default uses 32 bit floats except Exception: _, err, tb = sys.exc_info() if not isinstance(err, ImportError): diff --git a/openmdao/components/implicit_func_comp.py b/openmdao/components/implicit_func_comp.py index 555fb65050..a263a27a95 100644 --- a/openmdao/components/implicit_func_comp.py +++ b/openmdao/components/implicit_func_comp.py @@ -14,8 +14,7 @@ try: import jax from jax import jit, jacfwd, jacrev - from jax.config import config - config.update("jax_enable_x64", True) # jax by default uses 32 bit floats + jax.config.update("jax_enable_x64", True) # jax by default uses 32 bit floats except Exception: _, err, tb = sys.exc_info() if not isinstance(err, ImportError):