Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Ir fill constant #56520

Merged
merged 26 commits into from Aug 23, 2023
Merged

Conversation

wanghuancoder
Copy link
Contributor

@wanghuancoder wanghuancoder commented Aug 22, 2023

PR types

Bug fixes

PR changes

Others

Description

Pcard-67164
解决IR不支持fill_constant传递ShapeTensor、ShapeTensorList、ValueTensor问题。

  1. 添加fill_with_tensor算子,参数为(Tensor shape, Tensor value, DataType dtype=DataType::FLOAT32)
  2. fill_with_tensor底层调用FullKernel
  3. IR Translator中添加FillConstantTranscriber
  • 如果没有ShapeTensor、ShapeTensorList、ValueTensor,则调用pd.full
  • 如果有ShapeTensorList,则插入pd.stack,用于将Tensor合并
  • 如果没有ShapeTensor、ShapeTensorList则插入pd.full_int_array,用于将shape转为Tensor
  • 如果没有ValueTensor则插入pd.full,用于将value转为Tensor

关于fill_with_tensor的实现,由于Tensor shape可能在运行时才有值,且InforMata传入的是shape的MetaTensor。导致InforMata的时候无法推到shape。所以InforMata将shape设置成[-1],在FullWithTensorKernel中,再跟进shape做一次Resize。

@paddle-bot
Copy link

paddle-bot bot commented Aug 22, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

ir::Program* program,
const std::vector<std::string>& args,
int axis = 0) {
ir::OpInfo op_info = ctx->GetRegisteredOpInfo("pd.stack");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ir::OpInfo op_info = ctx->GetRegisteredOpInfo("pd.stack");
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(dialect::StackOp::name());

@@ -1175,6 +1198,149 @@ struct TrilAndTriuOpTranscriber : public OpTranscriber {
}
};

struct FillConstantTranscriber : public OpTranscriber {
ir::Operation* operator()(ir::IrContext* ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不建议直接重载 operator,可以看看OpTranscriber::operator()的实现,里面做的事情其实是比较多的,自己实现一个容易遗漏。
如果只需要修改处理输入的方式,可以只重载OpTranscriber::GenerateOperationInput
如果需要修改处理属性的方式,可以重载OpTranscriber::TranslateOpAttribute,
如果需要处理特定输入/属性,可以考虑重载GetSpecialInputHandlers/GetSpecialAttributeHandlers

has_mutable_attribute |= op_desc.HasInput("ValueTensor", true) &&
op_desc.Input("ValueTensor", true).size() > 0;

if (!has_mutable_attribute) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议写两个不同的OpTranscriber,通过Op名指定分发方式,如

struct FillConstantTranscriber : public OpTranscriber {
  ir::Operation* operator()(ir::IrContext* ctx,
                            TranslationContext* param_map,
                            const OpDesc& op_desc,
                            ir::Program* program) override {
    bool has_mutable_attribute = op_desc.HasInput("ShapeTensor", true) && op_desc.Input("ShapeTensor", true).size() > 0;
    has_mutable_attribute |= op_desc.HasInput("ShapeTensorList", true) && op_desc.Input("ShapeTensorList", true).size() > 0;
    has_mutable_attribute |= op_desc.HasInput("ValueTensor", true) && op_desc.Input("ValueTensor", true).size() > 0;
    if (has_mutable_attribute) {
      return FullOpTranscriber()(ctx, param_map, op_desc, program);
    } else {
      return FullWithInputOpTranscriber()(ctx, param_map, op_desc, program);
    }
}
};

然后FullOpTranscriberFullWithInputOpTranscriber可以通过继承的方式做一下代码复用

auto op_info = ctx->GetRegisteredOpInfo("pd.fill_with_tensor");
if (!op_info) {
IR_THROW(
"Op tril_triu should have corresponding OpInfo "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Op tril_triu should have corresponding OpInfo "
"Op fill_constant with mutable attribute should have corresponding OpInfo "

op_desc.Input("ValueTensor", true).size() > 0;

if (!has_mutable_attribute) {
auto op_info = ctx->GetRegisteredOpInfo("pd.full");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto op_info = ctx->GetRegisteredOpInfo("pd.full");
auto op_info = ctx->GetRegisteredOpInfo(dialect::FullOp::name());

ir::Operation* operation = ir::Operation::Create(
{}, attribute_map, {defining_info.value.type()}, op_info);
program->block()->push_back(operation);
return operation;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在的实现里缺少对OpTranscriber::RecordOpResultMapping的调用,这个函数是用来记录VarDescir::Value之间的映射的,缺失会导致bug。

kangguangli
kangguangli previously approved these changes Aug 22, 2023
Copy link
Contributor

@kangguangli kangguangli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

param : [shape, dtype]
kernel :
func : full_with_tensor
data_type : dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要支持 inplace 版本么?参考 full_

@wanghuancoder wanghuancoder merged commit e914f7f into PaddlePaddle:develop Aug 23, 2023
26 checks passed
BeingGod pushed a commit to BeingGod/Paddle that referenced this pull request Sep 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants