Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/relax/transform/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,37 @@ class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator {
Map<tir::Var, tir::Var> var_map_;
};

/*!
* \brief Copy a function while renewing the relax Vars and the tir Vars.
* \details All variables that are bound inside the original function would be copied to satisfy
* the restriction in the well-formed check: Variables in Relax must be bound exactly once.
*/
class FunctionCopier : public ExprMutator {
public:
Function Copy(Function func) {
auto new_func = Downcast<Function>(VisitExpr(func));
return SymbolicVarRenewMutator::Renew(new_func);
}

Var VisitVarDef_(const DataflowVarNode* var) override {
Var new_var = ExprMutator::VisitVarDef_(var);
Var copied_var = DataflowVar(new_var->name_hint(), GetStructInfo(new_var), new_var->span);
var_remap_[var->vid] = copied_var;
var_map.Set(GetRef<Var>(var), copied_var);
return copied_var;
}

Var VisitVarDef_(const VarNode* var) override {
Var new_var = ExprMutator::VisitVarDef_(var);
Var copied_var = Var(new_var->name_hint(), GetStructInfo(new_var), new_var->span);
var_remap_[var->vid] = copied_var;
var_map.Set(GetRef<Var>(var), copied_var);
return copied_var;
}

Map<Var, Var> var_map;
};

/*!
* \brief Create a Constant with a scalar
*
Expand Down
30 changes: 2 additions & 28 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,39 +111,13 @@ bool IsLeafOrTuple(const Expr& expr) {
expr.as<OpNode>() || expr.as<TupleNode>();
}

/*! \brief Helper to implement CopyWithNewVars.*/
class FunctionCopier : public ExprMutator {
public:
static Function Transform(Function func) {
FunctionCopier copier;
// All variables that are bound inside the original function would be copied
// to satisfy the restriction in the well-formed check: Variables in Relax
// must be bound exactly once.
auto new_func = Downcast<Function>(copier.VisitExpr(func));
return SymbolicVarRenewMutator::Renew(new_func);
}

Var VisitVarDef_(const DataflowVarNode* var) override {
Var new_var = ExprMutator::VisitVarDef_(var);
Var copied_var = DataflowVar(new_var->name_hint(), GetStructInfo(new_var), new_var->span);
var_remap_[var->vid] = copied_var;
return copied_var;
}

Var VisitVarDef_(const VarNode* var) override {
Var new_var = ExprMutator::VisitVarDef_(var);
Var copied_var = Var(new_var->name_hint(), GetStructInfo(new_var), new_var->span);
var_remap_[var->vid] = copied_var;
return copied_var;
}
};

/*!
* \brief Copy a new Relax function with new remapped vars and symbolic vars.
* To get the var mapping from old vars to new vars, see FuncCopier in src/relax/transform/utils.h.
* \param func The Relax function we want to copy.
* \return The copied function.
*/
Function CopyWithNewVars(Function func) { return FunctionCopier::Transform(func); }
Function CopyWithNewVars(Function func) { return FunctionCopier().Copy(func); }

TVM_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars);

Expand Down