Skip to content

Commit

Permalink
[CINN] Add FactorizeReduction schedule primitive (#57777)
Browse files Browse the repository at this point in the history
Add FactorizeReduction schedule primitive.
The difference between FactorizeReduction primitive and the original RFactor primitive are:

FactorizeReduction supports complex iters_value subscript, which means that FactorizeReduction can be used after using primitives such as Fuse and Split, and RFactor does not support this.
FactorizeReduction does not change the order of the original loop, while RFactor may have an implicit Reorder.
FactorizeReduction supports the transformation of one reduce block in a complex AST, while RFactor only supports the case where the AST is entirely composed of one reduce block.
  • Loading branch information
BiynXu committed Oct 9, 2023
1 parent 8a42a34 commit 0e4f474
Show file tree
Hide file tree
Showing 7 changed files with 820 additions and 1 deletion.
264 changes: 264 additions & 0 deletions paddle/cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,270 @@ void test_rfactor(void* _args, int32_t num_args)
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
}

TEST(IrSchedule, factorize_reduction) {
Context::Global().ResetNameId();
Expr M(3);
Expr N(4);
Expr K(5);

Target target = common::DefaultHostTarget();

Placeholder<float> A("A", {M, N, K});
Var j(4, "j0");
Var k(5, "k0");
auto B = Compute(
{M},
[&](Var i) {
return lang::ReduceSum(A(i, j, k), {j, k});
},
"B");

auto stages = CreateStages({A, B});
auto func = cinn::lang::LowerVec("test_factorize_reduction",
stages,
{A, B},
{},
{},
nullptr,
target,
true);
CHECK(!func.empty());
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
auto loops = ir_sch.GetLoops("B");
CHECK_EQ(loops.size(), 3U);
auto new_rf_tensor = ir_sch.FactorizeReduction(loops[1], 0);
auto* new_rf_tensor_ref = new_rf_tensor.As<ir::_Tensor_>();
CHECK(new_rf_tensor_ref);
CHECK(new_rf_tensor_ref->buffer.defined());
func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer);
func[0]->PrepareBufferCastExprs();
std::string origin = utils::GetStreamCnt(func[0]);
LOG(INFO) << origin;
EXPECT_EQ(origin, utils::Trim(R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[vj0, i0_0] = 0.00000000f
}
serial for (k0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, i2 = axis.bind(j0, i, k0)
B_rf[vj0, i0_0] = (B_rf[vj0, i0_0] + A[i0_0, vj0, i2])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[vj0, i0_0])
}
}
}
}
}
}
)ROC"));
}

TEST(IrSchedule, factorize_reduction1) {
Context::Global().ResetNameId();
Expr M(3);
Expr N(4);
Expr K(5);

Target target = common::DefaultHostTarget();

Placeholder<float> A("A", {M, N, K});
Var j(4, "j0");
Var k(5, "k0");
auto B = Compute(
{M},
[&](Var i) {
return lang::ReduceSum(A(i, j, k), {j, k});
},
"B");

auto stages = CreateStages({A, B});
auto func = cinn::lang::LowerVec("test_factorize_reduction",
stages,
{A, B},
{},
{},
nullptr,
target,
true);
CHECK(!func.empty());
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
auto loops = ir_sch.GetLoops("B");
CHECK_EQ(loops.size(), 3U);
auto new_rf_tensor = ir_sch.FactorizeReduction(loops[1], 1);
auto* new_rf_tensor_ref = new_rf_tensor.As<ir::_Tensor_>();
CHECK(new_rf_tensor_ref);
CHECK(new_rf_tensor_ref->buffer.defined());
func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer);
func[0]->PrepareBufferCastExprs();
std::string origin = utils::GetStreamCnt(func[0]);
LOG(INFO) << origin;
EXPECT_EQ(origin, utils::Trim(R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[i0_0, vj0] = 0.00000000f
}
serial for (k0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, i2 = axis.bind(j0, i, k0)
B_rf[i0_0, vj0] = (B_rf[i0_0, vj0] + A[i0_0, vj0, i2])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[i0_0, vj0])
}
}
}
}
}
}
)ROC"));
}

TEST(IrSchedule, factorize_reduction2) {
Context::Global().ResetNameId();
Expr M(3);
Expr N(4);
Expr K(5);

Target target = common::DefaultHostTarget();

Placeholder<float> A("A", {M, N * K});
Var j(4 * 5, "j0");
auto B = Compute(
{M}, [&](Var i) { return lang::ReduceSum(A(i, j), {j}); }, "B");

auto stages = CreateStages({A, B});
auto func = cinn::lang::LowerVec("test_factorize_reduction",
stages,
{A, B},
{},
{},
nullptr,
target,
true);
CHECK(!func.empty());
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
auto loops = ir_sch.GetLoops("B");
CHECK_EQ(loops.size(), 2U);
auto splited_loops = ir_sch.Split(loops[1], {4, 5});
CHECK_EQ(splited_loops.size(), 2U);
auto new_rf_tensor = ir_sch.FactorizeReduction(splited_loops[0], 1);
auto* new_rf_tensor_ref = new_rf_tensor.As<ir::_Tensor_>();
CHECK(new_rf_tensor_ref);
CHECK(new_rf_tensor_ref->buffer.defined());
func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer);
func[0]->PrepareBufferCastExprs();
std::string origin = utils::GetStreamCnt(func[0]);
LOG(INFO) << origin;
EXPECT_EQ(origin, utils::Trim(R"ROC(
function test_factorize_reduction (_A, _B)
{
ScheduleBlock(root)
{
{
serial for (i, 0, 3)
{
serial for (j0, 0, 4)
{
ScheduleBlock(B_rf__reduce_init)
{
vj0, i0_0 = axis.bind(j0, i)
B_rf__reduce_init[i0_0, vj0] = 0.00000000f
}
serial for (j0_0, 0, 5)
{
ScheduleBlock(B_rf)
{
vj0, i0_0, vj0_0 = axis.bind(j0, i, j0_0)
B_rf[i0_0, vj0] = (B_rf[i0_0, vj0] + A[i0_0, ((5 * vj0) + vj0_0)])
}
}
}
}
serial for (i, 0, 3)
{
ScheduleBlock(B__reduce_init)
{
i0_0 = axis.bind(i)
B__reduce_init[i0_0] = 0.00000000f
}
serial for (j0, 0, 4)
{
ScheduleBlock(B)
{
vj0, i0_0 = axis.bind(j0, i)
B[i0_0] = (B[i0_0] + B_rf[i0_0, vj0])
}
}
}
}
}
}
)ROC"));
}

TEST(IrSchedule, compute_inline1) {
Context::Global().ResetNameId();
Expr M(32);
Expand Down
Loading

0 comments on commit 0e4f474

Please sign in to comment.