Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[IrSchedule] SamplePerfectTile #1142

Merged
merged 3 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2871,5 +2871,23 @@ void TestIrSchedule_ReduceSum(void* _args, int32_t num_args)
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
}

TEST(IrSchedule, SamplePerfectTile) {
Context::Global().ResetNameId();
Expr M(1024);
Placeholder<int> A("A", {M});
auto B = Compute(
{M}, [&](Expr i) { return A(i) + 1; }, "B");
poly::StageMap stages = CreateStages({A, B});

auto funcs = cinn::lang::LowerVec(
"test_sampleperfecttile", stages, {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true);

ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body}));
auto loops_b = ir_sch.GetLoops("B");
std::vector<Expr> result = ir_sch.SamplePerfectTile(loops_b[0], 3, 64);
LOG(INFO) << "SamplePerfectTile result: " << result;
ASSERT_EQ(result.size(), 3);
}

} // namespace backends
} // namespace cinn
38 changes: 38 additions & 0 deletions cinn/ir/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

#include "cinn/common/cas.h"
Expand Down Expand Up @@ -70,6 +72,7 @@ class ScheduleImpl {
Expr GetBlock(const std::string& block_name) const;
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
std::vector<Expr> Split(const std::string& block_name, int loop_index, const std::vector<int>& factors);
std::vector<Expr> SamplePerfectTile(const uint32_t seed, const Expr& loop, int n, int max_innermost_factor);
Expr Fuse(const std::vector<Expr>& loops);
Expr Fuse(const std::string& block_name, const std::vector<int>& loops_index);
Expr Fuse(const Expr& block, const std::vector<int>& loops_index);
Expand Down Expand Up @@ -1790,6 +1793,32 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block
this->Replace(all_loops[0], res);
}

std::vector<Expr> ScheduleImpl::SamplePerfectTile(const uint32_t seed,
const Expr& loop,
int n,
int max_innermost_factor) {
CHECK(loop.As<ir::For>()) << "Expr param of SamplePerfectTile should be a For loop";
CHECK_GE(n, 2) << "The number of tile factors should be at least 2";
CHECK_GE(max_innermost_factor, 1) << "The max innermost factor should be at least 1";
CHECK(common::is_zero(loop.As<ir::For>()->min)) << "The For loop should start from 0";
int loop_extent = GetLoopExtent(loop);
std::vector<int> innermost_factors;
for (int i = max_innermost_factor; i >= 1; --i) {
if (loop_extent % i == 0) {
innermost_factors.push_back(i);
}
}
CHECK(!innermost_factors.empty()) << "No innermost factor found";
int innermost_factor = innermost_factors[ir::SampleInt(0, innermost_factors.size() - 1, seed)];
auto result = SampleTile(seed, n - 1, loop_extent / innermost_factor);
std::vector<Expr> result_expr;
for (auto& factor : result) {
result_expr.push_back(Expr(factor));
}
result_expr.push_back(Expr(innermost_factor));
return result_expr;
}

IRSchedule::IRSchedule() {}

IRSchedule::IRSchedule(const ModuleExpr& module_expr, bool debug_flag) {
Expand Down Expand Up @@ -2042,5 +2071,14 @@ void IRSchedule::CopyTransformAndLoopInfo(const std::string& block_name, const s
// don't support to trace, because we can't ensure both blocks are from the same ModuleExpr
}

std::vector<Expr> IRSchedule::SamplePerfectTile(const Expr& loop, int n, int max_innermost_factor) {
auto result = impl_->SamplePerfectTile(ir::RandomSeedController::seed, loop, n, max_innermost_factor);
trace_.Append(ScheduleDesc::Step("SamplePerfectTile",
{{"loop", std::vector<Expr>({loop})}},
{{"n", n}, {"max_innermost_factor", max_innermost_factor}},
{result}));
return result;
}

} // namespace ir
} // namespace cinn
15 changes: 15 additions & 0 deletions cinn/ir/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,21 @@ class IRSchedule {
// TODO(sunli): Solve Index Simplify.
void FlattenLoops(const std::vector<Expr>& loops, const bool force_flat = false);

/*!
* \brief Number of random sampling cycles for the loop to be split
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved
* \param loop the loop to be split
* \param n the number of loop layers to split
* \param max_innermost_factor the maximum factor of the innermost loop
* \return the split factors of the loop (The larger the index, the inner the corresponding loop)
* For example, return {16,64} means the loop will be like this:
* for (i, 0, 16) {
* for (j, 0, 64) {
* ...
* }
* }
*/
std::vector<Expr> SamplePerfectTile(const Expr& loop, int n, int max_innermost_factor);
CtfGo marked this conversation as resolved.
Show resolved Hide resolved

private:
std::unique_ptr<ScheduleImpl> impl_;
mutable ScheduleDesc trace_; // trace the scheduling process
Expand Down
42 changes: 42 additions & 0 deletions cinn/ir/ir_schedule_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -962,5 +962,47 @@ bool ContainVar(const std::vector<Expr>& exprs, const std::string& var_name) {
return false;
}

int SampleInt(int min, int max, uint32_t seed) {
// TODO(PuQing): If the seed always be a constant, the random number will be the same.
std::default_random_engine rng(seed);
std::uniform_int_distribution<int> dist(min, max);
return dist(rng);
}

std::unordered_map<int, int> PrimeFactorize(int n) {
std::unordered_map<int, int> factors;
while (n % 2 == 0) {
++factors[2];
n /= 2;
}
for (int i = 3; i <= sqrt(n); i += 2) {
while (n % i == 0) {
++factors[i];
n /= i;
}
}
if (n > 2) {
factors[n] = 1;
}
return factors;
}

std::vector<int> SampleTile(uint32_t seed, int n, int extent) {
if (n == 1) {
return {extent};
}
std::vector<int> tile;
std::unordered_map<int, int> factors = PrimeFactorize(extent);
int product = 1;
for (auto& factor : factors) {
if (factor.second >= 1) {
int num = ir::SampleInt(1, factor.second, seed);
product *= std::pow(factor.first, num);
}
}
auto result = SampleTile(seed, n - 1, extent / product);
result.push_back(product);
CtfGo marked this conversation as resolved.
Show resolved Hide resolved
return result;
}
} // namespace ir
} // namespace cinn
36 changes: 36 additions & 0 deletions cinn/ir/ir_schedule_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#pragma once
#include <map>
#include <random>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -413,5 +415,39 @@ std::vector<IterRange> CalculateRequiredRegions(const Expr& block,

Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, const Expr& root);

/*!
* \brief RandomSeedController is used to control the random seed in the whole program.
*/
class RandomSeedController {
// TODO(PuQing): This is a temporary solution, maybe change it to a better one.
public:
static constexpr uint32_t seed = 1;
};
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief sample a int in [min, max].
* \param min The min value of the range.
* \param max The max value of the range.
* \param seed The random number generator to use.
*/
int SampleInt(int min, int max, uint32_t seed);

/*!
* \brief Get the prime factors of a number.
* For example, 12 = 2^2 * 3^1, then the return value is {2: 2, 3: 1}.
* \param n The number to be factorized.
* \return A map of prime factors and their corresponding exponents.
*/
std::unordered_map<int, int> PrimeFactorize(int n);

/*!
* \brief Given a number returns the form of the product of its n factors
* For example:
* n = 2, dividend = 12, return one of {2, 6}, {6, 2}, {3, 4}, {4, 3}
* \param seed The random number generator to use.
* \param n The number to be factorized.
* \param dividend The dividend of the number.
*/
std::vector<int> SampleTile(uint32_t seed, int n, int dividend);
} // namespace ir
} // namespace cinn
5 changes: 5 additions & 0 deletions cinn/ir/schedule_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,11 @@ CINN_BUILD_STEP_KIND(FlattenLoops)
.Attrs({"force_flat"})
.SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::FlattenLoops)));

CINN_BUILD_STEP_KIND(SamplePerfectTile)
.Inputs({"loop"})
.Attrs({"n","max_innermost_factor"})
.SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SamplePerfectTile)));

// clang-format on

// ------ Following codes are about member function implement of the ScheduleDesc class
Expand Down
25 changes: 25 additions & 0 deletions cinn/ir/schedule_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,5 +676,30 @@ TEST_F(TestScheduleDesc, StepKind_Unannotate) {
CheckReplayResult(ir_sch, trace);
CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_SamplePerfectTile) {
Expr M(1024);
Var n(1, "n");

Placeholder<int> A("A", {M});
auto B = Compute(
{M}, [&](Expr i) { return A(i) + n; }, "B");
lowered_funcs =
cinn::lang::LowerVec("test_sample_perfect_tile", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true);

ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
auto loops = ir_sch.GetLoops("B");
trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops));
auto result = ir_sch.SamplePerfectTile(loops[0], 2, 64);
trace.Append(ScheduleDesc::Step("SamplePerfectTile",
{{"loop", std::vector<Expr>({loops[0]})}},
{{"n", 2}, {"max_innermost_factor", 64}},
result));
CheckTracingOutputs(result, trace);
CheckTracingOutputs(result, ir_sch.GetTraceDesc());
CheckReplayResult(ir_sch, trace);
CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

} // namespace ir
} // namespace cinn