Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax not recognizing GPU. #21240

Open
charlie-guan opened this issue May 15, 2024 · 2 comments
Open

Jax not recognizing GPU. #21240

charlie-guan opened this issue May 15, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@charlie-guan
Copy link

charlie-guan commented May 15, 2024

Description

I am trying to reproduce the study of this work from Google DeepMind by running Jax on NVIDIA GPU (Driver: 550.67) and CUDA (12.4), but it returns

"No GPU/TPU found, falling back to CPU."

I tried bumping up the jax and jaxlib versiosn to 0.4.28 (the latest version) from 0.4.16 (the version listed in requirements.txt) and also upgraded flax to 0.8.3 from 0.7.4. These changes eliminate the warning and Jax seems to recognize the GPU devices but the computation is still very slow. How do I fix this?

System info (python version, jaxlib version, accelerator, etc.)

$ python3 -c "import jax; jax.print_environment_info()"
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.0
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (10 total, 10 local): [cuda(id=0) cuda(id=1) ... cuda(id=8) cuda(id=9)]
process_count: 1
platform: uname_result(system='Linux', node='nebula', release='5.15.0-107-generic', version='#117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024', machine='x86_64')

$ nvidia-smi
Wed May 15 09:53:43 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.67 Driver Version: 550.67 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX A6000 Off | 00000000:1A:00.0 Off | Off |
| 30% 37C P2 73W / 300W | 37062MiB / 49140MiB | 17% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA RTX A6000 Off | 00000000:1B:00.0 Off | Off |
| 30% 50C P2 135W / 300W | 39244MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA RTX A6000 Off | 00000000:1C:00.0 Off | Off |
| 30% 39C P2 80W / 300W | 16274MiB / 49140MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA RTX A6000 Off | 00000000:1D:00.0 Off | Off |
| 30% 56C P2 146W / 300W | 48601MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA RTX A6000 Off | 00000000:1E:00.0 Off | Off |
| 30% 57C P2 149W / 300W | 48037MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA RTX A6000 Off | 00000000:3D:00.0 Off | Off |
| 30% 49C P2 120W / 300W | 46176MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA RTX A6000 Off | 00000000:3E:00.0 Off | Off |
| 30% 52C P2 158W / 300W | 47214MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA RTX A6000 Off | 00000000:3F:00.0 Off | Off |
| 30% 57C P2 117W / 300W | 47026MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 8 NVIDIA RTX A6000 Off | 00000000:40:00.0 Off | Off |
| 30% 56C P2 159W / 300W | 32440MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 9 NVIDIA RTX A6000 Off | 00000000:41:00.0 Off | Off |
| 30% 53C P2 123W / 300W | 47000MiB / 49140MiB | 100% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 185733 C python3 36788MiB |
| 0 N/A N/A 191989 C python3 262MiB |
| 1 N/A N/A 190770 C ...liu/anaconda3/envs/gemma/bin/python 38972MiB |
| 1 N/A N/A 191989 C python3 262MiB |
| 2 N/A N/A 176896 C python3 16000MiB |
| 2 N/A N/A 191989 C python3 262MiB |
| 3 N/A N/A 176896 C python3 262MiB |
| 3 N/A N/A 190769 C ...liu/anaconda3/envs/gemma/bin/python 48062MiB |
| 3 N/A N/A 191989 C python3 262MiB |
| 4 N/A N/A 176896 C python3 262MiB |
| 4 N/A N/A 190772 C ...liu/anaconda3/envs/gemma/bin/python 47498MiB |
| 4 N/A N/A 191989 C python3 262MiB |
| 5 N/A N/A 190771 C ...liu/anaconda3/envs/gemma/bin/python 45904MiB |
| 5 N/A N/A 191989 C python3 262MiB |
| 6 N/A N/A 190773 C ...liu/anaconda3/envs/gemma/bin/python 46942MiB |
| 6 N/A N/A 191989 C python3 262MiB |
| 7 N/A N/A 190774 C ...liu/anaconda3/envs/gemma/bin/python 46754MiB |
| 7 N/A N/A 191989 C python3 262MiB |
| 8 N/A N/A 190767 C ...liu/anaconda3/envs/gemma/bin/python 32360MiB |
| 8 N/A N/A 191989 C python3 262MiB |
| 9 N/A N/A 190768 C ...liu/anaconda3/envs/gemma/bin/python 46728MiB |
| 9 N/A N/A 191989 C python3 262MiB |
+-----------------------------------------------------------------------------------------+

@charlie-guan charlie-guan added the bug Something isn't working label May 15, 2024
@hawkinsp
Copy link
Member

Hmm. I'm not sure this is an actionable report. You upgraded and the original problem was fixed, it seems?

You say the model is running slowly. Can you say more? Knowing nothing about that particular model are you seeing different performance characteristics to the original model authors? How so?

@charlie-guan charlie-guan changed the title Jax not recognizing TPU. Jax not recognizing GPU. May 15, 2024
@charlie-guan
Copy link
Author

I resolved this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants