Skip to content

Commit b71c5d7

Browse files
authored
FIX: pin jax=0.6.2 to all workflows (#498)
1 parent 84a6c0c commit b71c5d7

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

.github/workflows/cache.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
shell: bash -l {0}
2525
run: |
2626
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
27-
pip install --upgrade "jax[cuda12-local]"
27+
pip install --upgrade "jax[cuda12-local]==0.6.2"
2828
pip install numpyro
2929
python scripts/test-jax-install.py
3030
- name: Check nvidia drivers

.github/workflows/publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
shell: bash -l {0}
2424
run: |
2525
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
26-
pip install --upgrade "jax[cuda12-local]"
26+
pip install --upgrade "jax[cuda12-local]==0.6.2"
2727
pip install numpyro
2828
python scripts/test-jax-install.py
2929
- name: Check nvidia drivers

0 commit comments

Comments
 (0)