Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at #14854

Merged
merged 6 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,19 +422,43 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
var_dom = arith::IntSet::Interval(required_min, required_max);
var_bound = arith::IntSet::Interval(0, dim_max);
} else {
arith::PVar<PrimExpr> p_f;
if ((floordiv(p_v, p_f)).Match(provided_min)) {
// a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
PrimExpr fac = p_f.Eval();
arith::PVar<PrimExpr> p_f1, p_f2;
if ((floordiv(p_f1, p_f2).Match(provided_min))) {
PrimExpr var_expr = p_f1.Eval();
PrimExpr fac = p_f2.Eval();
if (analyzer->CanProveGreaterEqual(fac, 1)) {
var = p_v.Eval();
var_dom = arith::IntSet::Interval(required_min * fac,
analyzer->Simplify(required_max * fac + fac - 1));
var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
if (var_expr->IsInstance<VarNode>()) {
// a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
var = Downcast<Var>(var_expr);
var_dom = arith::IntSet::Interval(required_min * fac,
analyzer->Simplify(required_max * fac + fac - 1));
var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
} else {
const arith::IntSet new_provided = arith::IntSet::SinglePoint(p_f1.Eval());
const arith::IntSet new_required = arith::IntSet::Interval(
required_min * fac, analyzer->Simplify(required_max * fac + fac - 1));
return SolveBlockVarDomain(new_provided, new_required, dim_max, analyzer);
}
}
} else if ((floormod(p_f1, p_f2).Match(provided_min))) {
PrimExpr var_expr = p_f1.Eval();
if (var_expr->IsInstance<VarNode>()) {
// generally domain of (x % fac) enforce no constraints to domain of x
Var var_mod = Downcast<Var>(var_expr);
return {var_mod, BlockVarDomainInfo()};
} else {
PrimExpr mod_1 = p_f1.Eval();
PrimExpr mod_2 = p_f2.Eval();
if (analyzer->CanProveGreaterEqual(mod_1, 1) &&
analyzer->CanProveGreaterEqual(mod_2, 1)) {
const arith::IntSet new_provided = arith::IntSet::SinglePoint(p_f1.Eval());
if (analyzer->CanProveGreaterEqual(required_min, 0)) {
const arith::IntSet new_required =
arith::IntSet::Interval(required_min, arith::SymbolicLimits::pos_inf_);
return SolveBlockVarDomain(new_provided, new_required, dim_max, analyzer);
}
}
}
} else if ((floormod(p_v, p_f).Match(provided_min))) {
// generally domain of (x % fac) enforce no constraints to domain of x
return {p_v.Eval(), BlockVarDomainInfo()};
}
}
}
Expand Down
52 changes: 52 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,48 @@ def floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.han
Y[v_i] = temp[v_i // 16, v_i % 16]


@T.prim_func
def recursive_floordiv_floormod(A: T.Buffer((16, 64, 1, 8, 8, 32), "float32"),
C: T.Buffer((3, 512, 512), "float32")) -> None:
T.func_attr({"tir.noalias": True})
# with T.block("root"):
B = T.alloc_buffer((1, 128, 16, 8, 2, 32, 2), "float32")
for axis1, axis2, axis3, axis4, axis5, axis6, axis7 in T.grid(1, 128, 16, 8, 2, 32, 2):
with T.block("In"):
v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7 = T.axis.remap("SSSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6, axis7])
T.reads(A[(v_axis2 * 4 + v_axis5 * 2 + v_axis7) // 32, (v_axis3 * 32 + v_axis6) // 8, (v_axis1 * 8 + v_axis4) // 8, (v_axis3 * 32 + v_axis6) % 8, v_axis1 * 8 + v_axis4, (v_axis2 * 4 + v_axis5 * 2 + v_axis7) % 32])
T.writes(B[v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7])
B[v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7] = A[(v_axis2 * 4 + v_axis5 * 2 + v_axis7) // 32, (v_axis3 * 32 + v_axis6) // 8, (v_axis1 * 8 + v_axis4) // 8, (v_axis3 * 32 + v_axis6) % 8, v_axis1 * 8 + v_axis4, (v_axis2 * 4 + v_axis5 * 2 + v_axis7) % 32] + 3
for ax1, ax2, ax3 in T.grid(3, 512, 512):
with T.block("Out"):
v1, v2, v3 = T.axis.remap("SSS", [ax1, ax2, ax3])
T.reads(B[v1 // 8, v2 // 4, v3 // 32, v1, v2 % 4 // 2, v3 % 32, v2 % 2])
T.writes(C[v1, v2, v3])
C[v1, v2, v3] = B[v1 // 8, v2 // 4, v3 // 32, v1, v2 % 4 // 2, v3 % 32, v2 % 2] * 2


@T.prim_func
def recursive_floordiv_floormod_after_reverse_compute_at(A: T.Buffer((16, 64, 1, 8, 8, 32), "float32"), C: T.Buffer((3, 512, 512), "float32")) -> None:
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
B = T.alloc_buffer((1, 128, 16, 8, 2, 32, 2))
for axis1, axis2, axis3 in T.grid(1, 128, 16):
for axis4, axis5, axis6, axis7 in T.grid(8, 2, 32, 2):
with T.block("In"):
v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7 = T.axis.remap("SSSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6, axis7])
T.reads(A[(v_axis2 * 4 + v_axis5 * 2 + v_axis7) // 32, (v_axis3 * 32 + v_axis6) // 8, (v_axis1 * 8 + v_axis4) // 8, (v_axis3 * 32 + v_axis6) % 8, v_axis1 * 8 + v_axis4, (v_axis2 * 4 + v_axis5 * 2 + v_axis7) % 32])
T.writes(B[v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7])
B[v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6, v_axis7] = A[(v_axis2 * 4 + v_axis5 * 2 + v_axis7) // 32, (v_axis3 * 32 + v_axis6) // 8, (v_axis1 * 8 + v_axis4) // 8, (v_axis3 * 32 + v_axis6) % 8, v_axis1 * 8 + v_axis4, (v_axis2 * 4 + v_axis5 * 2 + v_axis7) % 32] + T.float32(3)
for ax0, ax1, ax2 in T.grid(3, 4, 32):
with T.block("Out"):
v1 = T.axis.spatial(3, ax0)
v2 = T.axis.spatial(512, axis2 * 4 + ax1)
v3 = T.axis.spatial(512, axis3 * 32 + ax2)
T.reads(B[v1 // 8, v2 // 4, v3 // 32, v1, v2 % 4 // 2, v3 % 32, v2 % 2])
T.writes(C[v1, v2, v3])
C[v1, v2, v3] = B[v1 // 8, v2 // 4, v3 // 32, v1, v2 % 4 // 2, v3 % 32, v2 % 2] * T.float32(2)


@T.prim_func
def tiled_repeat_op(x: T.Buffer((4,), "float32"), T_repeat: T.Buffer((64,), "float32")) -> None:
T_add = T.alloc_buffer([4], dtype="float32")
Expand Down Expand Up @@ -1255,6 +1297,16 @@ def test_reverse_compute_at_floordiv_and_floormod_indices(use_block_name):
verify_trace_roundtrip(sch=sch, mod=floordiv_and_floormod_indices)


def test_reverse_compute_at_floordiv_and_floormod_recursive(use_block_name):
sch = tir.Schedule(recursive_floordiv_floormod, debug_mask="all")
write_block = sch.get_block("Out")
sch.reverse_compute_at(write_block, sch.get_loops("In")[2])
tvm.ir.assert_structural_equal(
recursive_floordiv_floormod_after_reverse_compute_at, sch.mod["main"]
)
verify_trace_roundtrip(sch=sch, mod=recursive_floordiv_floormod)


def test_read_out_of_bound(use_block_name):
sch = tir.Schedule(read_out_of_bound, debug_mask="all")
block = sch.get_block("B")
Expand Down