-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't build jaxlib in GH200 #21299
Comments
Can you share the compilation errors you're getting in external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc? |
Thanks for the quick reply...#include "absl/base/call_once.h" static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit); CHECK(it != kernel_ptr_cache.end()); .visible .entry redzone_checker( ld.param.u64 %rd6, [buffer_length]; absl::StatusOr<const ComparisonKernel*> GetComparisonKernel( return LoadKernelOrGetPtr<DeviceMemory<uint8_t>, uint8_t, uint64_t, |
This snippet doesn't contain any compilation errors AFAICT. Can you upload the output of the compiler to a gist? |
I'm sorry but I don't understand the request. Can you be more specific and include the linux terminal commands you want me to run? |
The message you posted initially
is usually preceded by compilation error messages, describing what went wrong while compiling jaxlib. If you upload the full output of |
cc : @nouiz |
This is what I get now (I'm in a different docker imae now): (base) root@8c1c1dd5a763:~/jax# python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12
_ | |/ _ \ \ / Bazel binary path: ./bazel-6.5.0-linux-arm64 Building XLA and installing it in the jaxlib source tree... |
Okay, from this it looks like your CUDA installation is missing development headers:
|
JAX-Toolbox has nightly JAX container for ARM: https://github.com/NVIDIA/JAX-Toolbox If you want to build JAX yourself, this container already contain cuda: docker pull nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 I'm mostly always using those 2 containers for development in JAX. |
@nouiz thanks, I tried those two options without success. I'm using a GH200 and I'm trying to use jax with the GPU, but it always fails. |
I'll also note that we (JAX) release CUDA arm wheels on pypi which should just work on GH200. Try:
(The more usual |
We released ARM wheel last week. But it isn't tested.
This installed cudnn 9.1.1. cudnn 8 isn't supported on GraceHopper to my knowledge. |
Thanks for the update, I'll check it right away |
@hawkinsp After executing
but get |
@nouiz does this mean I can't yet run jax with GPU acceleration on GH200? |
It is possible. Can you give the exact command line you use to start the docker container? |
Last week I used the docker jax:jax here: https://github.com/NVIDIA/JAX-Toolbox |
|
|
works |
gilad@gracehopper:~$ docker run -it --gpus all ghcr.io/nvidia/jax:jax ==========
|
@nouiz I still get: During handling of the above exception, another exception occurred: Traceback (most recent call last): |
From the output of nvidia-smi, the issues seem to be that MIG is enabled, but no MIG "instance" is created. |
@nouiz u we're right! My colleague fixed the issue:
But now when I run the following code:
I get the following error:
@nouiz is this the cuDNN problem you mentioned? what can I do to check it? |
Yes I have cuDNN version 9.1.0. |
Yes, you need to downgrade to CUDNN 8.9 for now. JAX doesn't yet release with CUDNN 9. |
Before downgrading, which container do you use and how JAX was installed? Did you try with the JAX-toolbox container without setting CUDA_VISIBLE_DEVICES? |
Description
I'm trying to run some code utilizing my GH200 without success. Unable to build jaxlib for my GPU.
System info (python version, jaxlib version, accelerator, etc.)
root@470c73980644:~/jax# nvidia-smi
Sun May 19 12:13:00 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GH200 480GB On | 00000009:01:00.0 Off | On |
| N/A 23C P0 62W / 900W | 5MiB / 97871MiB | N/A Default |
| | | Enabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| MIG devices: |
+------------------+----------------------------------+-----------+-----------------------+
| GPU GI CI MIG | Memory-Usage | Vol| Shared |
| ID ID Dev | BAR1-Usage | SM Unc| CE ENC DEC OFA JPG |
| | | ECC| |
|==================+==================================+===========+=======================|
| No MIG devices found |
+-----------------------------------------------------------------------------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
root@470c73980644:~/jax# nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:24:28_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0
the error i get:
Error limit reached.
100 errors detected in the compilation of "external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc".
Compilation terminated.
Target //jaxlib/tools:build_gpu_plugin_wheel failed to build
INFO: Elapsed time: 7.262s, Critical Path: 4.88s
INFO: 73 processes: 73 internal.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
Traceback (most recent call last):
File "/root/jax/build/build.py", line 733, in
main()
File "/root/jax/build/build.py", line 727, in main
shell(build_pjrt_plugin_command)
File "/root/jax/build/build.py", line 45, in shell
output = subprocess.check_output(cmd)
File "/usr/lib/python3.10/subprocess.py", line 421, in check_output
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
File "/usr/lib/python3.10/subprocess.py", line 526, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/usr/local/bin/bazel', 'run', '--verbose_failures=true', '//jaxlib/tools:build_gpu_plugin_wheel', '--', '--output_path=/root/jax/dist', '--jaxlib_git_hash=45a7c22e932fee257016bf0da1022be146ed6095', '--cpu=aarch64', '--cuda_version=12']' returned non-zero exit status 1.
The text was updated successfully, but these errors were encountered: