New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Ir fill constant #56520
[IR] Ir fill constant #56520
Conversation
… ir_fill_constant
你的PR提交成功,感谢你对开源项目的贡献! |
ir::Program* program, | ||
const std::vector<std::string>& args, | ||
int axis = 0) { | ||
ir::OpInfo op_info = ctx->GetRegisteredOpInfo("pd.stack"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ir::OpInfo op_info = ctx->GetRegisteredOpInfo("pd.stack"); | |
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(dialect::StackOp::name()); |
@@ -1175,6 +1198,149 @@ struct TrilAndTriuOpTranscriber : public OpTranscriber { | |||
} | |||
}; | |||
|
|||
struct FillConstantTranscriber : public OpTranscriber { | |||
ir::Operation* operator()(ir::IrContext* ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不建议直接重载 operator,可以看看OpTranscriber::operator()
的实现,里面做的事情其实是比较多的,自己实现一个容易遗漏。
如果只需要修改处理输入的方式,可以只重载OpTranscriber::GenerateOperationInput
;
如果需要修改处理属性的方式,可以重载OpTranscriber::TranslateOpAttribute
,
如果需要处理特定输入/属性,可以考虑重载GetSpecialInputHandlers
/GetSpecialAttributeHandlers
has_mutable_attribute |= op_desc.HasInput("ValueTensor", true) && | ||
op_desc.Input("ValueTensor", true).size() > 0; | ||
|
||
if (!has_mutable_attribute) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议写两个不同的OpTranscriber
,通过Op名指定分发方式,如
struct FillConstantTranscriber : public OpTranscriber {
ir::Operation* operator()(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Program* program) override {
bool has_mutable_attribute = op_desc.HasInput("ShapeTensor", true) && op_desc.Input("ShapeTensor", true).size() > 0;
has_mutable_attribute |= op_desc.HasInput("ShapeTensorList", true) && op_desc.Input("ShapeTensorList", true).size() > 0;
has_mutable_attribute |= op_desc.HasInput("ValueTensor", true) && op_desc.Input("ValueTensor", true).size() > 0;
if (has_mutable_attribute) {
return FullOpTranscriber()(ctx, param_map, op_desc, program);
} else {
return FullWithInputOpTranscriber()(ctx, param_map, op_desc, program);
}
}
};
然后FullOpTranscriber
和FullWithInputOpTranscriber
可以通过继承的方式做一下代码复用
auto op_info = ctx->GetRegisteredOpInfo("pd.fill_with_tensor"); | ||
if (!op_info) { | ||
IR_THROW( | ||
"Op tril_triu should have corresponding OpInfo " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Op tril_triu should have corresponding OpInfo " | |
"Op fill_constant with mutable attribute should have corresponding OpInfo " |
op_desc.Input("ValueTensor", true).size() > 0; | ||
|
||
if (!has_mutable_attribute) { | ||
auto op_info = ctx->GetRegisteredOpInfo("pd.full"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto op_info = ctx->GetRegisteredOpInfo("pd.full"); | |
auto op_info = ctx->GetRegisteredOpInfo(dialect::FullOp::name()); |
ir::Operation* operation = ir::Operation::Create( | ||
{}, attribute_map, {defining_info.value.type()}, op_info); | ||
program->block()->push_back(operation); | ||
return operation; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在的实现里缺少对OpTranscriber::RecordOpResultMapping
的调用,这个函数是用来记录VarDesc
和ir::Value
之间的映射的,缺失会导致bug。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
param : [shape, dtype] | ||
kernel : | ||
func : full_with_tensor | ||
data_type : dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要支持 inplace 版本么?参考 full_
* support ir fill constant
PR types
Bug fixes
PR changes
Others
Description
Pcard-67164
解决IR不支持fill_constant传递ShapeTensor、ShapeTensorList、ValueTensor问题。
关于fill_with_tensor的实现,由于Tensor shape可能在运行时才有值,且InforMata传入的是shape的MetaTensor。导致InforMata的时候无法推到shape。所以InforMata将shape设置成[-1],在FullWithTensorKernel中,再跟进shape做一次Resize。