Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Strong constraint branch support dynamic shape #59309

Merged
merged 42 commits into from
Dec 6, 2023

Conversation

jiahy0825
Copy link
Contributor

@jiahy0825 jiahy0825 commented Nov 23, 2023

PR types

New features

PR changes

Others

Description

pcard-76996

强约束分支支持动态形状,使用 ShapeDialect 中的约束关系构建方程组,然后使用方程组进行符号化推导,得到每个迭代量对应的符号表达式。

exp + sub 最小子图,求解得到的 MapExpr 如下:

MapExprTest(t_var_2, t_var_1) {
  AnchoredMapStmt(t_var_0) {
    MapStmt([i_59, i_60]) {
      exp(
          &t_var[IndexDot([BI(i_59, sym_17), 0], [sym_17, 1])],
          t_var_1[IndexDot([BI(i_59, sym_17), 0], [sym_17, 1])]);
      subtract(
          &t_var_0[IndexDot([i_59, i_60], [sym_17, 1])],
          t_var_2[IndexDot([BI(i_59, sym_17), 0], [sym_17, 1])],
          t_var[IndexDot([BI(i_59, sym_17), 0], [sym_17, 1])]);
    }
  }
}

Copy link

paddle-bot bot commented Nov 23, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Aurelius84
Aurelius84 previously approved these changes Dec 1, 2023
Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

}
};

inline std::size_t GetHashValueImpl(const DynamicTensor& tensor) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size_t 作为hash_key是有隐藏的hash冲突问题。对于Hash Map,当出现key的冲突时,会再次调用operator==。但直接以size_t作为key,二阶段的判断operator==是无效的,请注意这个风险会带来难排查的问题。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实 Value 可以直接作为HashKey的

Copy link
Contributor Author

@jiahy0825 jiahy0825 Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢提醒 😃 目前是直接把 pir::Value 当作 HashKey 来用的~
此处在重载 Tensor 的 std::hash 函数,由于 DynamicTensor 是 Tensor 的子类型,因此强约束分支内须实现 DynamicTensor 的 GetHashValueImpl 方法

@@ -74,6 +84,7 @@ class Tuple {
std::make_shared<std::tuple<Ts...>>(std::forward<Args>(args)...)) {}

const std::tuple<Ts...>& tuple() const { return *tuple_; }
std::tuple<Ts...>* mut_tuple() { return &*tuple_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mut 是表示 mutable?,有点Rust 的风格~~,框架内部用的最多的是mutable全称,保持一致是否好一些?

Copy link
Contributor Author

@jiahy0825 jiahy0825 Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,表示 mutable~这里后续提个 PR 修改下

} else {
return Op<DimExpr>{ret_operand};
}
LOG(FATAL) << "Dead code.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一行是不是永远都不会触发?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

随着未来可能的不断迭代,if 分支的逻辑可能变复杂,这里是写的防御性代码,保证触发 if-else 分支,不会到达这一行。

} else {
return expr;
}
LOG(FATAL) << "Dead code";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,下面似乎也有

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上


template <>
struct GetOrderValue<BroadcastedDim<DimExpr>> {
static constexpr int value = 10;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的value 似乎与IsLhsBeforeRhs 逻辑有关系,想多了解下这里定义order的依据是什么,比如为什么Product 放在 Sum 之前? std::int64_t> 对应的order既没有放着最前面,也没有放到最后面,而是重新中间,这个又是为什么?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是运算的优先级?类似乘法高于加法?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处 order 的值没有特别的依据,这里的 order 主要还是为了排序时相同类型的 DimExpr 可以放在一起。
这样做有利于简化 DimExpr 的判等

};

std::int64_t GetInteger(const DimExpr& expr) {
if (expr.Has<Negative<DimExpr>>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的Has感觉像是 Is的语义?或者对最外层判断Is语义?

Copy link
Contributor Author

@jiahy0825 jiahy0825 Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has 在判断 DimExpr 最外层的子类型是什么,确实很贴近 Is 的语义


template <>
struct GetInversed<Sum> {
static DimExpr Call(const DimExpr& expr) { return Negative<DimExpr>(expr); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

S1 + S2 的反义是 -(S1+S2) ? 还是S1 - S2 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果只看这里的 Inversed 实现,S1 + S2 对应的是 -(S1 + S2)
实际上调用的地方在 FlattenOperands pass 内,会先把 S1 + S2 展开,然后再取负,经过 pass 优化后表达式会变为 -S1 - S2

DimExpr ret = expr;
for (bool keep_rewrite = true; keep_rewrite;) {
keep_rewrite = false;
DoPass<SimplifyOneOperand<Negative>>(&keep_rewrite, &ret);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是一个逐步收敛的优化过程是吧?这里有评估执行效率如何么,会有可感知的耗时么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是一个逐步收敛的优化过程是吧?

是的。没有实际测过,但 DimExpr 相对来说都比较简单,预计不会成为性能瓶颈

constants_provider)
explicit IGroup(const List<OpStmt>& op_stmts,
const AnchorIndex& anchor_index,
const EquationCtx4OpStmtT& EquationCtx4OpStmt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里4 表示 For ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的~


PD_DEFINE_bool(cinn_enable_map_expr_index_detail,
BoolFromEnv("FLAGS_cinn_enable_map_expr_index_detail", false),
"It controls whether to display datail tensor index");
Copy link
Contributor

@ZzSean ZzSean Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

datail->detailed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks~下个 PR 改

@@ -91,6 +91,14 @@ PD_DEFINE_bool(cinn_enable_map_expr_inline,
BoolFromEnv("FLAGS_cinn_enable_map_expr_inline", false),
"It controls whether to inline by map_expr");

PD_DEFINE_bool(cinn_enable_map_expr_dynamic_shape,
BoolFromEnv("FLAGS_cinn_enable_map_expr_dynamic_shape", false),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个flag好像没用到

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test/cpp/pir/cinn/adt/map_expr_test.cc 单测以及 paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc 文件中使用,表示会根据 shape dialect 进行符号推导

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jiahy0825 jiahy0825 merged commit 74dd107 into PaddlePaddle:develop Dec 6, 2023
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants