Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Bugfix] Improved massive build times caused by tir.floormod and tir.floordiv. Fixed Topi testcase. #5666

Merged
merged 9 commits into from
Jul 28, 2020
8 changes: 7 additions & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
CHECK(!var_map_.count(op->var.get()));
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
}
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
Expand Down
5 changes: 5 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/ir/module.h>
#include <tvm/runtime/container.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
Expand Down Expand Up @@ -322,6 +323,10 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
std::unordered_set<const VarNode*> alias_var_set_;
// set of volatile buffer.
std::unordered_set<const VarNode*> volatile_buf_;
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// Cache potential common path ops to slightly improve lookup time.
// global symbol table.
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
Expand Down
8 changes: 7 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,14 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
}

void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
}
std::string value = PrintExpr(op->value);
CHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
os << PrintExpr(op->body);
}
Expand Down
5 changes: 5 additions & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/ir/op.h>
#include <tvm/runtime/container.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
Expand Down Expand Up @@ -269,6 +270,10 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
bool print_ssa_form_{false};
/*! \brief set of volatile buf access */
std::unordered_set<const VarNode*> volatile_buf_;
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};

} // namespace codegen
Expand Down
8 changes: 7 additions & 1 deletion src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
}

spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
CHECK(!var_map_.count(op->var.get()));
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
}
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
Expand Down
5 changes: 5 additions & 0 deletions src/target/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
Expand Down Expand Up @@ -140,6 +141,10 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
std::unordered_map<const VarNode*, spirv::Value> var_map_;
// The analyzer.
std::unique_ptr<arith::Analyzer> analyzer_;
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};

} // namespace codegen
Expand Down
47 changes: 31 additions & 16 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,22 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
}
}
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
rdiv - make_const(dtype, 1));
if (dtype.is_float()) {
// floor(a / b)
return VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>());
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
auto rmod = tir::Var("rmod", dtype);
auto rdiv = tir::Var("rdiv", dtype);
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
PrimExpr let_rdiv =
tir::Let(rdiv, truncdiv(op->a, op->b),
tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
rdiv - make_const(dtype, 1)));
return Let(rmod, truncmod(op->a, op->b), let_rdiv);
}
}
}

Expand Down Expand Up @@ -152,14 +160,21 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
}
}
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
PrimExpr rmod = truncmod(op->a, op->b);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b);
if (dtype.is_float()) {
// a - floor(a / b) * b
return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()) * op->b);
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
auto rmod = tir::Var("rmod", dtype);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
return Let(
rmod, truncmod(op->a, op->b),
Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b));
}
}
}

Expand Down
22 changes: 20 additions & 2 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,28 @@ class VarUseDefAnalysis : public StmtExprMutator {
}

PrimExpr VisitExpr_(const LetNode* op) final {
this->HandleDef(op->var.get());
// Weaker SSA condition
// A single var can be binded in multiple lets
// but they have to bind to the same value.
// This is used to allow cases when we reuse a single let
// expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var);
PrimExpr value = this->VisitExpr(op->value);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, value))
<< "Let cannot bind the same var to two different values";
return GetRef<PrimExpr>(it->second);
} else {
this->HandleDef(op->var.get());
let_binding_[op->var] = op;
}
PrimExpr body = this->VisitExpr(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
simplify_let_) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
if (body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
Expand Down Expand Up @@ -157,6 +171,10 @@ class VarUseDefAnalysis : public StmtExprMutator {
Array<PrimExpr> thread_extent_;
std::unordered_map<const VarNode*, int> use_count_;
std::unordered_map<const VarNode*, int> def_count_;

private:
ExprDeepEqual deep_equal_;
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};

Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
Expand Down
27 changes: 9 additions & 18 deletions topi/tests/python/test_topi_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,6 @@ def check_device(device):
rhs_npy, rhs_nd = gen_operand(rhs_shape, rhs_min, rhs_max, ctx)
out_npy = fnumpy(lhs_npy, rhs_npy)

if fnumpy == np.floor_divide:
# avoid check too close to X.5 and X.0
# FIXME: floor_divide(94.90735, 0.6731018) behaves as floor(div(94.90735, 0.6731018))
# However the result is somehow incorrect - need to further investigate.
# And looks like numpy's floor_div(a,b) is implemented different from floor(div(a,b))
mask = np.logical_or(np.abs(np.abs(np.fmod(lhs_npy / rhs_npy, 1)) - 0.5) < 1e-6,
np.abs(np.fmod(lhs_npy / rhs_npy, 1)) < 1e-6)
if mask.any():
lhs_npy = lhs_npy + mask * 1e-3 * rhs_npy
lhs_npy = lhs_npy.astype(dtype)
lhs_nd = tvm.nd.array(lhs_npy, ctx) if lhs_shape is not None else lhs_npy.item()
out_npy = fnumpy(lhs_npy, rhs_npy)

out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
foo(lhs_nd, rhs_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
Expand Down Expand Up @@ -151,12 +138,14 @@ def test_divide():
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)

def test_floor_divide():
def _canonical_floor_div(a,b):
return np.floor(a / b)
verify_broadcast_binary_ele(
None, (10,), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
None, (10,), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
verify_broadcast_binary_ele(
(), None, topi.floor_divide, np.floor_divide, rhs_min=0.0001)
(), None, topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 64, 32), (64, 32), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
(2, 3, 64, 32), (64, 32), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)

def test_maximum_minmum():
verify_broadcast_binary_ele(
Expand All @@ -175,10 +164,12 @@ def test_mod():
(1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")

def test_floor_mod():
def _canonical_floor_mod(a,b):
return a - np.floor(a / b) * b
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="int32")
(1, 2, 2), (2,), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="int32")
verify_broadcast_binary_ele(
(3, 4, 5), (3, 4, 5), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="float32")
(3, 4, 5), (3, 4, 5), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="float32")

def test_cmp():
# explicit specify the output type
Expand Down