You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to reproduce the study of this work from Google DeepMind by running Jax on NVIDIA GPU (Driver: 550.67) and CUDA (12.4), but it returns
"No GPU/TPU found, falling back to CPU."
I tried bumping up the jax and jaxlib versiosn to 0.4.28 (the latest version) from 0.4.16 (the version listed in requirements.txt) and also upgraded flax to 0.8.3 from 0.7.4. These changes eliminate the warning and Jax seems to recognize the GPU devices but the computation is still very slow. How do I fix this?
System info (python version, jaxlib version, accelerator, etc.)
Hmm. I'm not sure this is an actionable report. You upgraded and the original problem was fixed, it seems?
You say the model is running slowly. Can you say more? Knowing nothing about that particular model are you seeing different performance characteristics to the original model authors? How so?
Description
I am trying to reproduce the study of this work from Google DeepMind by running Jax on NVIDIA GPU (Driver: 550.67) and CUDA (12.4), but it returns
"No GPU/TPU found, falling back to CPU."
I tried bumping up the jax and jaxlib versiosn to 0.4.28 (the latest version) from 0.4.16 (the version listed in requirements.txt) and also upgraded flax to 0.8.3 from 0.7.4. These changes eliminate the warning and Jax seems to recognize the GPU devices but the computation is still very slow. How do I fix this?
System info (python version, jaxlib version, accelerator, etc.)
$ python3 -c "import jax; jax.print_environment_info()"
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.0
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (10 total, 10 local): [cuda(id=0) cuda(id=1) ... cuda(id=8) cuda(id=9)]
process_count: 1
platform: uname_result(system='Linux', node='nebula', release='5.15.0-107-generic', version='#117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024', machine='x86_64')
$ nvidia-smi
Wed May 15 09:53:43 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.67 Driver Version: 550.67 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX A6000 Off | 00000000:1A:00.0 Off | Off |
| 30% 37C P2 73W / 300W | 37062MiB / 49140MiB | 17% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA RTX A6000 Off | 00000000:1B:00.0 Off | Off |
| 30% 50C P2 135W / 300W | 39244MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA RTX A6000 Off | 00000000:1C:00.0 Off | Off |
| 30% 39C P2 80W / 300W | 16274MiB / 49140MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA RTX A6000 Off | 00000000:1D:00.0 Off | Off |
| 30% 56C P2 146W / 300W | 48601MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA RTX A6000 Off | 00000000:1E:00.0 Off | Off |
| 30% 57C P2 149W / 300W | 48037MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA RTX A6000 Off | 00000000:3D:00.0 Off | Off |
| 30% 49C P2 120W / 300W | 46176MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA RTX A6000 Off | 00000000:3E:00.0 Off | Off |
| 30% 52C P2 158W / 300W | 47214MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA RTX A6000 Off | 00000000:3F:00.0 Off | Off |
| 30% 57C P2 117W / 300W | 47026MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 8 NVIDIA RTX A6000 Off | 00000000:40:00.0 Off | Off |
| 30% 56C P2 159W / 300W | 32440MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 9 NVIDIA RTX A6000 Off | 00000000:41:00.0 Off | Off |
| 30% 53C P2 123W / 300W | 47000MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 185733 C python3 36788MiB |
| 0 N/A N/A 191989 C python3 262MiB |
| 1 N/A N/A 190770 C ...liu/anaconda3/envs/gemma/bin/python 38972MiB |
| 1 N/A N/A 191989 C python3 262MiB |
| 2 N/A N/A 176896 C python3 16000MiB |
| 2 N/A N/A 191989 C python3 262MiB |
| 3 N/A N/A 176896 C python3 262MiB |
| 3 N/A N/A 190769 C ...liu/anaconda3/envs/gemma/bin/python 48062MiB |
| 3 N/A N/A 191989 C python3 262MiB |
| 4 N/A N/A 176896 C python3 262MiB |
| 4 N/A N/A 190772 C ...liu/anaconda3/envs/gemma/bin/python 47498MiB |
| 4 N/A N/A 191989 C python3 262MiB |
| 5 N/A N/A 190771 C ...liu/anaconda3/envs/gemma/bin/python 45904MiB |
| 5 N/A N/A 191989 C python3 262MiB |
| 6 N/A N/A 190773 C ...liu/anaconda3/envs/gemma/bin/python 46942MiB |
| 6 N/A N/A 191989 C python3 262MiB |
| 7 N/A N/A 190774 C ...liu/anaconda3/envs/gemma/bin/python 46754MiB |
| 7 N/A N/A 191989 C python3 262MiB |
| 8 N/A N/A 190767 C ...liu/anaconda3/envs/gemma/bin/python 32360MiB |
| 8 N/A N/A 191989 C python3 262MiB |
| 9 N/A N/A 190768 C ...liu/anaconda3/envs/gemma/bin/python 46728MiB |
| 9 N/A N/A 191989 C python3 262MiB |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered: