Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Strong constraint branch #58719

Merged
merged 6 commits into from
Nov 13, 2023

Conversation

jiahy0825
Copy link
Contributor

@jiahy0825 jiahy0825 commented Nov 6, 2023

PR types

New features

PR changes

Others

Description

pcard-76996

一、背景概述

前端 fusion_merge pass 决定了 Group 划分,Group 划分对生成的 kernel 性能有着至关重要的影响。前端的 Group 圈定了哪些 OP 应该放到一个 kernel 里,Group 相当于是对 OP 的一个 view,真正的代码生成逻辑在后端开展。在实际工作时发现,如果 fusion_merge 内的融合策略发生了变化,生成出来的某些 Group,在和后端对接时,程序会直接崩溃,这极为容易引发用户的 panic。

对上述情况进行梳理,可以发现存在的两个关键问题:

  • 问题1:不能够明确代码生成失败,是前端划定的 Group 有问题,还是后端的机制不支持——前后端缺少一个协议
  • 问题2:循环融合散落在两个地方处理:Group 融合时考虑了循环融合的可能,但是循环融合实际却发生在 CINN 后端——两者不一致,会经常导致 bug

现在前后端耦合程度高,问题排查难。需要有一个前后端公认的协议来作为它们之间的桥梁
在强约束分支的设计里,MapExpr 正是这一座桥梁。为了 MapExpr 的正确生成,我们发展了一套关于下标方程组的理论,从理论上保障融合正确性、完备性

二、架构设计

image

在工程实现上,强约束分支可总结为如下五个关键模块:

  • 联立器:为每个算子构造约束方程组,并对方程组进行联立
  • 划分器:基于约束方程组,实现算子融合,将可融合的算子划分到同一个 IGroup 内(IGroup = Inline Group)
  • 判别器:判断基于 IGroup 的方程组是否有解
  • 求解器:求解出每个 Tensor 对应的下标索引表达式
  • 生成器:根据 IGroup 划分结果和方程组的解,生成 MapExpr

三、子模块介绍

3.1 联立器

Q: 如何构造约束方程组?多个方程组之间如何联立?
A: 为算子挂载GenerateEquations多态方法,构造算子的约束方程组

核心代码:

  • paddle/cinn/hlir/op 目录内,为 relu、broadcast_to 等算子挂载了 generate_equations 函数
  • paddle/cinn/adt/naive_op_equation_context.cc 中的 GenerateOpEquations函数,查找并调用 op 对应的 generate_equations 函数

代码示例

void GenerateEquationsForRelu(cinn::adt::config::OpEquationContext *ctx) {
  CHECK(ctx->GetInTensorsRanks().size() != 0)
      << "The inputs is empty! Please check again.";
  ctx->Equal(ctx->GetInIndex(0), ctx->GetOutIndex(0));
}

OpEquationContext 为用户提供下列接口,基本可以满足构造方程组的需求
image

3.2 划分器

Q: 如何基于约束方程组,实现算子融合?
A: 约束方程组等价转换为 EquationGraph,在图上基于 AnchorTensor 进行遍历,遍历到的节点表示可以放到同一个 IGroup 内。

核心代码:

  • paddle/cinn/adt/partition_op_stmts.h 中的 PartitionOpStmts函数

流程图-202311072118

从 AnchorTensor 开始对 EquationGraph 进行遍历,遍历到的所有 Tensor 可以放到同一组内。上述过程对 Tensor 进行了划分,实际上融合的粒度却是算子,如何利用划分好的 op 集合,实现算子融合?
IGroup 融合:即算子融合。如果算子的输入输出 Tensor 均在划分出来的 op 集合内,则表明算子属于当前的 IGroup(即 Inline Group,表示 IGroup 内的算子可以内联在一起)。

下述伪代码介绍了核心逻辑,实际函数签名和伪代码有出入

std::vector<IGroup> PartitionIGroups(std::unordered_set<AnchorIndex>* candidate_anchor_indexes,
                                     const EquationGraph& equation_graph) {
  std::vector<IGroup> igroups{};
  while (!candidate_anchor_indexes->empty()) {
    AnchorIndex anchor_index = PickThenEraseAnchorIndex(candidate_anchor_indexes);
    const auto& visited_op_stmts = FindVisitedOpStmts(anchor_index, equation_graph);
    EraseCandidateAnchorIndexes(visited_op_stmts, candidate_anchor_indexes);
    igroups.emplace_back(IGroup{visited_op_stmts});
  }
  return igroups;
}

3.3 判别器

Q: 如何判断约束方程组是否有解?
A: 方程组有解,需要确保:如果有多条路径可以遍历到同一节点,该节点对应的解是唯一的

核心代码:

  • paddle/cinn/adt/equation_solver.cc 中的 CheckEquationsSolvable函数
  • paddle/cinn/adt/equation_solver.cc 中的 MergeInferedValuesIntoCtx函数,对方程组是否有解做了判断

不可解的例子

x: (a0, 1)
y: (1, a0)
z: (a0, a0)

y = reshape(op_0...op_n(x))
z = x + y

// 以 z 为 AnchorTensor,即遍历起点
z_index = Dot((z_i, z_j), (a0, 1))
// 1. 直接由 z 来算 x 的下标表达式
x_index1 = z_i
// 2. 如果 z 经过 y 算 x 的下标表达式
x_index2 = z_j
// 此时方程无解
x_index1 != x_index2

3.4 求解器

Q: 约束方程组的解如何定义?
A: 根据约束方程组,求解出每个 Tensor 对应的下标索引表达式。该表达式指明 Tensor 下标和调度描述符的换算关系
Q: 如何求解约束方程组?
A: 从 Schedule Descriptor 开始,遍历 Equation Graph,在遍历的过程中,计算每个 Tensor 对应的表达式

核心代码:

  • paddle/cinn/adt/equation_solver.cc 中的 SolveEquations函数

3.5 生成器

Q: 如何根据方程组的解,完成代码生成?
A: 前置步骤已经提供了充分的信息:IGroup 的划分方式 + 每个 Tensor 索引的表达式,本组件只需要根据 MapExpr 的格式,按部就班生成即可。

核心代码:

  • paddle/cinn/adt/generate_map_expr.cc 中的 GenerateMapExpr函数

四、如何运行

4.1 开启 FLAGS_cinn_enable_map_expr:

export FLAGS_cinn_enable_map_expr=true

4.2 执行 python 脚本:

cd test/cinn/adt
python test_reduce_fusion.py

五、输出预览

5.1 简单示例:以 Tensor x 和 y 为输入,执行 elementwise_add 和 reduce 两个算子

builder = NetBuilder("MapExprTest")
x = builder.create_input(Float(32), self.inputs["x"].shape, "x")
y = builder.create_input(Float(32), self.inputs["y"].shape, "y")
t = builder.elementwise_add(x, y)
out = builder.reduce_sum(t, [0], False)

5.2 输出结果

fn_elementwise_add_0_reduce_sum_1_3(&t_var_2[], t_y[], t_x[]) {
  AnchoredMapStmt(t_var_1) {
    MapStmt([i_75, i_76]) {
      reduce_sum_init(&t_var_2);
      MapStmt([i_77]) {
        elementwise_add(&t_var_1, t_x, t_y);
        reduce_sum_acc(&t_var_2, t_var_1);
      }
    }
  }
}

5.3 按行解析各字段含义

字段 含义
fn_elementwise_add_0_reduce_sum_1_3(&t_var_2[], t_y[], t_x[]) Kernel 名称为 fn_elementwise_add_0_reduce_sum_1_3,该 Kernel 以 t_var_2 为输出,t_x 和 t_y 为输入,&为输出 Tensor 标识符
AnchoredMapStmt(t_var_1) AnchoredMapStmt 内的 op 以 t_var_1 为 AnchorTensor,根据 t_var_1 的下标可以推断出 Stmt 内所有其他 Tensor 的下标
MapStmt([i_75, i_76]) MapStmt 内所有 op 的循环迭代量有两个 i_75 和 i_76
reduce_sum_init(&t_var_2) reduce_sum_init 算子的输出 Tensor 为 t_var_2
elementwise_add(&t_var_1, t_x, t_y) elementwise_add 算子的输出 Tensor 为 t_var_1,输入 Tensor 为 t_x 和 t_y
reduce_sum_acc(&t_var_2, t_var_1) reduce_sum_acc 算子的输出 Tensor 为 t_var_2,输入 Tensor 为 t_var_1

@jiahy0825
Copy link
Contributor Author

由于原 PR ( #57543 )的 cla 检查一直处于 pending 状态,将代码合并为一个 commit 并迁移至本 PR

Copy link
Contributor Author

@jiahy0825 jiahy0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

迁移原来 PR #57543 中的 comments

paddle/cinn/adt/print_utils/CMakeLists.txt Outdated Show resolved Hide resolved
paddle/cinn/hlir/framework/op_lowering_impl.cc Outdated Show resolved Hide resolved
paddle/cinn/adt/adapter_tensor.h Outdated Show resolved Hide resolved
paddle/cinn/adt/adt.h Show resolved Hide resolved
paddle/cinn/adt/kgroup.h Show resolved Hide resolved
paddle/cinn/hlir/framework/op_lowering_impl.cc Outdated Show resolved Hide resolved
paddle/cinn/hlir/op/reduction.cc Outdated Show resolved Hide resolved
paddle/cinn/hlir/op/broadcast.cc Show resolved Hide resolved
Copy link
Contributor

@BiynXu BiynXu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for OpLower

Copy link
Member

@zhhsplendid zhhsplendid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jiahy0825 jiahy0825 merged commit e1c9e50 into PaddlePaddle:develop Nov 13, 2023
28 checks passed
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* Strong Constraint Branch

* Change UpdateOpLoweredFuncKey location (PaddlePaddle#86)

* Remove useless parameter (PaddlePaddle#87)

* Change codes according to comments (PaddlePaddle#89)

* Delete useless code (PaddlePaddle#91)
SecretXV pushed a commit to SecretXV/Paddle that referenced this pull request Nov 28, 2023
* Strong Constraint Branch

* Change UpdateOpLoweredFuncKey location (PaddlePaddle#86)

* Remove useless parameter (PaddlePaddle#87)

* Change codes according to comments (PaddlePaddle#89)

* Delete useless code (PaddlePaddle#91)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants