Skip to content

Commit

Permalink
[Bugfix] Update FAttrsGetter to return Map<String, ObjectRef> (#17096)
Browse files Browse the repository at this point in the history
Prior to this commit, `FAttrsGetter` was defined as a function that
returned `Map<String, String>`.  However, it is used to define
attributes in a `Map<String, ObjectRef>`, and in some cases is used to
define attributes whose value is a dictionary (e.g. `msc_attrs_getter`
in `python/tvm/contrib/msc/core/transform/pattern.py`).

This commit updates the type signature of `FAttrsGetter` to match its
usage, returning a `Map<String, ObjectRef>`.
  • Loading branch information
Lunderberg committed Jun 18, 2024
1 parent f6fe2aa commit e58cb27
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ class FusionPatternNode : public Object {
* \brief The function to get attributes for fused function
*
* It should have signature
* Map<String, String>(const Map<String, Expr>& context)
* Map<String, ObjectRef>(const Map<String, Expr>& context)
*/
Optional<PackedFunc> attrs_getter;

Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ class PatternBasedPartitioner : ExprVisitor {
using PatternCheckContext = transform::PatternCheckContext;
using ExprVisitor::VisitExpr_;
using FCheckMatch = runtime::TypedPackedFunc<bool(const transform::PatternCheckContext&)>;
using FAttrsGetter = runtime::TypedPackedFunc<Map<String, String>(const Map<String, Expr>&)>;
using FAttrsGetter = runtime::TypedPackedFunc<Map<String, ObjectRef>(const Map<String, Expr>&)>;

static GroupMap Run(String pattern_name, DFPattern pattern,
Map<String, DFPattern> annotation_patterns, FCheckMatch check, Expr expr,
Expand Down

0 comments on commit e58cb27

Please sign in to comment.