Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[IrSchedule] SamplePerfectTile #1142

Merged
merged 3 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2871,5 +2871,23 @@ void TestIrSchedule_ReduceSum(void* _args, int32_t num_args)
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
}

TEST(IrSchedule, SamplePerfectTile) {
Context::Global().ResetNameId();
Expr M(1024);
Placeholder<int> A("A", {M});
auto B = Compute(
{M}, [&](Expr i) { return A(i) + 1; }, "B");
poly::StageMap stages = CreateStages({A, B});

auto funcs = cinn::lang::LowerVec(
"test_sampleperfecttile", stages, {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true);

ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body}));
auto loops_b = ir_sch.GetLoops("B");
std::vector<Expr> result = ir_sch.SamplePerfectTile(loops_b[0], 3, 64);
LOG(INFO) << "SamplePerfectTile result: " << result;
ASSERT_EQ(result.size(), 3);
}

} // namespace backends
} // namespace cinn
38 changes: 38 additions & 0 deletions cinn/ir/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

#include "cinn/common/cas.h"
Expand Down Expand Up @@ -70,6 +72,7 @@ class ScheduleImpl {
Expr GetBlock(const std::string& block_name) const;
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
std::vector<Expr> Split(const std::string& block_name, int loop_index, const std::vector<int>& factors);
std::vector<Expr> SamplePerfectTile(const uint32_t seed, const Expr& loop, int n, int max_innermost_factor);
Expr Fuse(const std::vector<Expr>& loops);
Expr Fuse(const std::string& block_name, const std::vector<int>& loops_index);
Expr Fuse(const Expr& block, const std::vector<int>& loops_index);
Expand Down Expand Up @@ -1790,6 +1793,32 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block
this->Replace(all_loops[0], res);
}

std::vector<Expr> ScheduleImpl::SamplePerfectTile(const uint32_t seed,
const Expr& loop,
int n,
int max_innermost_factor) {
CHECK(loop.As<ir::For>()) << "Expr param of SamplePerfectTile should be a For loop";
CHECK_GE(n, 2) << "The number of tile factors should be at least 2";
CHECK_GE(max_innermost_factor, 1) << "The max innermost factor should be at least 1";
CHECK(common::is_zero(loop.As<ir::For>()->min)) << "The For loop should start from 0";
int loop_extent = GetLoopExtent(loop);
std::vector<int> innermost_factors;
for (int i = max_innermost_factor; i >= 1; --i) {
if (loop_extent % i == 0) {
innermost_factors.push_back(i);
}
}
CHECK(!innermost_factors.empty()) << "No innermost factor found";
int innermost_factor = innermost_factors[ir::SampleInt(0, innermost_factors.size() - 1, seed)];
auto result = SampleTile(seed, n - 1, loop_extent / innermost_factor);
std::vector<Expr> result_expr;
for (auto& factor : result) {
result_expr.push_back(Expr(factor));
}
result_expr.push_back(Expr(innermost_factor));
return result_expr;
}

IRSchedule::IRSchedule() {}

IRSchedule::IRSchedule(const ModuleExpr& module_expr, bool debug_flag) {
Expand Down Expand Up @@ -2042,5 +2071,14 @@ void IRSchedule::CopyTransformAndLoopInfo(const std::string& block_name, const s
// don't support to trace, because we can't ensure both blocks are from the same ModuleExpr
}

std::vector<Expr> IRSchedule::SamplePerfectTile(const Expr& loop, int n, int max_innermost_factor) {
auto result = impl_->SamplePerfectTile(ir::RandomSeedController::seed, loop, n, max_innermost_factor);
trace_.Append(ScheduleDesc::Step("SamplePerfectTile",
{{"loop", std::vector<Expr>({loop})}},
{{"n", n}, {"max_innermost_factor", max_innermost_factor}},
{result}));
return result;
}

} // namespace ir
} // namespace cinn
15 changes: 15 additions & 0 deletions cinn/ir/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,21 @@ class IRSchedule {
// TODO(sunli): Solve Index Simplify.
void FlattenLoops(const std::vector<Expr>& loops, const bool force_flat = false);

/*!
* \brief Sample the factors to tile a specific loop perfectly
* \param loop the loop to be split
* \param n the number of loop layers to split
* \param max_innermost_factor the maximum factor of the innermost loop
* \return the split factors of the loop (The larger the index, the inner the corresponding loop)
* For example, return {16,64} means the loop will be like this:
* for (i, 0, 16) {
* for (j, 0, 64) {
* ...
* }
* }
*/
std::vector<Expr> SamplePerfectTile(const Expr& loop, int n, int max_innermost_factor);
CtfGo marked this conversation as resolved.
Show resolved Hide resolved

private:
std::unique_ptr<ScheduleImpl> impl_;
mutable ScheduleDesc trace_; // trace the scheduling process
Expand Down
43 changes: 43 additions & 0 deletions cinn/ir/ir_schedule_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -962,5 +962,48 @@ bool ContainVar(const std::vector<Expr>& exprs, const std::string& var_name) {
return false;
}

int SampleInt(int min, int max, uint32_t seed) {
// TODO(PuQing): If the seed always be a constant, the random number will be the same.
std::default_random_engine rng(seed);
std::uniform_int_distribution<int> dist(min, max);
return dist(rng);
}

std::unordered_map<int, int> PrimeFactorize(int n) {
std::unordered_map<int, int> factors;
while (n % 2 == 0) {
++factors[2];
n /= 2;
}
for (int i = 3; i <= sqrt(n); i += 2) {
while (n % i == 0) {
++factors[i];
n /= i;
}
}
if (n > 2) {
factors[n] = 1;
}
return factors;
}

std::vector<int> SampleTile(uint32_t seed, int n, int extent) {
std::vector<int> tile;
while (n > 1) {
std::unordered_map<int, int> factors = PrimeFactorize(extent);
int product = 1;
for (auto& factor : factors) {
if (factor.second >= 1) {
int num = ir::SampleInt(1, factor.second, seed);
product *= std::pow(factor.first, num);
}
}
tile.push_back(product);
extent /= product;
--n;
}
tile.push_back(extent);
return tile;
}
} // namespace ir
} // namespace cinn
36 changes: 36 additions & 0 deletions cinn/ir/ir_schedule_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#pragma once
#include <map>
#include <random>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -413,5 +415,39 @@ std::vector<IterRange> CalculateRequiredRegions(const Expr& block,

Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, const Expr& root);

/*!
* \brief RandomSeedController is used to control the random seed in the whole program.
*/
class RandomSeedController {
// TODO(PuQing): This is a temporary solution, maybe change it to a better one.
public:
static constexpr uint32_t seed = 1;
};
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief sample a int in [min, max].
* \param min The min value of the range.
* \param max The max value of the range.
* \param seed The random number generator to use.
*/
int SampleInt(int min, int max, uint32_t seed);

/*!
* \brief Get the prime factors of a number.
* For example, 12 = 2^2 * 3^1, then the return value is {2: 2, 3: 1}.
* \param n The number to be factorized.
* \return A map of prime factors and their corresponding exponents.
*/
std::unordered_map<int, int> PrimeFactorize(int n);

/*!
* \brief Given a number returns the form of the product of its n factors
* For example:
* n = 2, dividend = 12, return one of {2, 6}, {6, 2}, {3, 4}, {4, 3}
* \param seed The random number generator to use.
* \param n The number to be factorized.
* \param dividend The dividend of the number.
*/
std::vector<int> SampleTile(uint32_t seed, int n, int dividend);
} // namespace ir
} // namespace cinn
5 changes: 5 additions & 0 deletions cinn/ir/schedule_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,11 @@ CINN_BUILD_STEP_KIND(FlattenLoops)
.Attrs({"force_flat"})
.SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::FlattenLoops)));

CINN_BUILD_STEP_KIND(SamplePerfectTile)
.Inputs({"loop"})
.Attrs({"n","max_innermost_factor"})
.SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SamplePerfectTile)));

// clang-format on

// ------ Following codes are about member function implement of the ScheduleDesc class
Expand Down
25 changes: 25 additions & 0 deletions cinn/ir/schedule_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,5 +676,30 @@ TEST_F(TestScheduleDesc, StepKind_Unannotate) {
CheckReplayResult(ir_sch, trace);
CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SamplePerfectTile) {
Expr M(1024);
Var n(1, "n");

Placeholder<int> A("A", {M});
auto B = Compute(
{M}, [&](Expr i) { return A(i) + n; }, "B");
lowered_funcs =
cinn::lang::LowerVec("test_sample_perfect_tile", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);

ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
auto loops = ir_sch.GetLoops("B");
trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
auto result = ir_sch.SamplePerfectTile(loops[0], 2, 64);
trace.Append(ScheduleDesc::Step("SamplePerfectTile",
{{"loop", std::vector<Expr>({loops[0]})}},
{{"n", 2}, {"max_innermost_factor", 64}},
result));
CheckTracingOutputs(result, trace);
CheckTracingOutputs(result, ir_sch.GetTraceDesc());
CheckReplayResult(ir_sch, trace);
CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

} // namespace ir
} // namespace cinn