-
-
Notifications
You must be signed in to change notification settings - Fork 13.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jaxlibWithCuda
on 23.11 does not use CUDA
#282184
Comments
Let me know if I can provide more useful information. Any suggestion on how I can continue debugging this would also be very helpful. Thanks in advance! |
hmm interesting... ooc what gpu are you trying to use? have you confirmed that the cuda capabilities specified in your flake are compatible with your gpu? |
Also please gist the outputs of |
On this machine it is a 3060 Ti. I have been using |
Thanks. Let me put the stderr log here: The log is very huge. I tried to search for |
I think it is trying to load from I tend to think the above is the GPU-enabled |
Also just to confirm: what's the output of echo $PYTHONPATH | awk 'BEGIN{RS=":"}{print}' | grep "jaxlib" ? Sometimes multiple jaxlib versions can get pulled in as transitive dependencies (though they shouldn't if everything is packaged correctly). Then python can end up loading jaxlib-without-cuda instead of jaxlib-with-cuda depending on the ordering in PYTHONPATH. From the looks of your flake this shouldn't be the case, but I'm guessing that this is a reproduction distilled from a more complex flake/project? |
Yes this is distilled from a real world flake.
It seems that in my dev shell import sys
print(sys.path) produces: ['', '/nix/store/p1zbyfrpj3hq50mxh5hmxl3kqpa2b1am-python3-3.11.6/lib/python311.zip', '/nix/store/p1zbyfrpj3hq50mxh5hmxl3kqpa2b1am-python3-3.11.6/lib/python3.11', '/nix/store/p1zbyfrpj3hq50mxh5hmxl3kqpa2b1am-python3-3.11.6/lib/python3.11/lib-dynload', '/nix/store/p1zbyfrpj3hq50mxh5hmxl3kqpa2b1am-python3-3.11.6/lib/python3.11/site-packages', '/nix/store/kwyj30nr4g40skc2m4nw0v2h1hh3miw5-python3-3.11.6-env/lib/python3.11/site-packages'] I looked into the breakds@samaritan {~/projects/nixvital.org/ml-pkgs} $ lsd /nix/store/kwyj30nr4g40skc2m4nw0v2h1hh3miw5-python3-3.11.6-env/lib/python3.11/site-packages
__pycache__ flatbuffers-23.5.26.dist-info ml_dtypes opt_einsum-3.3.0.dist-info six-1.16.0.dist-info
_sysconfigdata__linux_x86_64-linux-gnu.py jax ml_dtypes-0.3.1.dist-info README.txt six.py
absl jax-0.4.20.dist-info numpy scipy
absl_py-1.4.0.dist-info jaxlib numpy-1.26.1.dist-info scipy-1.11.3.dist-info
flatbuffers jaxlib-0.4.20.dist-info opt_einsum sitecustomize.py Apparently there is only one |
Ok yeah that looks fine.. hmm |
Btw, I tried two more things
|
Attempting to repro this but currently blocked on nix-community/nixGL#157. |
Interesting. I had an attempt to run my nix-based ML project on an Ubuntu machine but wasn't able to get the nix dev shell to recognize the GPU a while ago. Have been using NixOS for development since then ... |
Oh yeah it definitely works... mostly. This is the first real issue I've had with nixGL haha |
Will give it a try! Meanwhile, I can try out any thoughts/ideas you might have to debug. Thanks! |
Yeah my hope upon getting a repro is to bisect nixpkgs commits to find a known-good commit and then hopefully isolate the issue in a commit range. If you're feeling up to the task it would be quite useful! @SomeoneSerge @ConnorBaker This would be a great use case for GPU-enabled tests and some CI infrastructure. tl;dr is that |
Got it. This looks like something I can try. It would probably take a while as I'd imagine CUDA will be built again and again. |
Let me try to build from a few commits on |
Tried f195a5e7fad77b5128ebfaba0a6112cd9ddca0d2 from about 3 months ago. Got different behavior: Python 3.10.12 (main, Jun 6 2023, 22:43:10) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705820658.921836 1555874 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
>>>
I0000 00:00:1705820728.067464 1555874 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed. An older version 083e133 has similar issue: Python 3.10.12 (main, Jun 6 2023, 22:43:10) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)] |
Also confirmed that 5853814 can reproduce the probem: Python 3.11.5 (main, Aug 24 2023, 12:23:19) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)] Similarly, f697235 from 2023-12-01 can reproduce the issue: Python 3.11.6 (main, Oct 2 2023, 13:45:54) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)] |
The above seems to suggest that, it is likely that in the past 3 months |
I found the a successful commit on Please note that this is not a commit from >>> import jax
>>> jax.devices()
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
>>> import jax.lib
>>> backend = jax.lib.xla_bridge.get_backend()
>>> backend.platform
'gpu'
>>> backend.platform_version
'cuda 11070' |
Nice! Well I guess we have a (rough) commit window now to bisect on. Btw, I'm curious: does |
Let me try that. My understanding is that
And for end users of |
Tried >>> import jaxlib.xla_client
2024-01-22 16:27:55.773987: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. |
Digged a little bit into this. On the successful branch 23.05 with CUDA 11.7: >>> import jaxlib
>>> import jaxlib.xla_client
>>> jaxlib.xla_client.make_gpu_client(platform_name="cuda")
2024-01-22 17:11:18.541943: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-01-22 17:11:18.542124: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0xabbb80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-01-22 17:11:18.542155: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177] StreamExecutor device (0): NVIDIA GeForce RTX 3060 Ti, Compute Capability 8.6
2024-01-22 17:11:18.542372: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.cc:198] Using BFC allocator.
2024-01-22 17:11:18.542407: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 6270959616 bytes on device 0 for BFCAllocator.
<jaxlib.xla_extension.Client object at 0x7ffee56837b0> while on the failed branch 23.11, >>> import jaxlib
>>> import jaxlib.xla_client
>>> jaxlib.xla_client.make_gpu_client(platform_name="cuda")
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/nix/store/1z91x9lrhrv195vbhl4962qclmnmnwhv-python3-3.11.6-env/lib/python3.11/site-packages/jaxlib/xla_client.py", line 88, in make_gpu_client
config = _xla.GpuAllocatorConfig()
^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' This seems to suggest that the |
Yes, that's exactly right!
Nice find! Huh, very weird. We're building with all the GPU/CUDA flags turned on AFAIK. Also, it's surprising to me that jaxlib-bin isn't working on your system since upstream builds that and we make only minor modifications |
I checked their build script and found that there is a bazel flag called
Sorry if I did not make it clear. |
More on >> import jaxlib.xla_client
2024-01-23 13:51:38.735183: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
>>> jaxlib.xla_client.make_gpu_client(platform_name="cuda")
2024-01-23 13:54:38.390979: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must b
e at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-01-23 13:54:38.391106: I external/xla/xla/service/service.cc:168] XLA service 0xb9f6f0 initialized for platform CUDA (this does not guarantee that XLA will be used
). Devices:
2024-01-23 13:54:38.391129: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): NVIDIA GeForce RTX 3080, Compute Capability 8.6
2024-01-23 13:54:38.391319: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:738] Using BFC allocator.
2024-01-23 13:54:38.391342: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 7869579264 bytes on device 0 for BFCAllocator.
<jaxlib.xla_extension.Client object at 0x7ffee6a335b0> This is The >>> import jaxlib.xla_client
2024-01-23 14:10:18.300076: I external/tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used. |
I am working on the fix for
|
Another approach, IMHO a little easier but totally up to you:
Then you can run this for That way you can run/test locally without committing changes. |
The PR is created at #285037 |
I am closing this as at least the |
Fixing the *-bin version is great, but let's keep this open until the source version is fixed as well |
FWIW, the Guix version (there is only a source version) is here: https://github.com/guix-science/guix-science-nonfree/blob/master/guix-science-nonfree/packages/machine-learning.scm#L216 The definition of the CUDA-less version that it references is here: Maybe there's an obvious difference somewhere in the arguments. |
Ok I think I've found the culprit. Somewhere along the way we removed Working on a fix now... Currently blocked on jax-ml/jax#19811 (comment). AFAICT we do not package CUB and it is not provided in any of the cudaPackages that I've consulted thus far. If anyone has any ideas on the best way to proceed, do let me know! |
Pretty sure it's part of |
Ah, thanks for the tip @SomeoneSerge ! I wasn't able to find
Perhaps it has a different attribute name? Assuming a typo, I checked nccl which is already a dependency of jaxlib, but to no avail:
Perhaps I'm missing something? Btw, draft PR started here: #288857 |
Oop, found
|
Not "cublock_load.cuh"?
|
|
Adding |
Looking at your issue upstream, I think the names match? |
Oop, yeah you're right. I was confusing that with something else |
This change snowballed to cover a number of improvements: 1. Fix NixOS#282184. The `--config=cuda` flag was lost at some point, disabling CUDA builds even with `cudaSupport = true`. 2. Migrate to using the CUDA redist packages instead of cudatoolkit. 3. Unify stdenv behind `effectiveStdenv` following a pattern that has precedent in OpenCV's derivation and was recommended in NixOS#288857 (comment).
This change snowballed to cover a number of improvements: 1. Fix NixOS#282184. The `--config=cuda` flag was lost at some point, disabling CUDA builds even with `cudaSupport = true`. 2. Migrate to using the CUDA redist packages instead of cudatoolkit. 3. Unify stdenv behind `effectiveStdenv` following a pattern that has precedent in OpenCV's derivation and was recommended in NixOS#288857 (comment).
This change snowballed to cover a number of improvements: 1. Fix NixOS#282184. The `--config=cuda` flag was lost at some point, disabling CUDA builds even with `cudaSupport = true`. 2. Migrate to using the CUDA redist packages instead of cudatoolkit. 3. Unify stdenv behind `effectiveStdenv` following a pattern that has precedent in OpenCV's derivation and was recommended in NixOS#288857 (comment).
This change snowballed to cover a number of improvements: 1. Fix NixOS#282184. The `--config=cuda` flag was lost at some point, disabling CUDA builds even with `cudaSupport = true`. 2. Migrate to using the CUDA redist packages instead of cudatoolkit. 3. Unify stdenv behind `effectiveStdenv` following a pattern that has precedent in OpenCV's derivation and was recommended in NixOS#288857 (comment).
python3Packages.jaxlib: fix #282184 and migrate to cuda redist packages
This issue has been mentioned on NixOS Discourse. There might be relevant details there: https://discourse.nixos.org/t/jaxlibwithcuda-not-using-cuda/36873/5 |
If you're using NixOS is there a platform wide config to always have the cuda libraries enabled? |
I'm trying to compile jax using the commit [user@system:~/jax]$ nix-shell
[...snip...]
INFO: Analyzed 2 targets (240 packages loaded, 20380 targets configured).
checking cached actions
INFO: Found 2 targets...
INFO: Elapsed time: 115.749s, Critical Path: 0.00s
INFO: 0 processes.
INFO: Build completed successfully, 0 total actions
buildPhase completed in 1 minutes 56 seconds
Running phase: installPhase
installPhase completed in 49 seconds
error: hash mismatch in fixed-output derivation '/nix/store/yjm7vri9hzasddanrjhzd4fi1a9406ip-bazel-build-jaxlib-0.4.24-deps.tar.gz.drv':
specified: sha256-IEKoHjCOtKZKvU/DUUjbvXldORFJuyO1R3F6CZZDXxM=
got: sha256-DLx5NHoSP9cOCf6hBgLXqNDb/XhJ3OWxA8/HGiTBBUo=
error: 1 dependencies of derivation '/nix/store/dsd7ssdqbmi82fvj1fcf8pnzv95wf6c4-bazel-build-jaxlib-0.4.24.drv' failed to build
error: 1 dependencies of derivation '/nix/store/yivr8ykzvg43g33hlm87f400568vd9ly-python3.11-jaxlib-0.4.24.drv' failed to build
error: 1 dependencies of derivation '/nix/store/vhr8syx8h5s3j18bp4kw9ifr3wviqd8p-python3.11-jax-0.4.24.drv' failed to build I even tried to use a later commit [user@system:~/jax]$ TF_CPP_MIN_LOG_LEVEL=0 python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)"
2024-06-17 06:08:48.914216: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-17 06:08:48.970135: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-17 06:08:48.970677: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-17 06:08:49.123745: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
CUDA backend failed to initialize: Unable to load CUDA. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
cpu This is my shell.nix: { pkgs ? import ./pkgs.nix { inherit builtins; } }:
pkgs.mkShell {
name = "jax-gpu";
buildInputs = with pkgs; with pkgs.python3Packages; with pkgs.cudaPackages; [
python3
jax
jaxlib-bin
numpy
cuda_cccl.dev
];
} And this is my pkgs.nix (for the newer commit): { builtins }:
let
revision = "d97b37430f8f0262f97f674eb357551b399c2003";
url = "https://github.com/nixos/nixpkgs/archive/${revision}.tar.gz";
nixpkgs = builtins.fetchTarball {
url = url;
sha256 = "0za8bwgvk83r18z1b7ll1hlh45ijimp388ll5r8z44yglfcp28gx";
};
in
import nixpkgs {
config = {
allowUnfree = true;
cudaSupport = true;
};
} |
This seems like a support question, you're more likely to find help on NixOS Discourse or Matrix. But the answer is
Interesting, this might be a reproducibility issue with the way we fetch bazel dependencies, in which case this deserves a separate issue |
Describe the bug
When creating a dev shell with
jaxlibWithCuda
+jax
, the resulting environment cannot use CUDA withjax
.Steps To Reproduce
Steps to reproduce the behavior:
Create a
flake.nix
Note that I am using the
d2003f2223cbb8cd95134e4a0541beea215c1073
commit from nixos-23.11, which is the latest at this moment. However, I have tried the commit from a month ago and the problem was there as well.Now, run
nix develop
with the above flake.Run
python
andNote that it reports
NVIDIA GPU
is found, butjaxlib
does not seem to have CUDA enabled.The above
flake.nix
can be found here.Expected behavior
It is expected that
jax
andjaxlib
run with CUDA and GPU appears in the device list.Additional context
I am running this on my NixOS machine with a Nvidia GPU, where nvidia-driver is properly installed. I am also an intensive
pytorch
user and it works perfectly.Also I vaguely remember such combination works correctly a few months ago.
Notify maintainers
I am not particularly sure about the maintainers so I might be wrong. Based on
jax
's commit history and meta I am goingt to @SomeoneSerge and @samuelaMetadata
Please run
nix-shell -p nix-info --run "nix-info -m"
and paste the result.The text was updated successfully, but these errors were encountered: