We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 949043e commit 7f7d6a5Copy full SHA for 7f7d6a5
.github/workflows/ci.yml
@@ -31,9 +31,9 @@ jobs:
31
- name: Install JAX, Numpyro, PyTorch
32
shell: bash -l {0}
33
run: |
34
- pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
35
- pip install pyro-ppl
36
- pip install --upgrade "jax[cuda12]"
+ # pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
+ # pip install pyro-ppl
+ pip install "jax[cuda12-local]==0.6.2"
37
pip install numpyro pyro-ppl
38
python scripts/test-jax-install.py
39
- name: Check nvidia Drivers
0 commit comments