Skip to content

Commit

Permalink
Updated use-def analysis and llvm codegen to support duplicated letno…
Browse files Browse the repository at this point in the history
…des.
  • Loading branch information
dpankratz committed Jun 28, 2020
1 parent 5c2e58a commit 1f2b1c6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 14 deletions.
9 changes: 8 additions & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
CHECK(!var_map_.count(op->var.get()));
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
return var_map_[op->var.get()];
} else {
let_binding_[op->var] = op;
}
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
Expand Down
5 changes: 5 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/ir/module.h>
#include <tvm/runtime/container.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
Expand Down Expand Up @@ -322,6 +323,10 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
std::unordered_set<const VarNode*> alias_var_set_;
// set of volatile buffer.
std::unordered_set<const VarNode*> volatile_buf_;
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// Cache potential common path ops to slightly improve lookup time.
// global symbol table.
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
Expand Down
21 changes: 10 additions & 11 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
}
}
} else {
if (dtype.is_float()){
if (dtype.is_float()) {
// floor(a / b)
return VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>());
}
else if (dtype.is_int() && dtype.bits() <= 32) {
} else if (dtype.is_int() && dtype.bits() <= 32) {
/* NOTE:
This must be restricted to int32 or less since floats can losslessly represent integers
only if the number of bits in the mantissa exceeds the number of bits in the integer.
Expand All @@ -121,18 +120,18 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// floor(a / b)
auto fdtype = DataType::Float(dtype.bits() * 2, dtype.lanes());
auto div = tir::Div(tir::Cast(fdtype, op->a), tir::Cast(fdtype, op->b));
auto f = tvm::floor(div);
return tir::Cast(dtype, VisitExpr_(f.as<CallNode>()));
return tir::Cast(dtype, VisitExpr_(tvm::floor(div).as<CallNode>()));
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
auto rmod = tir::Var("rmod",dtype);
auto rdiv = tir::Var("rdiv",dtype);
auto rmod = tir::Var("rmod", dtype);
auto rdiv = tir::Var("rdiv", dtype);
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
PrimExpr floordiv = Let(rdiv, truncdiv(op->a, op->b),
tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, rdiv - make_const(dtype, 1));
return Let(rmod, truncmod(op->a, op->b), floordiv);
PrimExpr let_rdiv = tir::Let(rdiv, truncdiv(op->a, op->b),
tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
rdiv, rdiv - make_const(dtype, 1)));
return Let(rmod, truncmod(op->a, op->b), let_rdiv);
}
}
}
Expand Down Expand Up @@ -192,7 +191,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
auto rmod = tir::Var("rmod",dtype);
auto rmod = tir::Var("rmod", dtype);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
Expand Down
22 changes: 20 additions & 2 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,27 @@ class VarUseDefAnalysis : public StmtExprMutator {
}

PrimExpr VisitExpr_(const LetNode* op) final {
this->HandleDef(op->var.get());
// Weaker SSA condition
// A single var can be binded in multiple lets
// but they have to bind to the same value.
// This is used to allow cases when we reuse a single let
// expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var);
PrimExpr value = this->VisitExpr(op->value);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, value))
<< "Let cannot bind the same var to two different values";
return GetRef<PrimExpr>(it->second);
} else {
this->HandleDef(op->var.get());
let_binding_[op->var] = op;
}
PrimExpr body = this->VisitExpr(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
if (body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
Expand Down Expand Up @@ -155,6 +169,10 @@ class VarUseDefAnalysis : public StmtExprMutator {
Array<PrimExpr> thread_extent_;
std::unordered_map<const VarNode*, int> use_count_;
std::unordered_map<const VarNode*, int> def_count_;

private:
ExprDeepEqual deep_equal_;
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};

Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
Expand Down

0 comments on commit 1f2b1c6

Please sign in to comment.