Skip to content

Commit

Permalink
[ROCm] re-enable support RCCL with fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jayfurmanek committed May 15, 2024
1 parent a8fb85f commit 36afb63
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 16 deletions.
29 changes: 20 additions & 9 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,8 @@ cc_library(
"nccl_recv_thunk.h",
"nccl_send_thunk.h",
],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) +
if_rocm_is_configured(["TENSORFLOW=1"]),
visibility = ["//visibility:public"],
deps = [
":backend_configs_cc",
Expand Down Expand Up @@ -936,7 +937,6 @@ cc_library(
# have `if_nccl` and `if_gpu_configured` that do not compose. NCCL header included directly in
# :nccl_api target and all other targets should use this header to launch collective operations.
# This allows to minimize the spreading of #ifdef all over the XLA code base.

alias(
name = "nccl_api",
actual = if_nccl(":_nccl_api_impl", ":_nccl_api_stub"),
Expand All @@ -945,13 +945,13 @@ alias(

cc_library(
name = "_nccl_api_impl",
srcs = if_cuda_is_configured(
srcs = if_gpu_is_configured(
["nccl_api.cc"],
["nccl_api_stub.cc"],
),
hdrs = ["nccl_api.h"],
compatible_with = get_compatible_with_portable(),
defines = if_cuda_is_configured(["XLA_ENABLE_XCCL"]), # TODO(ezhulenev): Remove!
defines = if_gpu_is_configured(["XLA_ENABLE_XCCL"]), # TODO(ezhulenev): Remove!
visibility = ["//visibility:public"],
deps = [
":nccl_clique_key",
Expand All @@ -972,6 +972,10 @@ cc_library(
"@local_config_nccl//:nccl",
"//xla/stream_executor/cuda:cuda_driver",
"//xla/stream_executor/cuda:cuda_executor",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rccl",
"//xla/stream_executor/rocm:rocm_driver",
"//xla/stream_executor/rocm:rocm_executor",
]) + if_gpu_is_configured([
"//xla/stream_executor/gpu:gpu_stream",
]),
Expand Down Expand Up @@ -1059,7 +1063,8 @@ cc_library(
name = "mock_nccl_xml_google",
srcs = ["mock_nccl_xml.cc"],
hdrs = ["mock_nccl_xml.h"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) +
if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
tags = ["manual"],
visibility = ["//visibility:public"],
deps = [
Expand All @@ -1075,14 +1080,17 @@ cc_library(
"@local_tsl//tsl/platform:regexp",
] + if_cuda_is_configured([
"@local_config_nccl//:nccl",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rccl",
]),
)

xla_cc_test(
name = "mock_nccl_xml_test",
size = "small",
srcs = if_google(["mock_nccl_xml_test.cc"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) +
if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
tags = tf_cuda_tests_tags(),
deps = [
"//xla:status",
Expand All @@ -1092,6 +1100,8 @@ xla_cc_test(
":mock_nccl_xml_google",
]) + if_cuda_is_configured([
"@local_config_nccl//:nccl",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rccl",
]),
)

Expand Down Expand Up @@ -1194,6 +1204,8 @@ cc_library(
"//xla/stream_executor",
]) + if_cuda_is_configured([
"@local_config_nccl//:nccl",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rccl",
]),
)

Expand All @@ -1211,9 +1223,8 @@ cc_library(
hdrs = [
"gpu_executable.h",
],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM=1",
]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) +
if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
visibility = ["//visibility:public"],
deps = [
":buffer_allocations",
Expand Down
16 changes: 15 additions & 1 deletion third_party/xla/xla/service/gpu/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,18 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"

#if TENSORFLOW_USE_ROCM
#include "rocm/rocm_config.h"
#if (TF_ROCM_VERSION >= 50200)
#include "rocm/include/rccl/rccl.h"
#else
#include "rocm/include/rccl.h"
#endif // TF_ROCM_VERSION >= 50200
#else
#include "third_party/nccl/nccl.h"
#endif // TENSORFLOW_USE_ROCM || GOOGLE_CUDA

#include "xla/primitive_util.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/nccl_clique_key.h"
Expand Down Expand Up @@ -510,9 +521,10 @@ DefaultNcclApi::RegisterBuffer(NcclCommHandle comm,
"Register buffer for NCCL communicator; buffer=%p; size=%d; comm=%p",
buffer.opaque(), buffer.size(), comm);
void* handle = nullptr;
#if (NCCL_VERSION_CODE >= 21901)
XLA_NCCL_RETURN_IF_ERROR(
ncclCommRegister(Cast(comm), buffer.opaque(), buffer.size(), &handle));

#endif // NCCL_VERSION_CODE >= 21901
return reinterpret_cast<NcclRegisteredBufferHandle>(handle);
}

Expand All @@ -522,8 +534,10 @@ DefaultNcclApi::DeregisterBuffer(NcclCommHandle comm,
VLOG(3) << absl::StreamFormat(
"Deregister buffer for NCCL communicator; handle=%p; comm=%p", handle,
comm);
#if (NCCL_VERSION_CODE >= 21901)
return XLA_NCCL_STATUS(
ncclCommDeregister(Cast(comm), reinterpret_cast<void*>(handle)));
#endif // NCCL_VERSION_CODE >= 21901
}

} // namespace xla::gpu
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ NcclAllGatherStartThunk::NcclAllGatherStartThunk(
ThunkInfo thunk_info, NcclApi* nccl_api,
const HloAllGatherInstruction* inst, std::vector<Buffer> buffers)
: NcclCollectiveThunk(Thunk::kNcclAllGatherStart, thunk_info, nccl_api,
inst->backend_config<GpuBackendConfig>()
->collective_backend_config()
.is_sync()),
IsSyncCollective(inst)),
config_(impl::GetNcclAllGatherConfig(inst)),
buffers_(std::move(buffers)) {
CHECK_EQ(config_.config.operand_count, buffers_.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,7 @@ NcclAllReduceStartThunk::NcclAllReduceStartThunk(
: NcclAllReduceReduceScatterThunkBase(
Thunk::kNcclAllReduceStart, thunk_info, nccl_api,
impl::GetNcclAllReduceConfigInst(inst), std::move(buffers),
inst->backend_config<GpuBackendConfig>()
->collective_backend_config()
.is_sync()) {}
IsSyncCollective(inst)) {}

absl::Status NcclAllReduceStartThunk::CheckImplementable(
AllReduceStartOp op, int64_t replica_count, int64_t partition_count) {
Expand Down

0 comments on commit 36afb63

Please sign in to comment.