Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reduce_scatter_add #64198

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions paddle/fluid/operators/collective/c_reducescatter_add_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_reducescatter_add_op.h"

#include <memory>

namespace paddle {
namespace operators {

class CReduceScatterAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("x"), "Input", "x", "ReduceScatterAdd");
OP_INOUT_CHECK(ctx->HasInput("bias"), "Input", "bias", "ReduceScatterAdd");
OP_INOUT_CHECK(ctx->HasOutput("out"), "Output", "out", "ReduceScatter");
int nranks = ctx->Attrs().Get<int>("nranks");
framework::DDim dim = ctx->GetInputDim("x");
if (dim[0] > 0 || dim[0] < -1) {
PADDLE_ENFORCE_EQ(
dim[0] % nranks,
0,
phi::errors::InvalidArgument(
"dim[0] (%d) is not divisible by nranks(%d)", dim[0], nranks));
dim[0] /= nranks;
}
ctx->SetOutputDim("out", dim);
}
};

class CReduceScatterAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor) tensor to be reduce scatter");
AddInput("bias", "(Tensor) tensor to be add after reduce scatter");
AddOutput("out", "(Tensor) the reduce scatter result");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
AddAttr<int>("nranks",
"Total trainer count of the distributed training job")
.SetDefault(1);

AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
CReduceScatterAdd Operator

Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#reducescatter
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_WITHOUT_GRADIENT(c_reducescatter_add,
ops::CReduceScatterAddOp,
ops::CReduceScatterAddOpMaker);

PD_REGISTER_STRUCT_KERNEL(c_reducescatter_add,
CPU,
ALL_LAYOUT,
ops::CReduceScatterAddOpCPUKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
146 changes: 146 additions & 0 deletions paddle/fluid/operators/collective/c_reducescatter_add_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_reducescatter_add_op.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/common/flags.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
COMMON_DECLARE_bool(dynamic_static_unified_comm);
#endif

namespace paddle {
namespace operators {

template <typename T, typename DeviceContext>
class CReduceScatterAddOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto in = ctx.Input<phi::DenseTensor>("x");
auto bias = ctx.Input<phi::DenseTensor>("bias");
auto out = ctx.Output<phi::DenseTensor>("out");

int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();

auto out_dims = in->dims();
gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
phi::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
phi::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
PADDLE_ENFORCE_EQ(out_dims[0] % comm_ctx->GetSize(),
0,
phi::errors::InvalidArgument(
"The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0],
comm_ctx->GetSize()));

stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has ring_id " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(out_dims[0] % comm->nranks(),
0,
phi::errors::InvalidArgument(
"The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0],
comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has ring_id " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
}

int nranks = comm_ctx ? comm_ctx->GetSize() : comm->nranks();
PADDLE_ENFORCE_EQ(out_dims[0] % nranks,
0,
phi::errors::InvalidArgument(
"The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0],
nranks));
out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place);

int64_t recv_numel = in->numel() / nranks;
const T* send_buff = in->data<T>();
const T* add_buff = bias->data<T>();
T* recv_buff = out->data<T>();
int dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(in->dtype()));

if (comm_ctx) {
comm_ctx->ReduceScatterAdd(out, *in, *bias, ncclSum, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclReduceScatterAdd(send_buff,
recv_buff,
add_buff,
recv_numel,
static_cast<ncclDataType_t>(dtype),
ncclSum,
comm->comm(),
stream));
}
#else
PADDLE_THROW(phi::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

PD_REGISTER_STRUCT_KERNEL(c_reducescatter_add,
GPU,
ALL_LAYOUT,
ops::CReduceScatterAddOpCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
phi::dtype::bfloat16,
#endif
int,
int64_t,
phi::dtype::float16) {
}
40 changes: 40 additions & 0 deletions paddle/fluid/operators/collective/c_reducescatter_add_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <utility>
#include <vector>

#include "paddle/common/ddim.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h"

namespace paddle {
namespace operators {

template <typename T, typename DeviceContext>
class CReduceScatterAddOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx UNUSED) const override {
PADDLE_THROW(phi::errors::Unimplemented(
"Unimplemented cpu kernel for CReduceScatterAddOp."));
}
};

} // namespace operators
} // namespace paddle
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,15 @@
kernel :
func : c_reducescatter

- op : c_reducescatter_add
args : (Tensor x, Tensor bias, int ring_id = 0, int nranks = 1, bool use_calc_stream = false)
output : Tensor(out)
infer_meta :
func : ReduceScatterInferMeta
param: [x, nranks]
kernel :
func : c_reducescatter_add

- op : c_scatter
args : (Tensor x, int ring_id = 0, int root = 0, int nranks = 0, bool use_calc_stream = false)
output : Tensor(out)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ const std::unordered_set<std::string> LegacyOpList = {
CReduceSumOp::name(),
CReduceSum_Op::name(),
CReducescatterOp::name(),
CReducescatterAddOp::name(),
CAllreduceMax_Op::name(),
CAllreduceMin_Op::name(),
CAllgatherOp::name(),
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/platform/dynload/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace dynload {
__macro(ncclGroupEnd); \
__macro(ncclReduce); \
__macro(ncclReduceScatter); \
__macro(ncclReduceScatterAdd); \
__macro(ncclCommGetAsyncError); \
__macro(ncclGetErrorString);

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/dynload/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ extern void* nccl_dso_handle;
__macro(ncclGroupEnd); \
__macro(ncclReduce); \
__macro(ncclReduceScatter); \
__macro(ncclReduceScatterAdd); \
__macro(ncclCommGetAsyncError); \
__macro(ncclGetErrorString);

Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/core/distributed/nccl_comm_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,33 @@ void NCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor,
stream));
}

void NCCLCommContext::ReduceScatterAdd(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const phi::DenseTensor& bias_tensor,
ncclRedOp_t reduce_type,
gpuStream_t stream) {
phi::distributed::CommStaticCheck::ScatterLikeShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor,
/*root_rank*/ 0,
rank_,
nccl_comm_);
}
NCCL_CHECK(
phi::dynload::ncclReduceScatterAdd(in_tensor.data(),
out_tensor->data(),
const_cast<void*>(bias_tensor.data()),
out_tensor->numel(),
ToNCCLDataType(in_tensor.type()),
reduce_type,
nccl_comm_,
stream));
}

void NCCLCommContext::Send(const phi::DenseTensor& in_tensor,
const int64_t& count,
const int& peer,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/core/distributed/nccl_comm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ class NCCLCommContext final : public CommContext {
ncclRedOp_t reduce_type,
gpuStream_t stream);

void ReduceScatterAdd(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const phi::DenseTensor& bias_tensor,
ncclRedOp_t reduce_type,
gpuStream_t stream);

void AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
gpuStream_t stream);
Expand Down