diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 91605abf3f..d8b1e188f8 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -74,7 +74,9 @@ #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/distributed/service.h" +#if defined(REACTANT_CUDA) || defined(REACTANT_ROCM) #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#endif #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_executable.h" @@ -439,6 +441,7 @@ MakeGPUClient(int node_id, int num_nodes, int64_t *allowed_devices, int64_t num_allowed_devices, double memory_fraction, bool preallocate, const char *platform_name, const char **error, void *distributed_runtime_client) { +#if defined(REACTANT_CUDA) || defined(REACTANT_ROCM) GpuClientOptions options; if (num_nodes > 1) { @@ -483,6 +486,10 @@ MakeGPUClient(int node_id, int num_nodes, int64_t *allowed_devices, auto client = std::move(clientErr).value(); return client.release(); } +#else + *error = "ReactantExtra was not built with GPU support"; + return nullptr; +#endif } const char *const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH"; diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 0df55adea0..cd29cdb699 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -914,7 +914,7 @@ cc_library( "-Werror=return-type", "-Werror=unused-result", "-Wno-error=stringop-truncation", - ] + if_cuda(["-DREACTANT_CUDA=1"]), + ] + if_cuda(["-DREACTANT_CUDA=1"]) + if_rocm(["-DREACTANT_ROCM=1"]), linkopts = select({ "//conditions:default": [], "@bazel_tools//src/conditions:darwin": [ @@ -1089,7 +1089,6 @@ cc_library( "@xla//xla/stream_executor/tpu:tpu_executor", "@xla//xla/stream_executor/tpu:tpu_transfer_manager", "@xla//xla/service/cpu:cpu_transfer_manager", - "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", "@xla//xla/tsl/protobuf:protos_all_cc_impl", "@xla//xla/tsl/framework:allocator_registry_impl", "@xla//xla/pjrt:status_casters", @@ -1143,11 +1142,7 @@ cc_library( "@llvm-project//mlir:CAPILLVMObjects", "@jax//jaxlib/mosaic:tpu_dialect_capi_objects", "@jax//jaxlib/triton:triton_dialect_capi_objects", - "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", - "@xla//xla/service:gpu_plugin", - "@xla//xla/pjrt/c:pjrt_c_api_gpu", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", - "@xla//xla/pjrt/plugin/xla_gpu:xla_gpu_pjrt_client", "@stablehlo//:linalg_passes", "@stablehlo//:tosa_passes", "@stablehlo//:stablehlo_passes", @@ -1163,6 +1158,11 @@ cc_library( "@xla//xla/service:hlo_proto_cc_impl", "@com_google_absl//absl/status:statusor", ] + if_cuda([ + "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", + "@xla//xla/service:gpu_plugin", + "@xla//xla/pjrt/c:pjrt_c_api_gpu", + "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", + "@xla//xla/pjrt/plugin/xla_gpu:xla_gpu_pjrt_client", "@jax//jaxlib/cuda:cuda_gpu_kernels", "@xla//xla/backends/profiler:profiler_backends", "@xla//xla/backends/profiler/gpu:device_tracer", @@ -1175,6 +1175,10 @@ cc_library( "@xla//xla/stream_executor:kernel", "@xla//xla/stream_executor/cuda:all_runtime", ]) + if_rocm([ + "@xla//xla/service:gpu_plugin", + "@xla//xla/pjrt/c:pjrt_c_api_gpu", + "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", + "@xla//xla/pjrt/plugin/xla_gpu:xla_gpu_pjrt_client", "@xla//xla/stream_executor:rocm_platform", "@xla//xla/service/gpu:amdgpu_compiler", "@xla//xla/backends/profiler/gpu:device_tracer",