Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxlibWithCuda on 23.11 does not use CUDA #282184

Closed
breakds opened this issue Jan 20, 2024 · 71 comments · Fixed by #288857
Closed

jaxlibWithCuda on 23.11 does not use CUDA #282184

breakds opened this issue Jan 20, 2024 · 71 comments · Fixed by #288857

Comments

@breakds
Copy link
Contributor

breakds commented Jan 20, 2024

Describe the bug

When creating a dev shell with jaxlibWithCuda + jax, the resulting environment cannot use CUDA with jax.

Steps To Reproduce

Steps to reproduce the behavior:

  1. Create a flake.nix

    {
      description = "Provide extra Nix packages for Machine Learning and Data Science";
      inputs = {
        nixpkgs.url = "github:NixOS/nixpkgs/nixos-23.11";
    
        utils.url = "github:numtide/flake-utils";
      };
      outputs = { self, nixpkgs, ... }@inputs: inputs.utils.lib.eachSystem [
        "x86_64-linux"
      ] (system:
        let pkgs = import nixpkgs {
              inherit system;
              config = {
                allowUnfree = true;
                cudaSupport = true;
                cudaCapabilities = [ "7.5" "8.6" ];
                cudaForwardCompat = false;
              };
            };
        in rec {
          devShells.default = pkgs.mkShell {
            name = "jax-dev";
            packages = [
              (pkgs.python3.withPackages (py-pkgs: with py-pkgs; [
                jax
                jaxlibWithCuda
              ]))
            ];
          };
        });
    }

    Note that I am using the d2003f2223cbb8cd95134e4a0541beea215c1073 commit from nixos-23.11, which is the latest at this moment. However, I have tried the commit from a month ago and the problem was there as well.

  2. Now, run nix develop with the above flake.

  3. Run python and

    Python 3.11.6 (main, Oct  2 2023, 13:45:54) [GCC 12.3.0] on linux
    Type "help", "copyright", "credits" or "license" for more information.
    >>> import jax
    >>> jax.devices()
    An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
    [CpuDevice(id=0)]

    Note that it reports NVIDIA GPU is found, but jaxlib does not seem to have CUDA enabled.

The above flake.nix can be found here.

Expected behavior

It is expected that jax and jaxlib run with CUDA and GPU appears in the device list.

Additional context

I am running this on my NixOS machine with a Nvidia GPU, where nvidia-driver is properly installed. I am also an intensive pytorch user and it works perfectly.

Also I vaguely remember such combination works correctly a few months ago.

Notify maintainers

I am not particularly sure about the maintainers so I might be wrong. Based on jax's commit history and meta I am goingt to @SomeoneSerge and @samuela

Metadata

Please run nix-shell -p nix-info --run "nix-info -m" and paste the result.

$ nix-shell -p nix-info --run "nix-info -m"
 - system: `"x86_64-linux"`
 - host os: `Linux 5.15.79, NixOS, 23.11 (Tapir), 23.11.20240105.c1be43e`
 - multi-user?: `yes`
 - sandbox: `yes`
 - version: `nix-env (Nix) 2.18.1`
 - nixpkgs: `/nix/var/nix/profiles/per-user/root/channels/nixos`
@breakds
Copy link
Contributor Author

breakds commented Jan 20, 2024

Let me know if I can provide more useful information. Any suggestion on how I can continue debugging this would also be very helpful.

Thanks in advance!

@samuela
Copy link
Member

samuela commented Jan 20, 2024

hmm interesting... ooc what gpu are you trying to use? have you confirmed that the cuda capabilities specified in your flake are compatible with your gpu?

@SomeoneSerge
Copy link
Contributor

Also please gist the outputs of LD_DEBUG=libs python ... demonstrating which implementation of jaxlib ends up being loaded, and whether a search for libcuda.so is attempted

@breakds
Copy link
Contributor Author

breakds commented Jan 20, 2024

hmm interesting... ooc what gpu are you trying to use? have you confirmed that the cuda capabilities specified in your flake are compatible with your gpu?

On this machine it is a 3060 Ti. I have been using torchWithCuda on this machine, with the samilar dev Shell, so I can confirm the cuda capabilities specified should work.

@breakds
Copy link
Contributor Author

breakds commented Jan 20, 2024

Also please gist the outputs of LD_DEBUG=libs python ... demonstrating which implementation of jaxlib ends up being loaded, and whether a search for libcuda.so is attempted

Thanks. Let me put the stderr log here:

jax.log

The log is very huge. I tried to search for libcuda.so and it did not come back with anything. I am trying to see how jaxlib is loaded from it.

@breakds
Copy link
Contributor Author

breakds commented Jan 20, 2024

I think it is trying to load from /nix/store/28r0fhcqqykrdfx5kqzwzqqjakn92n12-python3.11-jaxlib-0.4.20/, but I am not sure whether it is the GPU-enabled one or not.

I tend to think the above is the GPU-enabled jaxlib because if I specify jaxWithoutCuda in the dev shell, it loads /nix/store/pbr3712q0rc0aaw1yrw2mlgh690qic9c-python3.11-jaxlib-0.4.20 instead.

@samuela
Copy link
Member

samuela commented Jan 20, 2024

Also just to confirm: what's the output of

echo $PYTHONPATH | awk 'BEGIN{RS=":"}{print}' | grep "jaxlib"

? Sometimes multiple jaxlib versions can get pulled in as transitive dependencies (though they shouldn't if everything is packaged correctly). Then python can end up loading jaxlib-without-cuda instead of jaxlib-with-cuda depending on the ordering in PYTHONPATH.

From the looks of your flake this shouldn't be the case, but I'm guessing that this is a reproduction distilled from a more complex flake/project?

@breakds
Copy link
Contributor Author

breakds commented Jan 20, 2024

From the looks of your flake this shouldn't be the case, but I'm guessing that this is a reproduction distilled from a more complex flake/project?

Yes this is distilled from a real world flake.

Sometimes multiple jaxlib versions can get pulled in as transitive dependencies

It seems that in my dev shell PYTHONPATH is not set. However, running

import sys
print(sys.path)

produces:

['', '/nix/store/p1zbyfrpj3hq50mxh5hmxl3kqpa2b1am-python3-3.11.6/lib/python311.zip', '/nix/store/p1zbyfrpj3hq50mxh5hmxl3kqpa2b1am-python3-3.11.6/lib/python3.11', '/nix/store/p1zbyfrpj3hq50mxh5hmxl3kqpa2b1am-python3-3.11.6/lib/python3.11/lib-dynload', '/nix/store/p1zbyfrpj3hq50mxh5hmxl3kqpa2b1am-python3-3.11.6/lib/python3.11/site-packages', '/nix/store/kwyj30nr4g40skc2m4nw0v2h1hh3miw5-python3-3.11.6-env/lib/python3.11/site-packages']

I looked into the site-packages of the python3-3.11.6-3nv:

breakds@samaritan {~/projects/nixvital.org/ml-pkgs} $ lsd /nix/store/kwyj30nr4g40skc2m4nw0v2h1hh3miw5-python3-3.11.6-env/lib/python3.11/site-packages
󰌠 __pycache__                                 flatbuffers-23.5.26.dist-info   ml_dtypes                   opt_einsum-3.3.0.dist-info   six-1.16.0.dist-info
 _sysconfigdata__linux_x86_64-linux-gnu.py   jax                             ml_dtypes-0.3.1.dist-info   README.txt                   six.py
 absl                                        jax-0.4.20.dist-info            numpy                       scipy
 absl_py-1.4.0.dist-info                     jaxlib                          numpy-1.26.1.dist-info      scipy-1.11.3.dist-info
 flatbuffers                                 jaxlib-0.4.20.dist-info         opt_einsum                  sitecustomize.py

Apparently there is only one jaxlib being pulled in here.

@samuela
Copy link
Member

samuela commented Jan 20, 2024

Ok yeah that looks fine.. hmm

@breakds
Copy link
Contributor Author

breakds commented Jan 20, 2024

Btw, I tried two more things

  1. Remove the capabilities constraint - still the same
  2. Tried this on a 3090 machine, a 3080 machine and a 4090 machine. All of them produce exactly the same behavior.

@samuela
Copy link
Member

samuela commented Jan 20, 2024

Attempting to repro this but currently blocked on nix-community/nixGL#157.

@breakds
Copy link
Contributor Author

breakds commented Jan 20, 2024

Attempting to repro this but currently blocked on nix-community/nixGL#157.

Interesting. I had an attempt to run my nix-based ML project on an Ubuntu machine but wasn't able to get the nix dev shell to recognize the GPU a while ago. Have been using NixOS for development since then ...

@samuela
Copy link
Member

samuela commented Jan 20, 2024

Oh yeah it definitely works... mostly. This is the first real issue I've had with nixGL haha

@breakds
Copy link
Contributor Author

breakds commented Jan 21, 2024

Will give it a try! Meanwhile, I can try out any thoughts/ideas you might have to debug. Thanks!

@samuela
Copy link
Member

samuela commented Jan 21, 2024

Yeah my hope upon getting a repro is to bisect nixpkgs commits to find a known-good commit and then hopefully isolate the issue in a commit range. If you're feeling up to the task it would be quite useful!

@SomeoneSerge @ConnorBaker This would be a great use case for GPU-enabled tests and some CI infrastructure. tl;dr is that jaxlibWithCuda build has been green for a while but at some point it stopped actually using/recognizing GPUs.

@breakds
Copy link
Contributor Author

breakds commented Jan 21, 2024

Yeah my hope upon getting a repro is to bisect nixpkgs commits to find a known-good commit and then hopefully isolate the issue in a commit range. If you're feeling up to the task it would be quite useful!

Got it. This looks like something I can try. It would probably take a while as I'd imagine CUDA will be built again and again.

@breakds
Copy link
Contributor Author

breakds commented Jan 21, 2024

Let me try to build from a few commits on master to start with.

@breakds
Copy link
Contributor Author

breakds commented Jan 21, 2024

Tried f195a5e7fad77b5128ebfaba0a6112cd9ddca0d2 from about 3 months ago. Got different behavior:

Python 3.10.12 (main, Jun  6 2023, 22:43:10) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705820658.921836 1555874 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
>>>
I0000 00:00:1705820728.067464 1555874 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.

An older version 083e133 has similar issue:

Python 3.10.12 (main, Jun  6 2023, 22:43:10) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]

@breakds
Copy link
Contributor Author

breakds commented Jan 21, 2024

Also confirmed that 5853814 can reproduce the probem:

Python 3.11.5 (main, Aug 24 2023, 12:23:19) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]

Similarly, f697235 from 2023-12-01 can reproduce the issue:

Python 3.11.6 (main, Oct  2 2023, 13:45:54) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]

@breakds
Copy link
Contributor Author

breakds commented Jan 22, 2024

The above seems to suggest that, it is likely that in the past 3 months jaxlibWithCuda would not work with CUDA.

@breakds
Copy link
Contributor Author

breakds commented Jan 22, 2024

I found the a successful commit on 23.05: 70bdade

Please note that this is not a commit from master and I believe it is using CUDA 11.7

>>> import jax
>>> jax.devices()
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

>>> import jax.lib
>>> backend = jax.lib.xla_bridge.get_backend()
>>> backend.platform
'gpu'
>>> backend.platform_version
'cuda 11070'

@samuela
Copy link
Member

samuela commented Jan 22, 2024

Nice! Well I guess we have a (rough) commit window now to bisect on.

Btw, I'm curious: does python3Packages.jaxlib-bin work any better for you on master? We don't have a CUDA alias for that one, so you'll need to set config.cudaSupport = true

@breakds
Copy link
Contributor Author

breakds commented Jan 23, 2024

Nice! Well I guess we have a (rough) commit window now to bisect on.

Btw, I'm curious: does python3Packages.jaxlib-bin work any better for you on master? We don't have a CUDA alias for that one, so you'll need to set config.cudaSupport = true

Let me try that. My understanding is that

  1. jaxlib-bin is about downloading the pre-built binary (wheel) and patch it, while
  2. jaxlibWithCuda is about spinning up bazel to build the wheel by ourselves

And for end users of jaxlib, they should behave pretty much the same. Is my understanding correct? Thanks!

@breakds
Copy link
Contributor Author

breakds commented Jan 23, 2024

Btw, I'm curious: does python3Packages.jaxlib-bin work any better for you on master? We don't have a CUDA alias for that one, so you'll need to set config.cudaSupport = true

Tried jaxlib-bin. Seeing the following error:

>>> import jaxlib.xla_client
2024-01-22 16:27:55.773987: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.

@breakds
Copy link
Contributor Author

breakds commented Jan 23, 2024

Digged a little bit into this.

On the successful branch 23.05 with CUDA 11.7:

>>> import jaxlib
>>> import jaxlib.xla_client
>>> jaxlib.xla_client.make_gpu_client(platform_name="cuda")
2024-01-22 17:11:18.541943: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-01-22 17:11:18.542124: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0xabbb80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-01-22 17:11:18.542155: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): NVIDIA GeForce RTX 3060 Ti, Compute Capability 8.6
2024-01-22 17:11:18.542372: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.cc:198] Using BFC allocator.
2024-01-22 17:11:18.542407: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 6270959616 bytes on device 0 for BFCAllocator.
<jaxlib.xla_extension.Client object at 0x7ffee56837b0>

while on the failed branch 23.11,

>>> import jaxlib
>>> import jaxlib.xla_client
>>> jaxlib.xla_client.make_gpu_client(platform_name="cuda")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/nix/store/1z91x9lrhrv195vbhl4962qclmnmnwhv-python3-3.11.6-env/lib/python3.11/site-packages/jaxlib/xla_client.py", line 88, in make_gpu_client
    config = _xla.GpuAllocatorConfig()
             ^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'

This seems to suggest that the jaxlib we have in this case is not enabled for cuda support.

@samuela
Copy link
Member

samuela commented Jan 23, 2024

And for end users of jaxlib, they should behave pretty much the same. Is my understanding correct? Thanks!

Yes, that's exactly right!

This seems to suggest that the jaxlib we have in this case is not enabled for cuda support.

Nice find! Huh, very weird. We're building with all the GPU/CUDA flags turned on AFAIK. Also, it's surprising to me that jaxlib-bin isn't working on your system since upstream builds that and we make only minor modifications

@breakds
Copy link
Contributor Author

breakds commented Jan 23, 2024

Nice find! Huh, very weird. We're building with all the GPU/CUDA flags turned on AFAIK.

I checked their build script and found that there is a bazel flag called cuda_plugin. Don't know what it is (maybe related to this?), but jaxlib's default.nix does not have it. I tried to add it and jaxlib (more specifically, the bazel-build inside it) cannot build any more.

Also, it's surprising to me that jaxlib-bin isn't working on your system since upstream builds that and we make only minor modifications

Sorry if I did not make it clear. jaxlib-bin can be successfully built on my system. The problem is similar to jaxlib, as it cannot be used with CUDA. It just prompts a different error.

@breakds
Copy link
Contributor Author

breakds commented Jan 23, 2024

More on jaxlib-bin. Although it can successfully create the "client", cudart_stub.cc will report not being able to find cuda drivers:

>> import jaxlib.xla_client
2024-01-23 13:51:38.735183: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
>>> jaxlib.xla_client.make_gpu_client(platform_name="cuda")
2024-01-23 13:54:38.390979: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must b
e at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-01-23 13:54:38.391106: I external/xla/xla/service/service.cc:168] XLA service 0xb9f6f0 initialized for platform CUDA (this does not guarantee that XLA will be used
). Devices:
2024-01-23 13:54:38.391129: I external/xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3080, Compute Capability 8.6
2024-01-23 13:54:38.391319: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:738] Using BFC allocator.
2024-01-23 13:54:38.391342: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 7869579264 bytes on device 0 for BFCAllocator.
<jaxlib.xla_extension.Client object at 0x7ffee6a335b0>

This is jaxlib-bin on 23.11.

The jaxlib-bin on master has been upgraded to 0.4.23, which has the same problem of not being able to find CUDA:

>>> import jaxlib.xla_client
2024-01-23 14:10:18.300076: I external/tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.

@breakds
Copy link
Contributor Author

breakds commented Jan 30, 2024

I am working on the fix for jaxlib-bin. Not sure what is the best way to test it though. I am thinking about:

  1. Create the fix in a branch
  2. Create a flake.nix that points to that branch
  3. Test jaxlib-bin with CUDA 11 and 12 on 3.9 through 3.12

@samuela
Copy link
Member

samuela commented Jan 30, 2024

Another approach, IMHO a little easier but totally up to you:

$ cd /path/to/your/nixpkgs
$ NIX_PATH=.. nix-shell -p python3XX python3XXPackages.jax "python3XXPackages.jaxlib-bin.override { cudaSupport = true; cudaPackagesGoogle = cudaPackages_XX; }"

Then you can run this for python39, python310, etc and cudaPackages_11 and cudaPackages_12.

That way you can run/test locally without committing changes.

@breakds
Copy link
Contributor Author

breakds commented Jan 30, 2024

The PR is created at #285037

@breakds
Copy link
Contributor Author

breakds commented Feb 9, 2024

I am closing this as at least the bin version is fixed now. Thanks a lot for the help @samuela @SomeoneSerge!

@breakds breakds closed this as completed Feb 9, 2024
@samuela
Copy link
Member

samuela commented Feb 9, 2024

Fixing the *-bin version is great, but let's keep this open until the source version is fixed as well

@samuela samuela reopened this Feb 9, 2024
@rekado
Copy link

rekado commented Feb 10, 2024

FWIW, the Guix version (there is only a source version) is here: https://github.com/guix-science/guix-science-nonfree/blob/master/guix-science-nonfree/packages/machine-learning.scm#L216
It works fine with CUDA.

The definition of the CUDA-less version that it references is here:
https://github.com/guix-science/guix-science/blob/master/guix-science/packages/python.scm#L548

Maybe there's an obvious difference somewhere in the arguments.

@samuela
Copy link
Member

samuela commented Feb 14, 2024

Ok I think I've found the culprit. Somewhere along the way we removed --config=cuda from bazelFlags and/or .jax_configure.bazelrc. Oops!

Working on a fix now... Currently blocked on jax-ml/jax#19811 (comment). AFAICT we do not package CUB and it is not provided in any of the cudaPackages that I've consulted thus far. If anyone has any ideas on the best way to proceed, do let me know!

@SomeoneSerge
Copy link
Contributor

AFAICT we do not package CUB and it is not provided in any of the cudaPackages that I've consulted thus far. If anyone has any ideas on the best way to proceed, do let me know!

Pretty sure it's part of cccl, and we also used to distribute it as part of nvidia-thrust which we deleted

@samuela
Copy link
Member

samuela commented Feb 14, 2024

Ah, thanks for the tip @SomeoneSerge ! I wasn't able to find cudaPackages.cccl or cccl:

ubuntu@bitbop:~/nixpkgs$ nix-build -A cudaPackagesGoogle.cccl
error: attribute 'cccl' in selection path 'cudaPackagesGoogle.cccl' not found
       Did you mean nccl?
ubuntu@bitbop:~/nixpkgs$ nix-build -A cccl
error: attribute 'cccl' in selection path 'cccl' not found
       Did you mean one of ccal, cccc, ccl, rccl or abcl?

Perhaps it has a different attribute name?

Assuming a typo, I checked nccl which is already a dependency of jaxlib, but to no avail:

ubuntu@bitbop:~/nixpkgs$ nix-build -A cudaPackagesGoogle.nccl
/nix/store/5qap16w6v0q38qrikg59gxrgm9wbgdca-nccl-2.19.3-1
ubuntu@bitbop:~/nixpkgs$ sudo find /nix -type f -name "cublock_load.cuh"
ubuntu@bitbop:~/nixpkgs$ 

Perhaps I'm missing something?

Btw, draft PR started here: #288857

@samuela
Copy link
Member

samuela commented Feb 14, 2024

Oop, found cudaPackages.cuda_cccl, but not finding that header file sadly:

ubuntu@bitbop:~/nixpkgs$ nix-build -A cudaPackagesGoogle.cuda_cccl
these 2 paths will be fetched (1.15 MiB download, 12.74 MiB unpacked):
  /nix/store/r9677wlkn8bq9idbr2f3gwqpnjzagp4h-cuda_cccl-11.8.89
  /nix/store/30cq1zsb0a5a2xkca1jkn7gliyjbswsg-cuda_cccl-11.8.89-dev
copying path '/nix/store/30cq1zsb0a5a2xkca1jkn7gliyjbswsg-cuda_cccl-11.8.89-dev' from 'https://cuda-maintainers.cachix.org'...
copying path '/nix/store/r9677wlkn8bq9idbr2f3gwqpnjzagp4h-cuda_cccl-11.8.89' from 'https://cuda-maintainers.cachix.org'...
/nix/store/r9677wlkn8bq9idbr2f3gwqpnjzagp4h-cuda_cccl-11.8.89
ubuntu@bitbop:~/nixpkgs$ sudo find /nix -type f -name "cublock_load.cuh"
ubuntu@bitbop:~/nixpkgs$ 

@SomeoneSerge
Copy link
Contributor

20 | #include "cub/block/block_load.cuh"

Not "cublock_load.cuh"?

❯ nix build -f . --arg config '{ allowUnfree = true; cudaSupport = true; }' cudaPackages.cuda_cccl -o cccl -L
❯ fd block_load cccl/
cccl/include/cub/block/block_load.cuh

@SomeoneSerge
Copy link
Contributor

packaging CUB from scratch

#224292

@samuela
Copy link
Member

samuela commented Feb 15, 2024

Adding cudaPackages.cuda_cccl.dev seemed to solve the problem, despite the difference in filenames. Not sure exactly what kind of wizardry they're doing to adjust header names, but it seems happy for now.

@SomeoneSerge
Copy link
Contributor

despite the difference in filenames

Looking at your issue upstream, I think the names match?

@samuela
Copy link
Member

samuela commented Feb 15, 2024

Looking at your issue upstream, I think the names match?

Oop, yeah you're right. I was confusing that with something else

samuela added a commit to samuela/nixpkgs that referenced this issue Feb 17, 2024
This change snowballed to cover a number of improvements:
1. Fix NixOS#282184. The `--config=cuda` flag was lost at some point, disabling CUDA builds even with `cudaSupport = true`.
2. Migrate to using the CUDA redist packages instead of cudatoolkit.
3. Unify stdenv behind `effectiveStdenv` following a pattern that has precedent in OpenCV's derivation and was recommended in NixOS#288857 (comment).
samuela added a commit to samuela/nixpkgs that referenced this issue Feb 18, 2024
This change snowballed to cover a number of improvements:
1. Fix NixOS#282184. The `--config=cuda` flag was lost at some point, disabling CUDA builds even with `cudaSupport = true`.
2. Migrate to using the CUDA redist packages instead of cudatoolkit.
3. Unify stdenv behind `effectiveStdenv` following a pattern that has precedent in OpenCV's derivation and was recommended in NixOS#288857 (comment).
samuela added a commit to samuela/nixpkgs that referenced this issue Feb 19, 2024
This change snowballed to cover a number of improvements:
1. Fix NixOS#282184. The `--config=cuda` flag was lost at some point, disabling CUDA builds even with `cudaSupport = true`.
2. Migrate to using the CUDA redist packages instead of cudatoolkit.
3. Unify stdenv behind `effectiveStdenv` following a pattern that has precedent in OpenCV's derivation and was recommended in NixOS#288857 (comment).
samuela added a commit to samuela/nixpkgs that referenced this issue Feb 20, 2024
This change snowballed to cover a number of improvements:
1. Fix NixOS#282184. The `--config=cuda` flag was lost at some point, disabling CUDA builds even with `cudaSupport = true`.
2. Migrate to using the CUDA redist packages instead of cudatoolkit.
3. Unify stdenv behind `effectiveStdenv` following a pattern that has precedent in OpenCV's derivation and was recommended in NixOS#288857 (comment).
SomeoneSerge added a commit that referenced this issue Feb 23, 2024
python3Packages.jaxlib: fix #282184 and migrate to cuda redist packages
@nixos-discourse
Copy link

This issue has been mentioned on NixOS Discourse. There might be relevant details there:

https://discourse.nixos.org/t/jaxlibwithcuda-not-using-cuda/36873/5

@CMCDragonkai
Copy link
Member

If you're using NixOS is there a platform wide config to always have the cuda libraries enabled?

@aryanjassal
Copy link

I'm trying to compile jax using the commit cdd38b2, which appears to have closed this issue. When I tried to run jax on CUDA using this commit, I would need to compile jaxlib from scratch. During the compilation, I get a hash mismatch error in a library that jaxlib is compiling.

[user@system:~/jax]$ nix-shell
[...snip...]
INFO: Analyzed 2 targets (240 packages loaded, 20380 targets configured).
 checking cached actions
INFO: Found 2 targets...
INFO: Elapsed time: 115.749s, Critical Path: 0.00s
INFO: 0 processes.
INFO: Build completed successfully, 0 total actions
buildPhase completed in 1 minutes 56 seconds
Running phase: installPhase
installPhase completed in 49 seconds
error: hash mismatch in fixed-output derivation '/nix/store/yjm7vri9hzasddanrjhzd4fi1a9406ip-bazel-build-jaxlib-0.4.24-deps.tar.gz.drv':
         specified: sha256-IEKoHjCOtKZKvU/DUUjbvXldORFJuyO1R3F6CZZDXxM=
            got:    sha256-DLx5NHoSP9cOCf6hBgLXqNDb/XhJ3OWxA8/HGiTBBUo=
error: 1 dependencies of derivation '/nix/store/dsd7ssdqbmi82fvj1fcf8pnzv95wf6c4-bazel-build-jaxlib-0.4.24.drv' failed to build
error: 1 dependencies of derivation '/nix/store/yivr8ykzvg43g33hlm87f400568vd9ly-python3.11-jaxlib-0.4.24.drv' failed to build
error: 1 dependencies of derivation '/nix/store/vhr8syx8h5s3j18bp4kw9ifr3wviqd8p-python3.11-jax-0.4.24.drv' failed to build

I even tried to use a later commit d97b37 to get the following output:

[user@system:~/jax]$ TF_CPP_MIN_LOG_LEVEL=0 python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)"
2024-06-17 06:08:48.914216: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-17 06:08:48.970135: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-17 06:08:48.970677: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-17 06:08:49.123745: I external/tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
CUDA backend failed to initialize: Unable to load CUDA. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
cpu

This is my shell.nix:

{ pkgs ? import ./pkgs.nix { inherit builtins; } }:

pkgs.mkShell {
  name = "jax-gpu";

  buildInputs = with pkgs; with pkgs.python3Packages; with pkgs.cudaPackages; [
    python3
    jax
    jaxlib-bin
    numpy
    cuda_cccl.dev
  ];
}

And this is my pkgs.nix (for the newer commit):

{ builtins }:

let
  revision = "d97b37430f8f0262f97f674eb357551b399c2003";
  url = "https://github.com/nixos/nixpkgs/archive/${revision}.tar.gz";

  nixpkgs = builtins.fetchTarball {
    url = url;
    sha256 = "0za8bwgvk83r18z1b7ll1hlh45ijimp388ll5r8z44yglfcp28gx";
  };
in
  import nixpkgs {
    config = {
      allowUnfree = true;
      cudaSupport = true;
    };
  }

@SomeoneSerge
Copy link
Contributor

If you're using NixOS is there a platform wide config to always have the cuda libraries enabled? @CMCDragonkai

This seems like a support question, you're more likely to find help on NixOS Discourse or Matrix. But the answer is nixpkgs.config.cudaSupport = true and this will trigger a lot of rebuilds

I get a hash mismatch error in a library that jaxlib is compiling. @aryanjassal

Interesting, this might be a reproducibility issue with the way we fetch bazel dependencies, in which case this deserves a separate issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

7 participants