Skip to content

Commit

Permalink
Use an attribute (relax.force_pure) to control forcing purity
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed May 17, 2023
1 parent b11171c commit d1651ce
Show file tree
Hide file tree
Showing 37 changed files with 126 additions and 177 deletions.
18 changes: 9 additions & 9 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -922,14 +922,11 @@ class FunctionNode : public BaseFuncNode {
StructInfo ret_struct_info;
/*! \brief Whether the function is annotated as pure or not. */
bool is_pure;
/*! \brief Override checking purity for this function (only if purity is set to true) */
bool force_pure;

void VisitAttrs(AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("is_pure", &is_pure);
v->Visit("force_pure", &force_pure);
v->Visit("ret_struct_info", &ret_struct_info);
v->Visit("attrs", &attrs);
v->Visit("struct_info_", &struct_info_);
Expand All @@ -941,8 +938,7 @@ class FunctionNode : public BaseFuncNode {
equal->MarkGraphNode();
return equal.DefEqual(params, other->params) && equal(body, other->body) &&
equal(ret_struct_info, other->ret_struct_info) && equal(is_pure, other->is_pure) &&
equal(force_pure, other->force_pure) && equal(attrs, other->attrs) &&
equal(struct_info_, other->struct_info_);
equal(attrs, other->attrs) && equal(struct_info_, other->struct_info_);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -951,7 +947,6 @@ class FunctionNode : public BaseFuncNode {
hash_reduce(body);
hash_reduce(ret_struct_info);
hash_reduce(is_pure);
hash_reduce(force_pure);
hash_reduce(attrs);
hash_reduce(struct_info_);
}
Expand All @@ -965,13 +960,12 @@ class FunctionNode : public BaseFuncNode {
class Function : public BaseFunc {
public:
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
bool is_pure = true, bool force_pure = false,
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
Span span = Span());

/*!
* \brief Mimics the constructor but without body Expr.
* \note ret_struct_info is required, since it can not deduced by the body.
* force_pure is omitted because the purity will not be checked anyway.
*/
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
Expand All @@ -997,6 +991,12 @@ constexpr const char* kComposite = "Composite";
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
/*! \brief The required workspace for an external function. */
constexpr const char* kWorkspaceSize = "WorkspaceSize";

// Note: in the future, we prefer snake_case instead of CamelCase for attributes.
// Past ones will be kept for backwards compatibility.
/*! \brief Override checking purity for this function and treat as pure
* (is_pure must be set to true) */
constexpr const char* kForcePure = "relax.force_pure";
} // namespace attr

/*! \brief The extern function, which can represent packed function. */
Expand Down
3 changes: 0 additions & 3 deletions include/tvm/script/ir_builder/relax/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ class FunctionFrameNode : public SeqExprFrameNode {
Optional<tvm::relax::StructInfo> ret_struct_info;
/*! \brief Whether the function is annotated as pure */
Optional<Bool> is_pure;
/*! \brief Whether the function is forced pure*/
Optional<Bool> force_pure;
/*! \brief The function attributes. */
Map<String, ObjectRef> attrs;
/*! \brief The block builder to create Relax function. */
Expand All @@ -112,7 +110,6 @@ class FunctionFrameNode : public SeqExprFrameNode {
v->Visit("params", &params);
v->Visit("ret_struct_info", &ret_struct_info);
v->Visit("is_pure", &is_pure);
v->Visit("force_pure", &force_pure);
v->Visit("attrs", &attrs);
v->Visit("binding_blocks", &binding_blocks);
v->Visit("output", &output);
Expand Down
6 changes: 0 additions & 6 deletions include/tvm/script/ir_builder/relax/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,6 @@ TVM_DLL void FuncAttrs(Map<String, ObjectRef> attrs);
*/
TVM_DLL void FuncIsPure(bool purity);

/*!
* \brief Specify whether the last function frame is forced to be pure.
* \param force_pure Whether purity should be forced.
*/
TVM_DLL void FuncForcePure(bool force_pure);

/*!
* \brief Specify the return struct info of the last function frame.
* \param ret_sinfo The return struct info.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def __init__(self, mod):
def visit_function_(self, f):
if f.attrs is None or "Composite" not in f.attrs:
body = super().visit_expr(f.body)
new_f = Function(f.params, body, f.ret_struct_info, f.attrs, f.span)
new_f = Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span)

if f.attrs and "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]:
composite_func = body.blocks[0].bindings[0].value
Expand Down
3 changes: 0 additions & 3 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,6 @@ class Function(BaseFunc, Scriptable):
body: Expr
ret_struct_info: StructInfo
is_pure: bool
force_pure: bool
attrs: Optional[tvm.ir.DictAttrs]

def __init__(
Expand All @@ -570,7 +569,6 @@ def __init__(
body: Expr,
ret_struct_info: Optional[StructInfo] = None,
is_pure: Optional[bool] = True,
force_pure: Optional[bool] = False,
attrs: Optional[tvm.ir.DictAttrs] = None,
span: Optional[Span] = None,
) -> None:
Expand All @@ -580,7 +578,6 @@ def __init__(
body,
ret_struct_info,
is_pure,
force_pure,
attrs,
span, # type: ignore
) # type: ignore
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relax/testing/ast_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def visit_function_(self, op: relax.Function) -> str:
"body": self.visit_expr(op.body),
"ret_struct_info": self.visit_struct_info_(op.ret_struct_info),
"is_pure": op.is_pure,
"force_pure": op.force_pure,
}
if op.attrs:
fields["attrs"] = self.build_list(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def ToNonDataflow() -> tvm.ir.transform.Pass:


def RemovePurityChecking() -> tvm.ir.transform.Pass:
"""Activate force_pure on all pure functions in the module
"""Activate relax.force_pure on all pure functions in the module
and unwrap all pure override ops into the normal versions.
This effectively means that there will be no more purity tracking,
Expand Down
11 changes: 0 additions & 11 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,6 @@ def is_impure() -> None:
return _ffi_api.FuncIsPure(False) # type: ignore[attr-defined] # pylint: disable=no-member


def force_pure(forced: bool = True) -> None:
"""Specify whether the last function frame is forced to be pure.
Parameters
----------
forced: bool
Whether purity is forced for the function or not
"""
return _ffi_api.FuncForcePure(forced) # type: ignore[attr-defined] # pylint: disable=no-member


def func_ret_struct_info(ret_sinfo: StructInfo) -> None:
"""Specify the return struct info of the last function frame.
Parameters
Expand Down Expand Up @@ -619,7 +609,6 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"flip",
"floor",
"floor_divide",
"force_pure",
"full",
"full_like",
"func_attr",
Expand Down
24 changes: 14 additions & 10 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@
* 14. DataflowBlocks may not contain If nodes.
* 15. DataflowBlocks may not contain calls to impure functions or operators
* (only checked if check_struct_info is true).
* 16. If a function has is_pure set to true and force_pure is not set to true,
* the body may not contain any impure call
* (only checked if check_struct_info is true).
* 17. If force_pure is true for a function, that function's is_pure must also be true.
* 16. If a function has is_pure set to true and the kForcePure attribute is not set,
* the body may not contain any impure call (only checked if check_struct_info is true).
* 17. If the kForcePure attribute is set for a function,
* that function's is_pure field must be true.
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
Expand Down Expand Up @@ -228,10 +228,11 @@ class WellFormedChecker : public relax::ExprVisitor,
});

// ensure the purity attributes are valid
if (op->force_pure && !op->is_pure) {
if (op->GetAttr<Bool>(relax::attr::kForcePure).value_or(Bool(false))->value && !op->is_pure) {
Malformed(Diagnostic::Error(op->span)
<< "Function " << op << " has true for force_pure but false for is_pure;"
<< " force_pure should be true only if is_pure is also true.");
<< "Function " << op << " has true for " << relax::attr::kForcePure
<< " but false for is_pure; " << relax::attr::kForcePure
<< " should be true only if is_pure is also true.");
}

// check all expr are well defined.
Expand All @@ -255,11 +256,14 @@ class WellFormedChecker : public relax::ExprVisitor,

// if we are not forcing purity and the function is annotated as pure, it must not contain an
// impure call
if (check_struct_info_ && !op->force_pure && op->is_pure && ContainsImpureCall(op->body)) {
if (check_struct_info_ &&
!op->GetAttr<Bool>(relax::attr::kForcePure).value_or(Bool(false))->value && op->is_pure &&
ContainsImpureCall(op->body)) {
Malformed(Diagnostic::Error(op)
<< "Function " << op << " is annotated as pure but contains an impure call; "
<< "please set force_pure to true or use a pure operator variant "
<< "(e.g., call_pure_packed) if it is necessary to override this judgment.");
<< "please set " << relax::attr::kForcePure << " to true "
<< "or use a pure operator variant (e.g., call_pure_packed) "
<< "if it is necessary to override this judgment.");
}

if (auto seq = op->body.as<SeqExprNode>()) {
Expand Down
3 changes: 1 addition & 2 deletions src/relax/backend/vm/vm_shape_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ class VMShapeLowerMutator

auto new_body = builder_->Normalize(SeqExpr(blocks, body_seq->body));
// create a new function
return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->force_pure,
func->attrs);
return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs);
}

//-------------------------------------------------------
Expand Down
3 changes: 1 addition & 2 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
if (new_body.same_as(op->body)) {
return GetRef<Function>(op);
} else {
return Function(op->params, new_body, op->ret_struct_info, op->is_pure, op->force_pure,
op->attrs);
return Function(op->params, new_body, op->ret_struct_info, op->is_pure, op->attrs);
}
}

Expand Down
8 changes: 3 additions & 5 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ TVM_REGISTER_GLOBAL("relax.SeqExpr")
TVM_REGISTER_NODE_TYPE(FunctionNode);

Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info, bool is_pure,
bool force_pure, DictAttrs attrs, Span span) {
DictAttrs attrs, Span span) {
// Set the function type.
// For function, we take a conservative approach and require the function type
// to be known at construction time.
Expand Down Expand Up @@ -457,7 +457,6 @@ Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct
n->body = std::move(body);
n->ret_struct_info = std::move(ret_struct_info.value());
n->is_pure = is_pure;
n->force_pure = force_pure;
n->checked_type_ = GetStaticType(func_sinfo);
n->struct_info_ = std::move(func_sinfo);
n->attrs = std::move(attrs);
Expand All @@ -467,8 +466,8 @@ Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct

TVM_REGISTER_GLOBAL("relax.Function")
.set_body_typed([](Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
bool is_pure, bool force_pure, DictAttrs attrs, Span span) {
return Function(params, body, ret_struct_info, is_pure, force_pure, attrs, span);
bool is_pure, DictAttrs attrs, Span span) {
return Function(params, body, ret_struct_info, is_pure, attrs, span);
});

Function Function::CreateEmpty(Array<Var> params, StructInfo ret_struct_info, bool is_pure,
Expand All @@ -487,7 +486,6 @@ Function Function::CreateEmpty(Array<Var> params, StructInfo ret_struct_info, bo
n->params = std::move(params);
n->body = Expr();
n->is_pure = is_pure;
n->force_pure = false;
n->checked_type_ = GetStaticType(finfo);
n->struct_info_ = std::move(finfo);
n->ret_struct_info = std::move(ret_struct_info);
Expand Down
4 changes: 2 additions & 2 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) {
if (body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Function(op->params, body, op->ret_struct_info, op->is_pure, op->force_pure, op->attrs);
return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs);
}
}

Expand Down Expand Up @@ -589,7 +589,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
if (all_params_unchanged && body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Function(params, body, op->ret_struct_info, op->is_pure, op->force_pure, op->attrs);
return Function(params, body, op->ret_struct_info, op->is_pure, op->attrs);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/relax/training/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class AppendLossMutator : private ExprMutator {
loss_function_->params.end());
Expr new_body = this->VisitExpr(func->body);

return Function(new_params, new_body, NullOpt, func->is_pure, func->force_pure, func->attrs);
return Function(new_params, new_body, NullOpt, func->is_pure, func->attrs);
}

Expr VisitExpr_(const SeqExprNode* seq_expr) final {
Expand Down
4 changes: 2 additions & 2 deletions src/relax/transform/allocate_workspace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ExternFunctionRewriter : ExprMutator {

new_params.push_back(workspace_param);
return Function(new_params, VisitExpr(func_node->body), func_node->ret_struct_info,
func_node->is_pure, func_node->force_pure, func_node->attrs);
func_node->is_pure, func_node->attrs);
}
return ExprMutator::VisitExpr_(func_node);
}
Expand Down Expand Up @@ -128,7 +128,7 @@ class WorkspaceProvider : ExprMutator {
auto gvar = mod_->GetGlobalVar("main");
auto func = Downcast<Function>(mod_->Lookup(gvar));
auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info,
func->is_pure, func->force_pure, func->attrs);
func->is_pure, func->attrs);
builder_->UpdateFunction(gvar, new_func);
return builder_->GetContextIRModule();
}
Expand Down
4 changes: 2 additions & 2 deletions src/relax/transform/eliminate_common_subexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class CommonSubexprEliminator : public ExprMutator {
if (new_body.same_as(func->body)) {
return GetRef<Expr>(func);
}
return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->force_pure,
func->attrs, func->span);
return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs,
func->span);
}

// this should happen only for the inner function case
Expand Down
4 changes: 1 addition & 3 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,6 @@ class FunctionCreator : public ExprMutator {
/*body=*/body, //
/*ret_struct_info=*/NullOpt, //
/*is_pure=*/true, //
/*force_pure=*/false, //
/*attrs=*/DictAttrs(group_attrs));
Array<PrimExpr> free_vars =
FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; });
Expand All @@ -485,7 +484,6 @@ class FunctionCreator : public ExprMutator {
/*body=*/body, //
/*ret_struct_info=*/NullOpt, //
/*is_pure=*/true, //
/*force_pure=*/false, //
/*attrs=*/DictAttrs(group_attrs));
}
function_ = SymbolicVarRenewMutator::Renew(function);
Expand Down Expand Up @@ -1092,7 +1090,7 @@ class CompositeFunctionAnnotator : public ExprMutator {
auto new_body = VisitExpr(func->body);
if (!new_body.same_as(func->body)) {
auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info,
func->is_pure, func->force_pure, func->attrs, func->span);
func->is_pure, func->attrs, func->span);
builder_->UpdateFunction(entry.first, new_func);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ class GradientMutator : private ExprMutator {

Expr new_body = this->VisitExpr(func->body);

return Function(func->params, new_body, NullOpt, func->is_pure, func->force_pure, func->attrs);
return Function(func->params, new_body, NullOpt, func->is_pure, func->attrs);
}

Expr VisitExpr_(const SeqExprNode* seq_expr) final {
Expand Down
Loading

0 comments on commit d1651ce

Please sign in to comment.