Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't build jaxlib in GH200 #21299

Open
giladqm opened this issue May 19, 2024 · 27 comments
Open

Can't build jaxlib in GH200 #21299

giladqm opened this issue May 19, 2024 · 27 comments
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@giladqm
Copy link

giladqm commented May 19, 2024

Description

I'm trying to run some code utilizing my GH200 without success. Unable to build jaxlib for my GPU.

System info (python version, jaxlib version, accelerator, etc.)

root@470c73980644:~/jax# nvidia-smi
Sun May 19 12:13:00 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GH200 480GB On | 00000009:01:00.0 Off | On |
| N/A 23C P0 62W / 900W | 5MiB / 97871MiB | N/A Default |
| | | Enabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| MIG devices: |
+------------------+----------------------------------+-----------+-----------------------+
| GPU GI CI MIG | Memory-Usage | Vol| Shared |
| ID ID Dev | BAR1-Usage | SM Unc| CE ENC DEC OFA JPG |
| | | ECC| |
|==================+==================================+===========+=======================|
| No MIG devices found |
+-----------------------------------------------------------------------------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
root@470c73980644:~/jax# nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:24:28_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0

the error i get:
Error limit reached.
100 errors detected in the compilation of "external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc".
Compilation terminated.
Target //jaxlib/tools:build_gpu_plugin_wheel failed to build
INFO: Elapsed time: 7.262s, Critical Path: 4.88s
INFO: 73 processes: 73 internal.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
Traceback (most recent call last):
File "/root/jax/build/build.py", line 733, in
main()
File "/root/jax/build/build.py", line 727, in main
shell(build_pjrt_plugin_command)
File "/root/jax/build/build.py", line 45, in shell
output = subprocess.check_output(cmd)
File "/usr/lib/python3.10/subprocess.py", line 421, in check_output
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
File "/usr/lib/python3.10/subprocess.py", line 526, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/usr/local/bin/bazel', 'run', '--verbose_failures=true', '//jaxlib/tools:build_gpu_plugin_wheel', '--', '--output_path=/root/jax/dist', '--jaxlib_git_hash=45a7c22e932fee257016bf0da1022be146ed6095', '--cpu=aarch64', '--cuda_version=12']' returned non-zero exit status 1.

@giladqm giladqm added the bug Something isn't working label May 19, 2024
@superbobry
Copy link
Member

Can you share the compilation errors you're getting in external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc?

@giladqm
Copy link
Author

giladqm commented May 20, 2024

Thanks for the quick reply...

#include "absl/base/call_once.h"
#include "absl/base/const_init.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/node_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/cuda/cuda_asm_compiler.h"
#include "xla/stream_executor/cuda/cuda_driver.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/redzone_allocator_kernel.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/stream_executor_pimpl.h"
#include "xla/stream_executor/typed_kernel_factory.h"
#include "tsl/platform/statusor.h"
namespace stream_executor {
// Maintains a cache of pointers to loaded kernels
template <typename... Args>
static absl::StatusOr<TypedKernel<Args...>> LoadKernelOrGetPtr(
StreamExecutor
executor, absl::string_view kernel_name,
absl::string_view ptx, absl::Span cubin_data) {
using KernelPtrCacheKey =
std::tuple<CUcontext, absl::string_view, absl::string_view>;

static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit);
static auto& kernel_ptr_cache ABSL_GUARDED_BY(kernel_ptr_cache_mutex) =
*new absl::node_hash_map<KernelPtrCacheKey, TypedKernel<Args...>>();
CUcontext current_context = cuda::CurrentContextOrDie();
KernelPtrCacheKey kernel_ptr_cache_key{current_context, kernel_name, ptx};
absl::MutexLock lock(&kernel_ptr_cache_mutex);
auto it = kernel_ptr_cache.find(kernel_ptr_cache_key);
if (it == kernel_ptr_cache.end()) {
TF_ASSIGN_OR_RETURN(TypedKernel<Args...> loaded,
(TypedKernelFactory<Args...>::Create(
executor, kernel_name, ptx, cubin_data)));
it =
kernel_ptr_cache.emplace(kernel_ptr_cache_key, std::move(loaded)).first;
}

CHECK(it != kernel_ptr_cache.end());
return &it->second;
}
// PTX blob for the function which checks that every byte in
// input_buffer (length is buffer_length) is equal to redzone_pattern.
//
// On mismatch, increment the counter pointed to by out_mismatch_cnt_ptr.
//
// Generated from:
// global void redzone_checker(unsigned char* input_buffer,
// unsigned char redzone_pattern,
// unsigned long long buffer_length,
// int* out_mismatched_ptr) {
// unsigned long long idx = threadIdx.x + blockIdx.x * blockDim.x;
// if (idx >= buffer_length) return;
// if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1);
// }
//
// Code must compile for the oldest GPU XLA may be compiled for.
static const char* redzone_checker_ptx = R"(
.version 4.2
.target sm_30
.address_size 64

.visible .entry redzone_checker(
.param .u64 input_buffer,
.param .u8 redzone_pattern,
.param .u64 buffer_length,
.param .u64 out_mismatch_cnt_ptr
)
{
.reg .pred %p<3>;
.reg .b16 %rs<3>;
.reg .b32 %r<6>;
.reg .b64 %rd<8>;

ld.param.u64 %rd6, [buffer_length];
ld.param.u64 %rd4, [input_buffer];
cvta.to.global.u64 %rd2, %rd4;
add.s64 %rd7, %rd2, %rd3;
ld.global.u8 %rs2, [%rd7];
setp.eq.s16 %p2, %rs2, %rs1;
@%p2 bra LBB6_3;
ld.param.u64 %rd5, [out_mismatch_cnt_ptr];
ld.param.u8 %rs1, [redzone_pattern];
ld.param.u64 %rd4, [input_buffer];
cvta.to.global.u64 %rd2, %rd4;
add.s64 %rd7, %rd2, %rd3;
ld.global.u8 %rs2, [%rd7];
setp.eq.s16 %p2, %rs2, %rs1;
@%p2 bra LBB6_3;
ld.param.u64 %rd5, [out_mismatch_cnt_ptr];
cvta.to.global.u64 %rd1, %rd5;
atom.global.add.u32 %r5, [%rd1], 1;
LBB6_3:
ret;
}
)";

absl::StatusOr<const ComparisonKernel*> GetComparisonKernel(
StreamExecutor* executor, GpuAsmOpts gpu_asm_opts) {
absl::Span compiled_ptx = {};
absl::StatusOr<absl::Span> compiled_ptx_or =
CompileGpuAsmOrGetCached(executor->device_ordinal(), redzone_checker_ptx,
gpu_asm_opts);
if (compiled_ptx_or.ok()) {
compiled_ptx = compiled_ptx_or.value();
} else {
static absl::once_flag ptxas_not_found_logged;
absl::call_once(ptxas_not_found_logged, & {
LOG(WARNING) << compiled_ptx_or.status()
<< "\nRelying on driver to perform ptx compilation. "
<< "\nModify $PATH to customize ptxas location."
<< "\nThis message will be only logged once.";
});
}

return LoadKernelOrGetPtr<DeviceMemory<uint8_t>, uint8_t, uint64_t,
DeviceMemory<uint64_t>>(
executor, "redzone_checker", redzone_checker_ptx, compiled_ptx);
}
} // namespace stream_executor
.reg .b16 %rs<3>;

@superbobry
Copy link
Member

This snippet doesn't contain any compilation errors AFAICT. Can you upload the output of the compiler to a gist?

@giladqm
Copy link
Author

giladqm commented May 20, 2024

I'm sorry but I don't understand the request. Can you be more specific and include the linux terminal commands you want me to run?

@superbobry
Copy link
Member

The message you posted initially

Error limit reached.
100 errors detected in the compilation of "external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc".

is usually preceded by compilation error messages, describing what went wrong while compiling jaxlib. If you upload the full output of build.py, that would include the error messages as well.

@mjsML
Copy link
Collaborator

mjsML commented May 20, 2024

cc : @nouiz

@mjsML mjsML added the NVIDIA GPU Issues specific to NVIDIA GPUs label May 20, 2024
@giladqm
Copy link
Author

giladqm commented May 20, 2024

This is what I get now (I'm in a different docker imae now):

(base) root@8c1c1dd5a763:~/jax# python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12

 _   _  __  __
| | / \ \ \/ /

_ | |/ _ \ \ /
| || / ___ /
___/
/ /_/_\

Bazel binary path: ./bazel-6.5.0-linux-arm64
Bazel version: 6.5.0
Python binary path: /root/miniconda3/bin/python3
Python version: 3.12
Use clang: no
MKL-DNN enabled: yes
Target CPU: aarch64
Target CPU features: release
CUDA enabled: yes
NCCL enabled: yes
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
./bazel-6.5.0-linux-arm64 run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=/root/jax/dist --jaxlib_git_hash=ffdb9bb0b0755e66f55995cafa2cf0946ed66598 --cpu=aarch64 --skip_gpu_kernels
INFO: Options provided by the client:
Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /root/jax/.bazelrc:
Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'run' from /root/jax/.bazelrc:
Inherited 'build' options: --nocheck_visibility --apple_platform_type=macos --macos_minimum_os=10.14 --announce_rc --define open_source_build=true --spawn_strategy=standalone --enable_platform_specific_config --experimental_cc_shared_library --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --define=tsl_link_protobuf=true -c opt --config=short_logs --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. --@xla//xla/python:enable_gpu=false
INFO: Reading rc options for 'run' from /root/jax/.jax_configure.bazelrc:
Inherited 'build' options: --strategy=Genrule=standalone --config=mkl_open_source_only --config=cuda --config=cuda_plugin --repo_env HERMETIC_PYTHON_VERSION=3.12
INFO: Found applicable config definition build:short_logs in file /root/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:mkl_open_source_only in file /root/jax/.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:cuda in file /root/jax/.bazelrc: --repo_env TF_NEED_CUDA=1 --repo_env TF_NCCL_USE_STUB=1 --action_env TF_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda --@xla//xla/python:enable_gpu=true --@xla//xla/python:jax_cuda_pip_rpaths=true --define=xla_python_enable_gpu=true --linkopt=-Wl,--disable-new-dtags
INFO: Found applicable config definition build:cuda_plugin in file /root/jax/.bazelrc: --@xla//xla/python:enable_gpu=false --define=xla_python_enable_gpu=false
INFO: Found applicable config definition build:linux in file /root/jax/.bazelrc: --config=posix --copt=-Wno-unknown-warning-option --copt=-Wno-stringop-truncation --copt=-Wno-array-parameter
INFO: Found applicable config definition build:posix in file /root/jax/.bazelrc: --copt=-fvisibility=hidden --copt=-Wno-sign-compare --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
Loading:
INFO: Repository local_config_cuda instantiated at:
/root/jax/WORKSPACE:45:15: in
/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/xla/workspace2.bzl:121:19: in workspace
/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/workspace2.bzl:601:19: in workspace
/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/workspace2.bzl:72:19: in _tf_toolchains
Repository rule cuda_configure defined at:
/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl:1542:33: in
ERROR: An error occurred during the fetch of repository 'local_config_cuda':
Traceback (most recent call last):
File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 1491, column 38, in _cuda_autoconf_impl
_create_local_cuda_repository(repository_ctx)
File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 1040, column 35, in _create_local_cuda_repository
cuda_config = _get_cuda_config(repository_ctx)
File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 716, column 30, in _get_cuda_config
config = find_cuda_config(repository_ctx, ["cuda", "cudnn"])
File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 693, column 26, in find_cuda_config
exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_cuda_config] + cuda_libraries)
File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/remote_config/common.bzl", line 230, column 13, in execute
fail(
Error in fail: Repository command failed
Could not find any cuda.h matching version '' in any subdirectory:
''
'include'
'include/cuda'
'include/-linux-gnu'
'extras/CUPTI/include'
'include/cuda/CUPTI'
'local/cuda/extras/CUPTI/include'
'targets/x86_64-linux/include'
of:
'/lib'
'/usr'
'/usr/lib/aarch64-linux-gnu'
'/usr/lib/aarch64-linux-gnu/libfakeroot'
ERROR: /root/jax/WORKSPACE:45:15: fetching cuda_configure rule //external:local_config_cuda: Traceback (most recent call last):
File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 1491, column 38, in _cuda_autoconf_impl
_create_local_cuda_repository(repository_ctx)
File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 1040, column 35, in _create_local_cuda_repository
cuda_config = _get_cuda_config(repository_ctx)
File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 716, column 30, in _get_cuda_config
config = find_cuda_config(repository_ctx, ["cuda", "cudnn"])
File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 693, column 26, in find_cuda_config
exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_cuda_config] + cuda_libraries)
File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/remote_config/common.bzl", line 230, column 13, in execute
fail(
Error in fail: Repository command failed
Could not find any cuda.h matching version '' in any subdirectory:
''
'include'
'include/cuda'
'include/
-linux-gnu'
'extras/CUPTI/include'
'include/cuda/CUPTI'
'local/cuda/extras/CUPTI/include'
'targets/x86_64-linux/include'
of:
'/lib'
'/usr'
'/usr/lib/aarch64-linux-gnu'
'/usr/lib/aarch64-linux-gnu/libfakeroot'
INFO: Repository rules_cc instantiated at:
/root/jax/WORKSPACE:48:15: in
/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/xla/workspace1.bzl:12:19: in workspace
/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/workspace1.bzl:30:14: in workspace
/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/com_github_grpc_grpc/bazel/grpc_deps.bzl:158:21: in grpc_deps
Repository rule http_archive defined at:
/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/bazel_tools/tools/build_defs/repo/http.bzl:372:31: in
ERROR: Skipping '@xla//xla/python:enable_gpu': no such package '@local_config_cuda//cuda': Repository command failed
Could not find any cuda.h matching version '' in any subdirectory:
''
'include'
'include/cuda'
'include/-linux-gnu'
'extras/CUPTI/include'
'include/cuda/CUPTI'
'local/cuda/extras/CUPTI/include'
'targets/x86_64-linux/include'
of:
'/lib'
'/usr'
'/usr/lib/aarch64-linux-gnu'
'/usr/lib/aarch64-linux-gnu/libfakeroot'
WARNING: Target pattern parsing failed.
ERROR: @xla//xla/python:enable_gpu :: Error loading option @xla//xla/python:enable_gpu: no such package '@local_config_cuda//cuda': Repository command failed
Could not find any cuda.h matching version '' in any subdirectory:
''
'include'
'include/cuda'
'include/
-linux-gnu'
'extras/CUPTI/include'
'include/cuda/CUPTI'
'local/cuda/extras/CUPTI/include'
'targets/x86_64-linux/include'
of:
'/lib'
'/usr'
'/usr/lib/aarch64-linux-gnu'
'/usr/lib/aarch64-linux-gnu/libfakeroot'
Traceback (most recent call last):
File "/root/jax/build/build.py", line 733, in
main()
File "/root/jax/build/build.py", line 699, in main
shell(build_cpu_wheel_command)
File "/root/jax/build/build.py", line 45, in shell
output = subprocess.check_output(cmd)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/subprocess.py", line 466, in check_output
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-6.5.0-linux-arm64', 'run', '--verbose_failures=true', '//jaxlib/tools:build_wheel', '--', '--output_path=/root/jax/dist', '--jaxlib_git_hash=ffdb9bb0b0755e66f55995cafa2cf0946ed66598', '--cpu=aarch64', '--skip_gpu_kernels']' returned non-zero exit status 2.

@superbobry
Copy link
Member

Okay, from this it looks like your CUDA installation is missing development headers:

Could not find any cuda.h matching version '' in any subdirectory:

@nouiz
Copy link
Collaborator

nouiz commented May 20, 2024

JAX-Toolbox has nightly JAX container for ARM: https://github.com/NVIDIA/JAX-Toolbox
For example: ghcr.io/nvidia/jax:jax for the latest nightly.

If you want to build JAX yourself, this container already contain cuda:

docker pull nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04

I'm mostly always using those 2 containers for development in JAX.

@giladqm
Copy link
Author

giladqm commented May 21, 2024

@nouiz thanks, I tried those two options without success. I'm using a GH200 and I'm trying to use jax with the GPU, but it always fails.

@hawkinsp
Copy link
Member

I'll also note that we (JAX) release CUDA arm wheels on pypi which should just work on GH200. Try:

pip install jax jaxlib jax-cuda12-plugin jax-cuda12-pjrt

(The more usual pip install jax[cuda12] won't work because NVIDIA doesn't release ARM wheels of CUDA, last I checked.)

@nouiz
Copy link
Collaborator

nouiz commented May 21, 2024

We released ARM wheel last week. But it isn't tested.
So let's try jax[cuda12]

docker run -it --gpus all ubuntu
apt-get update; apt-get install -y python3-pip python3.12-venv
python3 -m venv path/to/venv
source path/to/venv/bin/activate
pip install jax[cuda12] # works
python3 -c "import jax; jax.numpy.zeros(3)" # fail with cudnn init error.

This installed cudnn 9.1.1. cudnn 8 isn't supported on GraceHopper to my knowledge.
@hawkinsp Does the JAX wheel for ARM are also build with cudnn 8?
Any idea when the cudnn 9 version can be created?

@giladqm
Copy link
Author

giladqm commented May 21, 2024

Thanks for the update, I'll check it right away

@giladqm
Copy link
Author

giladqm commented May 21, 2024

@hawkinsp After executing

pip install jax jaxlib jax-cuda12-plugin jax-cuda12-pjrt
I'm trying to run the following code:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Specify the index of the GPU you want to use

import jax
import jax.numpy as jnp

def main():
    # Explicitly place arrays on GPU using jax.device_put
    gpu_device = jax.devices("gpu")[0]  # Use the first GPU
    a = jax.random.normal(jax.random.PRNGKey(0), (size, size))
    b = jax.random.normal(jax.random.PRNGKey(1), (size, size))
    a_gpu = jax.device_put(a, device=gpu_device)
    b_gpu = jax.device_put(b, device=gpu_device)

    # Run matrix multiplication on GPU
    result = jnp.dot(a_gpu, b_gpu)

    # Print the result
    print("Result of matrix multiplication:")
    print(result)

if __name__ == "__main__":
    main()

but get
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

@giladqm
Copy link
Author

giladqm commented May 21, 2024

@nouiz does this mean I can't yet run jax with GPU acceleration on GH200?

@nouiz
Copy link
Collaborator

nouiz commented May 21, 2024

It is possible. Can you give the exact command line you use to start the docker container?
What is the output of nvidia-smi in it?
Can you try the jax container we provide? It should work and won't ask you to compile JAX.

@giladqm
Copy link
Author

giladqm commented May 21, 2024

Last week I used the docker jax:jax here: https://github.com/NVIDIA/JAX-Toolbox
I don't mind trying it again.
or do u mean nvcr.io/nvidia/jax:24.04-maxtext-py3 (from here: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax)?

@giladqm
Copy link
Author

giladqm commented May 21, 2024

Last week I used the docker jax:jax here: https://github.com/NVIDIA/JAX-Toolbox I don't mind trying it again. or do u mean nvcr.io/nvidia/jax:24.04-maxtext-py3 (from here: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax)?

@nouiz

@giladqm
Copy link
Author

giladqm commented May 21, 2024

$ docker pull nvcr.io/nvidia/jax:24.04-maxtext-py3
24.04-maxtext-py3: Pulling from nvidia/jax
no matching manifest for linux/arm64/v8 in the manifest list entries

@nouiz

@giladqm
Copy link
Author

giladqm commented May 21, 2024

gilad@gracehopper:~$ docker pull ghcr.io/nvidia/jax:jax 
jax: Pulling from nvidia/jax

works
@nouiz

@giladqm
Copy link
Author

giladqm commented May 21, 2024

gilad@gracehopper:~$ docker pull ghcr.io/nvidia/jax:jax 
jax: Pulling from nvidia/jax

works @nouiz

gilad@gracehopper:~$ docker run -it --gpus all ghcr.io/nvidia/jax:jax

==========
== CUDA ==

CUDA Version 12.4.1

Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

WARNING: Your shm is currenly less than 1GB. This may cause SIGBUS errors.
To avoid this problem, you can manually set the shm size in docker with:

docker run ... --shm-size=1g ...

root@f85914843395:/# nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:24:28_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0
root@f85914843395:/# nvidia-smi
Tue May 21 17:27:31 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GH200 480GB On | 00000009:01:00.0 Off | On |
| N/A 23C P0 62W / 900W | 6MiB / 97871MiB | N/A Default |
| | | Enabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| MIG devices: |
+------------------+----------------------------------+-----------+-----------------------+
| GPU GI CI MIG | Memory-Usage | Vol| Shared |
| ID ID Dev | BAR1-Usage | SM Unc| CE ENC DEC OFA JPG |
| | | ECC| |
|==================+==================================+===========+=======================|
| No MIG devices found |
+-----------------------------------------------------------------------------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
root@f85914843395:/#

@giladqm
Copy link
Author

giladqm commented May 21, 2024

@nouiz I still get:
root@f85914843395:~# python jax_program.py
2024-05-21 18:06:49.257055: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_SYSTEM_NOT_READY: system not yet initialized
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 679, in backends
backend = _init_backend(platform)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 761, in _init_backend
backend = registration.factory()
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 509, in factory
return xla_client.make_c_api_client(plugin_name, options, None)
File "/usr/local/lib/python3.10/dist-packages/jaxlib/xla_client.py", line 190, in make_c_api_client
return _xla.get_c_api_client(plugin_name, options, distributed_client)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: No visible GPU devices.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/root/jax_program.py", line 23, in
main()
File "/root/jax_program.py", line 9, in main
gpu_device = jax.devices("gpu")[0] # Use the first GPU
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 872, in devices
return get_backend(backend).devices()
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 806, in get_backend
return _get_backend_uncached(platform)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 786, in _get_backend_uncached
bs = backends()
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 695, in backends
raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

@nouiz
Copy link
Collaborator

nouiz commented May 21, 2024

From the output of nvidia-smi, the issues seem to be that MIG is enabled, but no MIG "instance" is created.
If exact, that would make all software fail on that node.
Can you ask your admins how they setup MIG and how to have a MIG instance created?

@giladqm
Copy link
Author

giladqm commented May 22, 2024

@nouiz u we're right! My colleague fixed the issue:

(base) nikola@gracehopper:~$  sudo nvidia-smi mig -lgi
+-------------------------------------------------------+
| GPU instances:                                        |
| GPU   Name             Profile  Instance   Placement  |
|                          ID       ID       Start:Size |
|=======================================================|
|   0  MIG 7g.96gb          0        0          0:8     |
+-------------------------------------------------------+
(base) nikola@gracehopper:~$ echo $CUDA_VISIBLE_DEVICES

(base) nikola@gracehopper:~$ nvidia-smi -L
GPU 0: NVIDIA GH200 480GB (UUID: GPU-d8731c65-c898-919e-74c9-286b27400dac)
  MIG 7g.96gb     Device  0: (UUID: MIG-7baedeb1-c0d7-53ba-9926-2e341a42b470)

But now when I run the following code:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Specify the index of the GPU you want to use

import jax
import jax.numpy as jnp

# Let's define a simple matrix multiplication function
def matmul_on_gpu(a, b):
    return jnp.dot(a, b)

# Main function to demonstrate GPU acceleration
def main():
    # Create some random matrices
    size = 1000
    a = jax.random.normal(jax.random.PRNGKey(0), (size, size))
    b = jax.random.normal(jax.random.PRNGKey(1), (size, size))

    # Run matrix multiplication on GPU
    result = matmul_on_gpu(a, b)

    # Print the result
    print("Result of matrix multiplication:")
    print(result)

if __name__ == "__main__":
    main()

I get the following error:

root@ce139f7a5d68:~# python jax_program.py 
2024-05-22 19:00:07.301687: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:474] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-05-22 19:00:07.301822: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:478] Memory usage: 100780867584 bytes free, 101468602368 bytes total.
2024-05-22 19:00:07.302174: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:474] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-05-22 19:00:07.302277: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:478] Memory usage: 100780867584 bytes free, 101468602368 bytes total.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/jax_program.py", line 26, in <module>
    main()
  File "/root/jax_program.py", line 15, in main
    a = jax.random.normal(jax.random.PRNGKey(0), (size, size))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/random.py", line 240, in PRNGKey
    return _return_prng_keys(True, _key('PRNGKey', seed, impl))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/random.py", line 202, in _key
    return prng.random_seed(seed, impl=impl)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/prng.py", line 595, in random_seed
    seeds_arr = jnp.asarray(np.int64(seeds))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 2217, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order)  # type: ignore
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 2172, in array
    out_array: Array = lax_internal._convert_element_type(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py", line 560, in _convert_element_type
    return convert_element_type_p.bind(operand, new_dtype=new_dtype,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 444, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 447, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 935, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

@nouiz is this the cuDNN problem you mentioned? what can I do to check it?

@giladqm
Copy link
Author

giladqm commented May 22, 2024

Yes I have cuDNN version 9.1.0.

@hawkinsp
Copy link
Member

Yes, you need to downgrade to CUDNN 8.9 for now. JAX doesn't yet release with CUDNN 9.

@nouiz
Copy link
Collaborator

nouiz commented May 22, 2024

Before downgrading, which container do you use and how JAX was installed?
If you use the JAX-Toolbox jax container, you have a good combination of JAX (nightly), cudnn (9.1.1), and CUDA 12.4.1.

Did you try with the JAX-toolbox container without setting CUDA_VISIBLE_DEVICES?
Why do you try to set it? If there is only 1 GPU, normally JAX will just find and use it. So you don't need to set it.
The MIG listing isn't the same as normal GPU.
Also, you can't do multi-gpu across MIGs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

5 participants