Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make einsum_v2 support multi-operands #42327

Merged
merged 4 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/operators/einsum_op.cc
Expand Up @@ -18,7 +18,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -85,7 +85,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(einsum, EinsumInferShapeFunctor,
PD_INFER_META(phi::EinsumInferShape));
PD_INFER_META(phi::EinsumInferMeta));

REGISTER_OPERATOR(einsum, ops::EinsumOp, ops::EinsumOpMaker,
EinsumInferShapeFunctor,
Expand Down
40 changes: 40 additions & 0 deletions paddle/phi/infermeta/unary.cc
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"

namespace phi {

Expand Down Expand Up @@ -398,6 +399,45 @@ void EighInferMeta(const MetaTensor& x,
out_v->set_dims(input_dim);
}

void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(inputs.size(), LabelMap(-1));
std::vector<char> all_labels;
std::vector<int> broadcast_dims;
std::vector<int> output_dims;
std::vector<std::vector<int>> ellipsis_dims(2);

std::vector<DDim> input_dims;
for (auto& i : inputs) {
input_dims.push_back(i->dims());
}
std::string right;
ParseEinsumEquation(equation,
input_dims,
&labelshape,
&labeltype,
&all_labels,
&label2perms,
&ellipsis_dims,
&broadcast_dims,
&output_dims,
&right);

VLOG(3) << "Einsum Infershape: input dims:"
<< paddle::string::join_strings(input_dims, "\n");
VLOG(3) << "Einsum Infershape: equation:" << equation;
VLOG(3) << "Einsum Infershape: all_labels:"
<< paddle::string::join_strings(all_labels, ",");
VLOG(3) << "Einsum Infershape: output dims:"
<< paddle::string::join_strings(output_dims, ",");
VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype);
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
}

void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/unary.h
Expand Up @@ -80,6 +80,10 @@ void EighInferMeta(const MetaTensor& x,
MetaTensor* out_w,
MetaTensor* out_v);

void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out);

void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out);
Expand Down
56 changes: 9 additions & 47 deletions paddle/phi/kernels/impl/einsum_impl.h
Expand Up @@ -13,14 +13,14 @@
// limitations under the License.
#pragma once

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/utils/string/string_helper.h"

namespace phi {

// check the validation of the Einsum equation.
// 1. the label must between 'a' - 'z'.
// 2. the dim of the same label must be same.
Expand Down Expand Up @@ -302,45 +302,6 @@ inline static void ParseEinsumEquation(
}
}

inline void EinsumInferShape(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(inputs.size(), LabelMap(-1));
std::vector<char> all_labels;
std::vector<int> broadcast_dims;
std::vector<int> output_dims;
std::vector<std::vector<int>> ellipsis_dims(2);

std::vector<DDim> input_dims;
for (auto& i : inputs) {
input_dims.push_back(i->dims());
}
std::string right;
ParseEinsumEquation(equation,
input_dims,
&labelshape,
&labeltype,
&all_labels,
&label2perms,
&ellipsis_dims,
&broadcast_dims,
&output_dims,
&right);

VLOG(3) << "Einsum Infershape: input dims:"
<< paddle::string::join_strings(input_dims, "\n");
VLOG(3) << "Einsum Infershape: equation:" << equation;
VLOG(3) << "Einsum Infershape: all_labels:"
<< paddle::string::join_strings(all_labels, ",");
VLOG(3) << "Einsum Infershape: output dims:"
<< paddle::string::join_strings(output_dims, ",");
VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype);
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
}

template <typename T>
std::vector<T> GetLabelIndexByType(const std::vector<char>& all_labels,
const LabelMap& type,
Expand Down Expand Up @@ -394,19 +355,20 @@ DenseTensor PerformReduction(const Context& dev_ctx,
return Sum<T, Context>(dev_ctx, tensor, indices, tensor.dtype(), true);
}

inline bool is_no_need_transpose(const std::vector<int>& axis) {
for (size_t i = 0; i < axis.size(); ++i) {
if (i != static_cast<size_t>(axis[i])) return false;
}
return true;
}

template <typename T, typename Context>
DenseTensor PerformTranspose(const Context& dev_ctx,
const DenseTensor& tensor,
const LabelMap& label2perm,
const std::vector<char>& all_labels,
const std::vector<int>& ellipsis,
const LabelMap& label2type) {
auto is_no_need_transpose = [](std::vector<int>& axis) {
for (size_t i = 0; i < axis.size(); ++i) {
if (i != size_t(axis[i])) return false;
}
return true;
};
auto axis = GetLabelIndexByType<int>(
all_labels, label2type, label2perm, ellipsis, LabelType::ALL_TYPE);
VLOG(5) << "PerformTranspose: " << paddle::string::join_strings(axis, ",");
Expand Down Expand Up @@ -496,9 +458,9 @@ void TransposeToOutput(const Context& dev_ctx,
axis.push_back(it - all_labels.begin() + offset);
}
}
if (is_no_need_transpose(axis)) return output->ShareBufferWith(to_trans);
VLOG(5) << "call TransposeToOutput: with axis: "
<< paddle::string::join_strings(axis, ",");
if (axis.size() == 0) return output->ShareBufferWith(to_trans);
return TransposeKernel<T, Context>(dev_ctx, to_trans, axis, output);
}

Expand Down