Skip to content

Commit

Permalink
Make einsum_v2 support multi-operands (#42327)
Browse files Browse the repository at this point in the history
* Extend python einsum interface to make einsum_v2 support multi-operands and switch it to default.

* add opt_einsum dependence

* add yaml and support eager model

* fix by code review
  • Loading branch information
2742195759 committed Apr 29, 2022
1 parent 21d94dd commit 32cae24
Show file tree
Hide file tree
Showing 9 changed files with 679 additions and 49 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/operators/einsum_op.cc
Expand Up @@ -18,7 +18,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -85,7 +85,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(einsum, EinsumInferShapeFunctor,
PD_INFER_META(phi::EinsumInferShape));
PD_INFER_META(phi::EinsumInferMeta));

REGISTER_OPERATOR(einsum, ops::EinsumOp, ops::EinsumOpMaker,
EinsumInferShapeFunctor,
Expand Down
40 changes: 40 additions & 0 deletions paddle/phi/infermeta/unary.cc
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"

namespace phi {

Expand Down Expand Up @@ -398,6 +399,45 @@ void EighInferMeta(const MetaTensor& x,
out_v->set_dims(input_dim);
}

void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(inputs.size(), LabelMap(-1));
std::vector<char> all_labels;
std::vector<int> broadcast_dims;
std::vector<int> output_dims;
std::vector<std::vector<int>> ellipsis_dims(2);

std::vector<DDim> input_dims;
for (auto& i : inputs) {
input_dims.push_back(i->dims());
}
std::string right;
ParseEinsumEquation(equation,
input_dims,
&labelshape,
&labeltype,
&all_labels,
&label2perms,
&ellipsis_dims,
&broadcast_dims,
&output_dims,
&right);

VLOG(3) << "Einsum Infershape: input dims:"
<< paddle::string::join_strings(input_dims, "\n");
VLOG(3) << "Einsum Infershape: equation:" << equation;
VLOG(3) << "Einsum Infershape: all_labels:"
<< paddle::string::join_strings(all_labels, ",");
VLOG(3) << "Einsum Infershape: output dims:"
<< paddle::string::join_strings(output_dims, ",");
VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype);
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
}

void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/unary.h
Expand Up @@ -80,6 +80,10 @@ void EighInferMeta(const MetaTensor& x,
MetaTensor* out_w,
MetaTensor* out_v);

void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out);

void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out);
Expand Down
56 changes: 9 additions & 47 deletions paddle/phi/kernels/impl/einsum_impl.h
Expand Up @@ -13,14 +13,14 @@
// limitations under the License.
#pragma once

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/utils/string/string_helper.h"

namespace phi {

// check the validation of the Einsum equation.
// 1. the label must between 'a' - 'z'.
// 2. the dim of the same label must be same.
Expand Down Expand Up @@ -302,45 +302,6 @@ inline static void ParseEinsumEquation(
}
}

inline void EinsumInferShape(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(inputs.size(), LabelMap(-1));
std::vector<char> all_labels;
std::vector<int> broadcast_dims;
std::vector<int> output_dims;
std::vector<std::vector<int>> ellipsis_dims(2);

std::vector<DDim> input_dims;
for (auto& i : inputs) {
input_dims.push_back(i->dims());
}
std::string right;
ParseEinsumEquation(equation,
input_dims,
&labelshape,
&labeltype,
&all_labels,
&label2perms,
&ellipsis_dims,
&broadcast_dims,
&output_dims,
&right);

VLOG(3) << "Einsum Infershape: input dims:"
<< paddle::string::join_strings(input_dims, "\n");
VLOG(3) << "Einsum Infershape: equation:" << equation;
VLOG(3) << "Einsum Infershape: all_labels:"
<< paddle::string::join_strings(all_labels, ",");
VLOG(3) << "Einsum Infershape: output dims:"
<< paddle::string::join_strings(output_dims, ",");
VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype);
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
}

template <typename T>
std::vector<T> GetLabelIndexByType(const std::vector<char>& all_labels,
const LabelMap& type,
Expand Down Expand Up @@ -394,19 +355,20 @@ DenseTensor PerformReduction(const Context& dev_ctx,
return Sum<T, Context>(dev_ctx, tensor, indices, tensor.dtype(), true);
}

inline bool is_no_need_transpose(const std::vector<int>& axis) {
for (size_t i = 0; i < axis.size(); ++i) {
if (i != static_cast<size_t>(axis[i])) return false;
}
return true;
}

template <typename T, typename Context>
DenseTensor PerformTranspose(const Context& dev_ctx,
const DenseTensor& tensor,
const LabelMap& label2perm,
const std::vector<char>& all_labels,
const std::vector<int>& ellipsis,
const LabelMap& label2type) {
auto is_no_need_transpose = [](std::vector<int>& axis) {
for (size_t i = 0; i < axis.size(); ++i) {
if (i != size_t(axis[i])) return false;
}
return true;
};
auto axis = GetLabelIndexByType<int>(
all_labels, label2type, label2perm, ellipsis, LabelType::ALL_TYPE);
VLOG(5) << "PerformTranspose: " << paddle::string::join_strings(axis, ",");
Expand Down Expand Up @@ -496,9 +458,9 @@ void TransposeToOutput(const Context& dev_ctx,
axis.push_back(it - all_labels.begin() + offset);
}
}
if (is_no_need_transpose(axis)) return output->ShareBufferWith(to_trans);
VLOG(5) << "call TransposeToOutput: with axis: "
<< paddle::string::join_strings(axis, ",");
if (axis.size() == 0) return output->ShareBufferWith(to_trans);
return TransposeKernel<T, Context>(dev_ctx, to_trans, axis, output);
}

Expand Down

0 comments on commit 32cae24

Please sign in to comment.