diff --git a/.github/workflows/cache.yml b/.github/workflows/cache.yml index 00c94fed2..25e71528f 100644 --- a/.github/workflows/cache.yml +++ b/.github/workflows/cache.yml @@ -20,12 +20,11 @@ jobs: python-version: "3.13" environment-file: environment.yml activate-environment: quantecon - - name: Install JAX, Numpyro, PyTorch + - name: Install JAX and Numpyro shell: bash -l {0} run: | - pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 - pip install --upgrade "jax[cuda12-local]==0.6.2" - pip install numpyro + pip install -U "jax[cuda13]" + pip install numpyro python scripts/test-jax-install.py - name: Check nvidia drivers shell: bash -l {0} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 80e4d9a83..eb1a324d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,17 +28,15 @@ jobs: python-version: "3.13" environment-file: environment.yml activate-environment: quantecon - - name: Install JAX, Numpyro, PyTorch - shell: bash -l {0} - run: | - # pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 - # pip install pyro-ppl - pip install "jax[cuda12-local]==0.6.2" - pip install numpyro pyro-ppl - python scripts/test-jax-install.py - name: Check nvidia Drivers shell: bash -l {0} run: nvidia-smi + - name: Install JAX and Numpyro + shell: bash -l {0} + run: | + pip install -U "jax[cuda13]" + pip install numpyro + python scripts/test-jax-install.py - name: Display Conda Environment Versions shell: bash -l {0} run: conda list diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index e5ca1bd12..3a87a11f6 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -19,12 +19,11 @@ jobs: python-version: "3.13" environment-file: environment.yml activate-environment: quantecon - - name: Install JAX, Numpyro, PyTorch + - name: Install JAX and Numpyro shell: bash -l {0} run: | - pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 - pip install --upgrade "jax[cuda12-local]==0.6.2" - pip install numpyro + pip install -U "jax[cuda13]" + pip install numpyro python scripts/test-jax-install.py - name: Check nvidia drivers shell: bash -l {0}