Skip to content

Commit

Permalink
[PIR]Migrate API fused_linear_activation,fused_matmul_bias,softmax_ma…
Browse files Browse the repository at this point in the history
…sk_fuse,FusedLinear to PIR (#61331)

* adapt api

* fix gemm epilogure

* fix llm

* fix gemm

* fix compile
  • Loading branch information
YuanRisheng committed Feb 1, 2024
1 parent 969037f commit 5eae192
Show file tree
Hide file tree
Showing 17 changed files with 306 additions and 47 deletions.
20 changes: 20 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,26 @@ pir::Value assign(const pir::Value& x) {
}
}

std::tuple<pir::Value, pir::Value> fused_gemm_epilogue(pir::Value x,
pir::Value y,
pir::Value bias,
bool trans_x,
bool trans_y,
std::string activation) {
pir::IrContext* ctx = pir::IrContext::Instance();
pir::AttributeMap attribute_map = {
{"trans_x", pir::BoolAttribute::get(ctx, trans_x)},
{"trans_y", pir::BoolAttribute::get(ctx, trans_y)},
{"activation", pir::StrAttribute::get(ctx, activation)}};
auto fused_gemm_epilogue_op =
ApiBuilder::Instance()
.GetBuilder()
->Build<paddle::dialect::FusedGemmEpilogueOp>(
x, y, bias, attribute_map);
return std::make_tuple(fused_gemm_epilogue_op.result(0),
fused_gemm_epilogue_op.result(1));
}

pir::Value array_pop(pir::Value input, int index) {
if (input.type().isa<paddle::dialect::DenseTensorArrayType>()) {
paddle::dialect::ArrayPopOp array_pop_op =
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ pir::Value slice_array_dense(pir::Value input, pir::Value starts);

pir::Value assign(const pir::Value& x);

std::tuple<pir::Value, pir::Value> fused_gemm_epilogue(pir::Value x,
pir::Value y,
pir::Value bias,
bool trans_x,
bool trans_y,
std::string activation);
pir::Value array_pop(pir::Value input, int index);

} // namespace dialect
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,16 @@
view : (mean -> mean_out), (variance -> variance_out)
backward : fused_bn_add_activation_grad

- op : fused_softmax_mask
args : (Tensor x, Tensor mask)
output : Tensor(out)
infer_meta :
func : SoftmaxMaskFuseInferMeta
kernel :
func : fused_softmax_mask
data_type : x
backward: fused_softmax_mask_grad

- op : fused_softmax_mask_upper_triangle
args : (Tensor X)
output : Tensor(Out)
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,17 @@
func: fused_feedforward_grad
optional: linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln1_out, ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance, dropout2_out, ln1_scale_grad, ln1_bias_grad, ln2_scale_grad, ln2_bias_grad, linear2_bias_grad

- backward_op : fused_softmax_mask_grad
forward : fused_softmax_mask (Tensor x, Tensor mask) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param: [out]
kernel :
func : fused_softmax_mask_grad
data_type : out

- backward_op : fused_softmax_mask_upper_triangle_grad
forward : fused_softmax_mask_upper_triangle(Tensor X) -> Tensor(Out)
args: (Tensor Out, Tensor Out_grad)
Expand Down
53 changes: 53 additions & 0 deletions paddle/fluid/pybind/manual_static_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,39 @@ static PyObject *run_custom_op(PyObject *self,
}
}

static PyObject *static_api_fused_gemm_epilogue(PyObject *self,
PyObject *args,
PyObject *kwargs) {
try {
VLOG(6) << "Running Static API: fused_gemm_epilogue";

VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
PyObject *x_obj = PyTuple_GET_ITEM(args, 0);
auto x = CastPyArg2Value(x_obj, "fused_gemm_epilogue", 0);
PyObject *y_obj = PyTuple_GET_ITEM(args, 1);
auto y = CastPyArg2Value(y_obj, "fused_gemm_epilogue", 1);
PyObject *bias_obj = PyTuple_GET_ITEM(args, 2);
auto bias = CastPyArg2Value(bias_obj, "fused_gemm_epilogue", 2);

// Parse Attributes if needed
PyObject *trans_x_obj = PyTuple_GET_ITEM(args, 3);
bool trans_x = CastPyArg2Boolean(trans_x_obj, "fused_gemm_epilogue", 3);
PyObject *trans_y_obj = PyTuple_GET_ITEM(args, 4);
bool trans_y = CastPyArg2Boolean(trans_y_obj, "fused_gemm_epilogue", 4);
PyObject *activation_obj = PyTuple_GET_ITEM(args, 5);
std::string activation =
CastPyArg2String(activation_obj, "fused_gemm_epilogue", 5);

// Call ir static api
auto out = paddle::dialect::fused_gemm_epilogue(
x, y, bias, trans_x, trans_y, activation);
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
static PyObject *static_api_array_pop(PyObject *self,
PyObject *args,
PyObject *kwargs) {
Expand All @@ -802,6 +835,22 @@ static PyObject *static_api_array_pop(PyObject *self,
}
}

extern PyObject *eager_api_fused_gemm_epilogue(PyObject *self,
PyObject *args,
PyObject *kwargs);

static PyObject *fused_gemm_epilogue(PyObject *self,
PyObject *args,
PyObject *kwargs) {
if (egr::Controller::Instance().GetCurrentTracer() == nullptr) {
VLOG(6) << "Call static_api_fused_gemm_epilogue";
return static_api_fused_gemm_epilogue(self, args, kwargs);
} else {
VLOG(6) << "Call eager_api_fused_gemm_epilogue";
return eager_api_fused_gemm_epilogue(self, args, kwargs);
}
}

static PyMethodDef ManualOpsAPI[] = {
{"set_parameter",
(PyCFunction)(void (*)(void))static_api_set_parameter,
Expand Down Expand Up @@ -851,6 +900,10 @@ static PyMethodDef ManualOpsAPI[] = {
(PyCFunction)(void (*)(void))static_api_slice_array_dense,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for slice_array_dense."},
{"fused_gemm_epilogue",
(PyCFunction)(void (*)(void))fused_gemm_epilogue,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for fused_gemm_epilogue."},
{"_run_custom_op",
(PyCFunction)(void (*)(void))run_custom_op,
METH_VARARGS | METH_KEYWORDS,
Expand Down
113 changes: 113 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@ limitations under the License. */
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/fusion.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/utils/flags.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#endif

PD_DECLARE_int32(low_precision_op_list);

namespace paddle {
namespace experimental {

Expand Down Expand Up @@ -221,6 +226,114 @@ Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
return out;
}

std::tuple<Tensor, Tensor> fused_gemm_epilogue_impl(
const Tensor& x,
const Tensor& y,
const Tensor& bias,
bool trans_x,
bool trans_y,
const std::string& activation) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;

if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x, y, bias);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}

VLOG(6) << "fused_gemm_epilogue API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"fused_gemm_epilogue",
{kernel_backend, kernel_layout, kernel_data_type},
true);
const auto& kernel = kernel_result.kernel;
if (FLAGS_low_precision_op_list) {
phi::KernelFactory::Instance().AddToLowPrecisionKernelList(
"fused_gemm_epilogue", kernel_data_type);
}
VLOG(6) << "fused_gemm_epilogue kernel: " << kernel;
// add actual_kernel_backend to select actual kernel backend after a potential
// falling-back to CPU
Backend actual_kernel_backend =
kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend;
auto* dev_ctx = GetDeviceContextByBackend(actual_kernel_backend);

auto input_x = PrepareData(
x,
GetKernelInputArgDef(kernel.InputAt(0), actual_kernel_backend),
{},
kernel_result.is_stride_kernel);
auto input_y = PrepareData(
y,
GetKernelInputArgDef(kernel.InputAt(1), actual_kernel_backend),
{},
kernel_result.is_stride_kernel);
auto input_bias = PrepareData(
bias,
GetKernelInputArgDef(kernel.InputAt(2), actual_kernel_backend),
{},
kernel_result.is_stride_kernel);

std::tuple<Tensor, Tensor> api_output;
auto kernel_out_0 = SetKernelOutput(&std::get<0>(api_output));
phi::DenseTensor* kernel_out_1 = nullptr;
if (activation != "none") {
kernel_out_1 = SetKernelOutput(&std::get<1>(api_output));
}

phi::MetaTensor meta_out_0(kernel_out_0, kernel_result.is_stride_kernel);
phi::MetaTensor meta_out_1(kernel_out_1, kernel_result.is_stride_kernel);

phi::FusedGemmEpilogueInferMeta(MakeMetaTensor(*input_x),
MakeMetaTensor(*input_y),
MakeMetaTensor(*input_bias),
trans_x,
trans_y,
activation,
kernel_out_0 ? &meta_out_0 : nullptr,
kernel_out_1 ? &meta_out_1 : nullptr);

using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
bool,
bool,
const std::string&,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();

(*kernel_fn)(*dev_ctx,
*input_x,
*input_y,
*input_bias,
trans_x,
trans_y,
activation,
kernel_out_0,
kernel_out_1);

if (kernel_result.has_fallback_cpu) {
TransDataBackend(kernel_out_0, kernel_backend, kernel_out_0);
TransDataBackend(kernel_out_1, kernel_backend, kernel_out_1);
}
return api_output;
}

////////////////// Backward(grad) api impls //////////////////////

void embedding_grad_impl(const Tensor& x,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ Tensor add_n_impl(const std::vector<Tensor>& x);

Tensor copy_to_impl(const Tensor& x, Place place, bool blocking);

std::tuple<Tensor, Tensor> fused_gemm_epilogue_impl(
const Tensor& x,
const Tensor& y,
const Tensor& bias,
bool trans_x,
bool trans_y,
const std::string& activation);

////////////////// Backward(grad) api impls //////////////////////

void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad);
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,17 @@
data_type : out_grad
optional : reserve_space

- backward_op : fused_softmax_mask_grad
forward : fused_softmax_mask (Tensor x, Tensor mask) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param: [out]
kernel :
func : fused_softmax_mask_grad
data_type : out

- backward_op : fused_softmax_mask_upper_triangle_grad
forward : fused_softmax_mask_upper_triangle(Tensor X) -> Tensor(Out)
args: (Tensor Out, Tensor Out_grad)
Expand Down Expand Up @@ -855,6 +866,16 @@
func: check_model_nan_inf
data_type: out_grad

- backward_op: fused_gemm_epilogue_grad
forward : fused_gemm_epilogue(Tensor x, Tensor y, Tensor bias, bool trans_x, bool trans_y, str activation) -> Tensor(out), Tensor(reserve_space)
args : (Tensor x, Tensor y, Tensor reserve_space, Tensor out_grad, bool trans_x, bool trans_y, str activation)
output : Tensor(x_grad), Tensor(y_grad), Tensor(bias_grad)
infer_meta :
func : FusedGemmEpilogueGradInferMeta
kernel:
func : fused_gemm_epilogue_grad
optional : reserve_space

- backward_op: unpool_grad
forward: unpool (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -> Tensor(out)
args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format)
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,23 @@
view : (mean -> mean_out), (variance -> variance_out)
backward : fused_bn_add_activation_grad

- op : fused_gemm_epilogue
args : (Tensor x, Tensor y, Tensor bias, bool trans_x, bool trans_y, str activation)
output : Tensor(out), Tensor(reserve_space)
invoke : fused_gemm_epilogue_impl(x, y, bias, trans_x, trans_y, activation)
backward: fused_gemm_epilogue_grad
optional: reserve_space

- op : fused_softmax_mask
args : (Tensor x, Tensor mask)
output : Tensor(out)
infer_meta :
func : SoftmaxMaskFuseInferMeta
kernel :
func : fused_softmax_mask
data_type : x
backward: fused_softmax_mask_grad

- op : fused_softmax_mask_upper_triangle
args : (Tensor X)
output : Tensor(Out)
Expand Down
8 changes: 5 additions & 3 deletions python/paddle/distributed/fleet/layers/mpu/mp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import paddle
from paddle.autograd import PyLayer
from paddle.base import core
from paddle.distributed import fleet
from paddle.nn import functional as F

Expand All @@ -34,7 +33,7 @@


def is_fused_matmul_bias_supported():
return hasattr(core.eager.ops.legacy, 'fused_gemm_epilogue')
return hasattr(paddle._C_ops, 'fused_gemm_epilogue')


def is_fused_linear_param_grad_add_supported():
Expand Down Expand Up @@ -214,7 +213,10 @@ def forward(
if not fuse_matmul_bias:
return paddle._C_ops.linear(x, weight, bias)
else:
return paddle._legacy_C_ops.fused_gemm_epilogue(x, weight, bias)
result, _ = paddle._C_ops.fused_gemm_epilogue(
x, weight, bias, False, False, "none"
)
return result

@staticmethod
def backward(ctx, dy):
Expand Down
Loading

0 comments on commit 5eae192

Please sign in to comment.