diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 34b0155a07ac..e150ff38041b 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -118,7 +118,7 @@ class BufferNode : public Object { return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); } - static constexpr const char* _type_key = "Buffer"; + static constexpr const char* _type_key = "tir.Buffer"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); @@ -228,7 +228,7 @@ class DataProducerNode : public Object { void SHashReduce(SHashReducer hash_reduce) const {} - static constexpr const char* _type_key = "DataProducer"; + static constexpr const char* _type_key = "tir.DataProducer"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object); diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index b7cb68688066..d3a77cc81063 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -112,7 +112,7 @@ class LayoutNode : public Object { v->Visit("axes", &axes); } - static constexpr const char* _type_key = "Layout"; + static constexpr const char* _type_key = "tir.Layout"; TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object); }; @@ -308,7 +308,7 @@ class BijectiveLayoutNode : public Object { v->Visit("backward_rule", &backward_rule); } - static constexpr const char* _type_key = "BijectiveLayout"; + static constexpr const char* _type_key = "tir.BijectiveLayout"; TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object); }; diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index cfb7f1ef0d5a..1518d1ff548e 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -64,7 +64,7 @@ class StringImmNode : public PrimExprNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - static constexpr const char* _type_key = "StringImm"; + static constexpr const char* _type_key = "tir.StringImm"; TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode); }; @@ -101,7 +101,7 @@ class CastNode : public PrimExprNode { hash_reduce(value); } - static constexpr const char* _type_key = "Cast"; + static constexpr const char* _type_key = "tir.Cast"; TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode); }; @@ -149,7 +149,7 @@ class BinaryOpNode : public PrimExprNode { /*! \brief a + b */ class AddNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Add"; + static constexpr const char* _type_key = "tir.Add"; }; /*! @@ -165,7 +165,7 @@ class Add : public PrimExpr { /*! \brief a - b */ class SubNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Sub"; + static constexpr const char* _type_key = "tir.Sub"; }; /*! @@ -181,7 +181,7 @@ class Sub : public PrimExpr { /*! \brief a * b */ class MulNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Mul"; + static constexpr const char* _type_key = "tir.Mul"; }; /*! @@ -200,7 +200,7 @@ class Mul : public PrimExpr { */ class DivNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Div"; + static constexpr const char* _type_key = "tir.Div"; }; /*! @@ -219,7 +219,7 @@ class Div : public PrimExpr { */ class ModNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Mod"; + static constexpr const char* _type_key = "tir.Mod"; }; /*! @@ -235,7 +235,7 @@ class Mod : public PrimExpr { /*! \brief Floor division, floor(a/b) */ class FloorDivNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "FloorDiv"; + static constexpr const char* _type_key = "tir.FloorDiv"; }; /*! @@ -251,7 +251,7 @@ class FloorDiv : public PrimExpr { /*! \brief The remainder of the floordiv */ class FloorModNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "FloorMod"; + static constexpr const char* _type_key = "tir.FloorMod"; }; /*! @@ -267,7 +267,7 @@ class FloorMod : public PrimExpr { /*! \brief min(a, b) */ class MinNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Min"; + static constexpr const char* _type_key = "tir.Min"; }; /*! @@ -283,7 +283,7 @@ class Min : public PrimExpr { /*! \brief max(a, b) */ class MaxNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Max"; + static constexpr const char* _type_key = "tir.Max"; }; /*! @@ -330,7 +330,7 @@ class CmpOpNode : public PrimExprNode { /*! \brief a == b */ class EQNode : public CmpOpNode { public: - static constexpr const char* _type_key = "EQ"; + static constexpr const char* _type_key = "tir.EQ"; }; /*! @@ -346,7 +346,7 @@ class EQ : public PrimExpr { /*! \brief a != b */ class NENode : public CmpOpNode { public: - static constexpr const char* _type_key = "NE"; + static constexpr const char* _type_key = "tir.NE"; }; /*! @@ -362,7 +362,7 @@ class NE : public PrimExpr { /*! \brief a < b */ class LTNode : public CmpOpNode { public: - static constexpr const char* _type_key = "LT"; + static constexpr const char* _type_key = "tir.LT"; }; /*! @@ -378,7 +378,7 @@ class LT : public PrimExpr { /*! \brief a <= b */ struct LENode : public CmpOpNode { public: - static constexpr const char* _type_key = "LE"; + static constexpr const char* _type_key = "tir.LE"; }; /*! @@ -394,7 +394,7 @@ class LE : public PrimExpr { /*! \brief a > b */ class GTNode : public CmpOpNode { public: - static constexpr const char* _type_key = "GT"; + static constexpr const char* _type_key = "tir.GT"; }; /*! @@ -410,7 +410,7 @@ class GT : public PrimExpr { /*! \brief a >= b */ class GENode : public CmpOpNode { public: - static constexpr const char* _type_key = "GE"; + static constexpr const char* _type_key = "tir.GE"; }; /*! @@ -447,7 +447,7 @@ class AndNode : public PrimExprNode { hash_reduce(b); } - static constexpr const char* _type_key = "And"; + static constexpr const char* _type_key = "tir.And"; TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode); }; @@ -485,7 +485,7 @@ class OrNode : public PrimExprNode { hash_reduce(b); } - static constexpr const char* _type_key = "Or"; + static constexpr const char* _type_key = "tir.Or"; TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode); }; @@ -519,7 +519,7 @@ class NotNode : public PrimExprNode { hash_reduce(a); } - static constexpr const char* _type_key = "Not"; + static constexpr const char* _type_key = "tir.Not"; TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode); }; @@ -568,7 +568,7 @@ class SelectNode : public PrimExprNode { hash_reduce(false_value); } - static constexpr const char* _type_key = "Select"; + static constexpr const char* _type_key = "tir.Select"; TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode); }; @@ -617,7 +617,7 @@ class BufferLoadNode : public PrimExprNode { hash_reduce(indices); } - static constexpr const char* _type_key = "BufferLoad"; + static constexpr const char* _type_key = "tir.BufferLoad"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); }; @@ -664,7 +664,7 @@ class ProducerLoadNode : public PrimExprNode { hash_reduce(indices); } - static constexpr const char* _type_key = "ProducerLoad"; + static constexpr const char* _type_key = "tir.ProducerLoad"; TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode); }; @@ -722,7 +722,7 @@ class LoadNode : public PrimExprNode { hash_reduce(predicate); } - static constexpr const char* _type_key = "Load"; + static constexpr const char* _type_key = "tir.Load"; TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode); }; @@ -773,7 +773,7 @@ class RampNode : public PrimExprNode { hash_reduce(lanes); } - static constexpr const char* _type_key = "Ramp"; + static constexpr const char* _type_key = "tir.Ramp"; TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode); }; @@ -811,7 +811,7 @@ class BroadcastNode : public PrimExprNode { hash_reduce(lanes); } - static constexpr const char* _type_key = "Broadcast"; + static constexpr const char* _type_key = "tir.Broadcast"; TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode); }; @@ -856,7 +856,7 @@ class LetNode : public PrimExprNode { hash_reduce(body); } - static constexpr const char* _type_key = "Let"; + static constexpr const char* _type_key = "tir.Let"; TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode); }; @@ -928,7 +928,7 @@ class CallNode : public PrimExprNode { /*! \return Whether call node can be vectorized. */ bool is_vectorizable() const; - static constexpr const char* _type_key = "Call"; + static constexpr const char* _type_key = "tir.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); // Build-in intrinsics @@ -990,7 +990,7 @@ class ShuffleNode : public PrimExprNode { hash_reduce(indices); } - static constexpr const char* _type_key = "Shuffle"; + static constexpr const char* _type_key = "tir.Shuffle"; TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode); }; @@ -1048,7 +1048,7 @@ class CommReducerNode : public Object { hash_reduce(identity_element); } - static constexpr const char* _type_key = "CommReducer"; + static constexpr const char* _type_key = "tir.CommReducer"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); @@ -1108,7 +1108,7 @@ class ReduceNode : public PrimExprNode { hash_reduce(value_index); } - static constexpr const char* _type_key = "Reduce"; + static constexpr const char* _type_key = "tir.Reduce"; TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode); }; @@ -1136,7 +1136,7 @@ class AnyNode : public PrimExprNode { /*! \brief Convert to var. */ Var ToVar() const { return Var("any_dim", DataType::Int(32)); } - static constexpr const char* _type_key = "Any"; + static constexpr const char* _type_key = "tir.Any"; TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index ee8e1ebbfbe5..be1c567198d9 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -37,7 +37,7 @@ namespace tir { /*! \brief Base node of all statements. */ class StmtNode : public Object { public: - static constexpr const char* _type_key = "Stmt"; + static constexpr const char* _type_key = "tir.Stmt"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const uint32_t _type_child_slots = 15; @@ -79,7 +79,7 @@ class LetStmtNode : public StmtNode { hash_reduce(body); } - static constexpr const char* _type_key = "LetStmt"; + static constexpr const char* _type_key = "tir.LetStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode); }; @@ -134,7 +134,7 @@ class AttrStmtNode : public StmtNode { hash_reduce(body); } - static constexpr const char* _type_key = "AttrStmt"; + static constexpr const char* _type_key = "tir.AttrStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); }; @@ -181,7 +181,7 @@ class AssertStmtNode : public StmtNode { hash_reduce(body); } - static constexpr const char* _type_key = "AssertStmt"; + static constexpr const char* _type_key = "tir.AssertStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); }; @@ -244,7 +244,7 @@ class StoreNode : public StmtNode { hash_reduce(predicate); } - static constexpr const char* _type_key = "Store"; + static constexpr const char* _type_key = "tir.Store"; TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); }; @@ -295,7 +295,7 @@ class BufferStoreNode : public StmtNode { hash_reduce(indices); } - static constexpr const char* _type_key = "BufferStore"; + static constexpr const char* _type_key = "tir.BufferStore"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode); }; @@ -355,7 +355,7 @@ class BufferRealizeNode : public StmtNode { BufferRealizeNode(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) : buffer(buffer), bounds(bounds), condition(condition), body(body) {} - static constexpr const char* _type_key = "BufferRealize"; + static constexpr const char* _type_key = "tir.BufferRealize"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode); }; @@ -406,7 +406,7 @@ class ProducerStoreNode : public StmtNode { hash_reduce(indices); } - static constexpr const char* _type_key = "ProducerStore"; + static constexpr const char* _type_key = "tir.ProducerStore"; TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode); }; @@ -462,7 +462,7 @@ class ProducerRealizeNode : public StmtNode { hash_reduce(body); } - static constexpr const char* _type_key = "ProducerRealize"; + static constexpr const char* _type_key = "tir.ProducerRealize"; TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode); }; @@ -529,7 +529,7 @@ class AllocateNode : public StmtNode { */ TVM_DLL static int32_t constant_allocation_size(const Array& extents); - static constexpr const char* _type_key = "Allocate"; + static constexpr const char* _type_key = "tir.Allocate"; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); }; @@ -559,7 +559,7 @@ class FreeNode : public StmtNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); } - static constexpr const char* _type_key = "Free"; + static constexpr const char* _type_key = "tir.Free"; TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode); }; @@ -598,7 +598,7 @@ class SeqStmtNode : public StmtNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); } - static constexpr const char* _type_key = "SeqStmt"; + static constexpr const char* _type_key = "tir.SeqStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); }; @@ -697,7 +697,7 @@ class IfThenElseNode : public StmtNode { hash_reduce(else_case); } - static constexpr const char* _type_key = "IfThenElse"; + static constexpr const char* _type_key = "tir.IfThenElse"; TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode); }; @@ -731,7 +731,7 @@ class EvaluateNode : public StmtNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - static constexpr const char* _type_key = "Evaluate"; + static constexpr const char* _type_key = "tir.Evaluate"; TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); }; @@ -817,7 +817,7 @@ class ForNode : public StmtNode { hash_reduce(body); } - static constexpr const char* _type_key = "For"; + static constexpr const char* _type_key = "tir.For"; TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); }; @@ -860,7 +860,7 @@ class PrefetchNode : public StmtNode { PrefetchNode() = default; PrefetchNode(Buffer buffer, Array bounds) : buffer(buffer), bounds(bounds) {} - static constexpr const char* _type_key = "Prefetch"; + static constexpr const char* _type_key = "tir.Prefetch"; TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); }; diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 2a44909f531d..f1651c118010 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -266,7 +266,7 @@ class IterVarNode : public Object { hash_reduce(thread_tag); } - static constexpr const char* _type_key = "IterVar"; + static constexpr const char* _type_key = "tir.IterVar"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 94e9cf3e8213..8b7568574b47 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -138,11 +138,48 @@ def _convert(item, nodes): # TIR "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], - "StringImm": [_update_from_std_str("value")], - "Call": [_update_from_std_str("name")], - "AttrStmt": [_update_from_std_str("attr_key")], - "Layout": [_update_from_std_str("name")], - "Buffer": [_update_from_std_str("name"), _update_from_std_str("scope")], + "StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")], + "Cast": [_rename("tir.Cast")], + "Add": [_rename("tir.Add")], + "Sub": [_rename("tir.Sub")], + "Mul": [_rename("tir.Mul")], + "Div": [_rename("tir.Div")], + "Mod": [_rename("tir.Mod")], + "FloorDiv": [_rename("tir.FloorDiv")], + "FloorMod": [_rename("tir.FloorMod")], + "Min": [_rename("tir.Min")], + "Max": [_rename("tir.Max")], + "EQ": [_rename("tir.EQ")], + "NE": [_rename("tir.NE")], + "LT": [_rename("tir.LT")], + "LE": [_rename("tir.LE")], + "GT": [_rename("tir.GT")], + "GE": [_rename("tir.GE")], + "And": [_rename("tir.And")], + "Or": [_rename("tir.Or")], + "Not": [_rename("tir.Not")], + "Select": [_rename("tir.Select")], + "Load": [_rename("tir.Load")], + "BufferLoad": [_rename("tir.BufferLoad")], + "Ramp": [_rename("tir.Ramp")], + "Broadcast": [_rename("tir.Broadcast")], + "Shuffle": [_rename("tir.Shuffle")], + "Call": [_rename("tir.Call"), _update_from_std_str("name")], + "Let": [_rename("tir.Let")], + "Any": [_rename("tir.Any")], + "LetStmt": [_rename("tir.LetStmt")], + "AssertStmt": [_rename("tir.AssertStmt")], + "Store": [_rename("tir.Store")], + "BufferStore": [_rename("tir.BufferStore")], + "BufferRealize": [_rename("tir.BufferRealize")], + "Allocate": [_rename("tir.Allocate")], + "IfThenElse": [_rename("tir.IfThenElse")], + "Evaluate": [_rename("tir.Evaluate")], + "Prefetch": [_rename("tir.Prefetch")], + "AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")], + "Layout": [_rename("tir.Layout"), _update_from_std_str("name")], + "Buffer": [ + _rename("tir.Buffer"), _update_from_std_str("name"), _update_from_std_str("scope")], } return create_updater(node_map, "0.6", "0.7") diff --git a/python/tvm/te/hybrid/util.py b/python/tvm/te/hybrid/util.py index 810509b6e9cd..891d7baf893e 100644 --- a/python/tvm/te/hybrid/util.py +++ b/python/tvm/te/hybrid/util.py @@ -83,7 +83,7 @@ def replace(op): return _expr.ProducerLoad(buf, op.indices) return None - return stmt_functor.ir_transform(body, None, replace, ['ProducerStore', 'ProducerLoad']) + return stmt_functor.ir_transform(body, None, replace, ['tir.ProducerStore', 'tir.ProducerLoad']) def _is_tvm_arg_types(args): diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index e4dec5f30950..11bfb4c55921 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -24,7 +24,7 @@ from . import _ffi_api -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Buffer") class Buffer(Object): """Symbolic data buffer in TVM. @@ -247,6 +247,6 @@ def decl_buffer(shape, data_alignment, offset_factor, buffer_type) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.DataProducer") class DataProducer(Object): pass diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py index fd8c7a942297..161647377e37 100644 --- a/python/tvm/tir/data_layout.py +++ b/python/tvm/tir/data_layout.py @@ -20,7 +20,7 @@ from tvm.runtime import Object from . import _ffi_api -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Layout") class Layout(Object): """Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and @@ -77,7 +77,7 @@ def factor_of(self, axis): return _ffi_api.LayoutFactorOf(self, axis) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.BijectiveLayout") class BijectiveLayout(Object): """Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other. diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index d55370e8bdfa..f8cb05431a5b 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -321,7 +321,7 @@ def __init__(self, name, dtype): _ffi_api.SizeVar, name, dtype) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.IterVar") class IterVar(Object, ExprOp): """Represent iteration variable. @@ -373,7 +373,7 @@ def __init__(self, dom, var, iter_type, thread_tag=""): _ffi_api.IterVar, dom, var, iter_type, thread_tag) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.CommReducer") class CommReducer(Object): """Communicative reduce operator @@ -396,7 +396,7 @@ def __init__(self, lhs, rhs, result, identity_element): _ffi_api.CommReducer, lhs, rhs, result, identity_element) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Reduce") class Reduce(PrimExprWithOp): """Reduce node. @@ -475,7 +475,7 @@ def __bool__(self): return self.__nonzero__() -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.StringImm") class StringImm(ConstExpr): """String constant. @@ -499,7 +499,7 @@ def __ne__(self, other): return self.value != other -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Cast") class Cast(PrimExprWithOp): """Cast expression. @@ -516,7 +516,7 @@ def __init__(self, dtype, value): _ffi_api.Cast, dtype, value) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Add") class Add(BinaryOpExpr): """Add node. @@ -533,7 +533,7 @@ def __init__(self, a, b): _ffi_api.Add, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Sub") class Sub(BinaryOpExpr): """Sub node. @@ -550,7 +550,7 @@ def __init__(self, a, b): _ffi_api.Sub, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Mul") class Mul(BinaryOpExpr): """Mul node. @@ -567,7 +567,7 @@ def __init__(self, a, b): _ffi_api.Mul, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Div") class Div(BinaryOpExpr): """Div node. @@ -584,7 +584,7 @@ def __init__(self, a, b): _ffi_api.Div, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Mod") class Mod(BinaryOpExpr): """Mod node. @@ -601,7 +601,7 @@ def __init__(self, a, b): _ffi_api.Mod, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.FloorDiv") class FloorDiv(BinaryOpExpr): """FloorDiv node. @@ -618,7 +618,7 @@ def __init__(self, a, b): _ffi_api.FloorDiv, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.FloorMod") class FloorMod(BinaryOpExpr): """FloorMod node. @@ -635,7 +635,7 @@ def __init__(self, a, b): _ffi_api.FloorMod, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Min") class Min(BinaryOpExpr): """Min node. @@ -652,7 +652,7 @@ def __init__(self, a, b): _ffi_api.Min, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Max") class Max(BinaryOpExpr): """Max node. @@ -669,7 +669,7 @@ def __init__(self, a, b): _ffi_api.Max, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.EQ") class EQ(CmpExpr): """EQ node. @@ -686,7 +686,7 @@ def __init__(self, a, b): _ffi_api.EQ, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.NE") class NE(CmpExpr): """NE node. @@ -703,7 +703,7 @@ def __init__(self, a, b): _ffi_api.NE, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.LT") class LT(CmpExpr): """LT node. @@ -720,7 +720,7 @@ def __init__(self, a, b): _ffi_api.LT, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.LE") class LE(CmpExpr): """LE node. @@ -737,7 +737,7 @@ def __init__(self, a, b): _ffi_api.LE, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.GT") class GT(CmpExpr): """GT node. @@ -754,7 +754,7 @@ def __init__(self, a, b): _ffi_api.GT, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.GE") class GE(CmpExpr): """GE node. @@ -771,7 +771,7 @@ def __init__(self, a, b): _ffi_api.GE, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.And") class And(LogicalExpr): """And node. @@ -788,7 +788,7 @@ def __init__(self, a, b): _ffi_api.And, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Or") class Or(LogicalExpr): """Or node. @@ -805,7 +805,7 @@ def __init__(self, a, b): _ffi_api.Or, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Not") class Not(LogicalExpr): """Not node. @@ -819,7 +819,7 @@ def __init__(self, a): _ffi_api.Not, a) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Select") class Select(PrimExprWithOp): """Select node. @@ -847,7 +847,7 @@ def __init__(self, condition, true_value, false_value): _ffi_api.Select, condition, true_value, false_value) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Load") class Load(PrimExprWithOp): """Load node. @@ -871,7 +871,7 @@ def __init__(self, dtype, buffer_var, index, predicate=None): _ffi_api.Load, dtype, buffer_var, index, *args) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.BufferLoad") class BufferLoad(PrimExprWithOp): """Buffer load node. @@ -888,7 +888,7 @@ def __init__(self, buffer, indices): _ffi_api.BufferLoad, buffer, indices) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.ProducerLoad") class ProducerLoad(PrimExprWithOp): """Producer load node. @@ -905,7 +905,7 @@ def __init__(self, producer, indices): _ffi_api.ProducerLoad, producer, indices) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Ramp") class Ramp(PrimExprWithOp): """Ramp node. @@ -925,7 +925,7 @@ def __init__(self, base, stride, lanes): _ffi_api.Ramp, base, stride, lanes) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Broadcast") class Broadcast(PrimExprWithOp): """Broadcast node. @@ -942,7 +942,7 @@ def __init__(self, value, lanes): _ffi_api.Broadcast, value, lanes) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Shuffle") class Shuffle(PrimExprWithOp): """Shuffle node. @@ -959,7 +959,7 @@ def __init__(self, vectors, indices): _ffi_api.Shuffle, vectors, indices) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Call") class Call(PrimExprWithOp): """Call node. @@ -987,7 +987,7 @@ def __init__(self, dtype, name, args, call_type): _ffi_api.Call, dtype, name, args, call_type) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Let") class Let(PrimExprWithOp): """Let node. @@ -1007,7 +1007,7 @@ def __init__(self, var, value, body): _ffi_api.Let, var, value, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Any") class Any(PrimExpr): """Any node. """ diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index f4d84716a47d..4536580737e5 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -36,7 +36,7 @@ class Stmt(Object): """Base class of all the statements.""" -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.LetStmt") class LetStmt(Stmt): """LetStmt node. @@ -56,7 +56,7 @@ def __init__(self, var, value, body): _ffi_api.LetStmt, var, value, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.AssertStmt") class AssertStmt(Stmt): """AssertStmt node. @@ -76,7 +76,7 @@ def __init__(self, condition, message, body): _ffi_api.AssertStmt, condition, message, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.For") class For(Stmt): """For node. @@ -116,7 +116,7 @@ def __init__(self, for_type, device_api, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Store") class Store(Stmt): """Store node. @@ -140,7 +140,7 @@ def __init__(self, buffer_var, value, index, predicate=None): _ffi_api.Store, buffer_var, value, index, *args) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.BufferStore") class BufferStore(Stmt): """Buffer store node. @@ -160,7 +160,7 @@ def __init__(self, buffer, value, indices): _ffi_api.BufferStore, buffer, value, indices) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.BufferRealize") class BufferRealize(Stmt): """Buffer realize node. @@ -183,7 +183,7 @@ def __init__(self, buffer, bounds, condition, body): _ffi_api.BufferRealize, buffer, bounds, condition, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.ProducerStore") class ProducerStore(Stmt): """ProducerStore node. @@ -203,7 +203,7 @@ def __init__(self, producer, value, indices): _ffi_api.ProducerStore, producer, value, indices) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Allocate") class Allocate(Stmt): """Allocate node. @@ -235,7 +235,7 @@ def __init__(self, extents, condition, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.AttrStmt") class AttrStmt(Stmt): """AttrStmt node. @@ -258,7 +258,7 @@ def __init__(self, node, attr_key, value, body): _ffi_api.AttrStmt, node, attr_key, value, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Free") class Free(Stmt): """Free node. @@ -272,7 +272,7 @@ def __init__(self, buffer_var): _ffi_api.Free, buffer_var) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.ProducerRealize") class ProducerRealize(Stmt): """ProducerRealize node. @@ -299,7 +299,7 @@ def __init__(self, _ffi_api.ProducerRealize, producer, bounds, condition, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.SeqStmt") class SeqStmt(Stmt): """Sequence of statements. @@ -319,7 +319,7 @@ def __len__(self): return len(self.seq) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.IfThenElse") class IfThenElse(Stmt): """IfThenElse node. @@ -339,7 +339,7 @@ def __init__(self, condition, then_case, else_case): _ffi_api.IfThenElse, condition, then_case, else_case) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Evaluate") class Evaluate(Stmt): """Evaluate node. @@ -353,7 +353,7 @@ def __init__(self, value): _ffi_api.Evaluate, value) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Prefetch") class Prefetch(Stmt): """Prefetch node. diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc index 868845fdc237..d1e24b94a32f 100644 --- a/src/tir/pass/hoist_if_then_else.cc +++ b/src/tir/pass/hoist_if_then_else.cc @@ -159,7 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { } }); - return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array{"For"}); + return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array{"tir.For"}); } // Remove IfThenElse node from a For node. @@ -183,9 +183,9 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { } }); - then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array{"IfThenElse"}); + then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array{"tir.IfThenElse"}); if (if_stmt.as()->else_case.defined()) { - else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array{"IfThenElse"}); + else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array{"tir.IfThenElse"}); } return std::make_pair(then_for, else_for); @@ -393,7 +393,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { *ret = new_for; } }); - return IRTransform(stmt, nullptr, replace_top_for, Array{"For"}); + return IRTransform(stmt, nullptr, replace_top_for, Array{"tir.For"}); } Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); } diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index bafa9577cb36..1a7163ff129d 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -214,7 +214,7 @@ def vectorizer(op): def _transform(f, *_): return f.with_body( - tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For'])) + tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['tir.For'])) return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize") with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, MyVectorize())]}): diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 34db08f40c2f..1173b71ade6f 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -724,7 +724,7 @@ def vectorizer(op): def _transform(f, *_): return f.with_body( - tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['For'])) + tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['tir.For'])) return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize") diff --git a/tests/python/unittest/test_tir_pass_hoist_if.py b/tests/python/unittest/test_tir_pass_hoist_if.py index 346239d302cf..80e93a706ee7 100644 --- a/tests/python/unittest/test_tir_pass_hoist_if.py +++ b/tests/python/unittest/test_tir_pass_hoist_if.py @@ -33,12 +33,12 @@ def _visit(op): if isinstance(op, tvm.tir.IfThenElse): global var_list tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars) - val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))] + val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))] var_list.clear() elif isinstance(op, tvm.tir.For): - val = [(op.body,), ("For", op.loop_var.name)] + val = [(op.body,), ("tir.For", op.loop_var.name)] elif isinstance(op, tvm.tir.AttrStmt): - val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))] + val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))] else: return node_dict[key] = val @@ -68,9 +68,9 @@ def test_basic(): stmt = ib.get() new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), - ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')), - ('For', 'i'): (('IfThenElse', ('i',)),)} + expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), ('tir.For', 'j')), + ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) def test_no_else(): @@ -87,9 +87,9 @@ def test_no_else(): stmt = ib.get() new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), - ('IfThenElse', ('i',)): (('For', 'j'), None), - ('For', 'i'): (('IfThenElse', ('i',)),)} + expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), + ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) def test_attr_stmt(): @@ -114,10 +114,10 @@ def test_attr_stmt(): stmt = ib.get() new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')), - ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),), - ('AttrStmt', 'thread_extent', 64): (('For', 'i'),), - ('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)} + expected_struct = {('tir.For', 'k'): (None,), ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')), + ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),), ('tir.For', 'i'): (('tir.For', 'j'),), + ('tir.AttrStmt', 'thread_extent', 64): (('tir.For', 'i'),), + ('tir.AttrStmt', 'thread_extent', 32): (('tir.AttrStmt', 'thread_extent', 64),)} verify_structure(new_stmt, expected_struct) def test_nested_for(): @@ -138,9 +138,9 @@ def test_nested_for(): stmt = ib.get() new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),), - ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None), - ('For', 'i'): (('IfThenElse', ('i',)),)} + expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.For', 'l'): (('tir.IfThenElse', ('i', 'j')),), + ('tir.For', 'k'): (('tir.For', 'l'),), ('tir.For', 'j'): (None,), ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), + ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) def test_if_block(): @@ -171,10 +171,10 @@ def test_if_block(): stmt = ib.get() new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None), - ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),), - ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),), - ('IfThenElse', ('n',)): (('For', 'j'), None)} + expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.IfThenElse', ('j',)): (None, None), + ('tir.For', 'l'): (None,), ('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'j'),), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),), + ('tir.IfThenElse', ('n',)): (('tir.For', 'j'), None)} verify_structure(new_stmt, expected_struct) diff --git a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py index 7bf70119e4aa..38529e927d52 100644 --- a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py +++ b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py @@ -37,7 +37,7 @@ def postorder(op): if op.name == "TestA": return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1) return op - body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["Call"]) + body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"]) stmt_list = tvm.tir.stmt_list(body.body.body) assert stmt_list[0].value.args[0].name == "TestB" assert stmt_list[1].value.value == 0 diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py index db5057288cf0..17f864f4414e 100644 --- a/tutorials/dev/low_level_custom_pass.py +++ b/tutorials/dev/low_level_custom_pass.py @@ -84,7 +84,7 @@ loops = [] def find_width8(op): - """ Find all the 'For' nodes whose extent can be divided by 8. """ + """ Find all the 'tir.For' nodes whose extent can be divided by 8. """ if isinstance(op, tvm.tir.For): if isinstance(op.extent, tvm.tir.IntImm): if op.extent.value % 8 == 0: @@ -129,7 +129,7 @@ def vectorize(f, mod, ctx): # The last list arugment indicates what kinds of nodes will be transformed. # Thus, in this case only `For` nodes will call `vectorize8` return f.with_body( - tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['For'])) + tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['tir.For'])) ##################################################################### diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 37b4e0e3e7c4..207f784b5885 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -87,7 +87,7 @@ def _post_order(op): return op ret = tvm.tir.stmt_functor.ir_transform( - stmt.body, None, _post_order, ["Call"]) + stmt.body, None, _post_order, ["tir.Call"]) if not fail[0] and all(x is not None for x in gemm_offsets): def _visit(op): @@ -132,7 +132,7 @@ def _do_fold(stmt): def _ftransform(f, mod, ctx): return f.with_body(tvm.tir.stmt_functor.ir_transform( - f.body, _do_fold, None, ["AttrStmt"])) + f.body, _do_fold, None, ["tir.AttrStmt"])) return tvm.tir.transform.prim_func_pass( _ftransform, opt_level=0, name="tir.vta.FoldUopLoop") @@ -188,7 +188,7 @@ def _post_order(op): stmt_in = f.body stmt = tvm.tir.stmt_functor.ir_transform( - stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) + stmt_in, None, _post_order, ["tir.Allocate", "tir.Load", "tir.Store"]) for buffer_var, new_var in rw_info.items(): stmt = tvm.tir.LetStmt( @@ -254,7 +254,7 @@ def _post_order(op): raise RuntimeError("not reached") stmt_in = f.body stmt = tvm.tir.stmt_functor.ir_transform( - stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"]) + stmt_in, _pre_order, _post_order, ["tir.Allocate", "tir.AttrStmt", "tir.For"]) assert len(lift_stmt) == 1 return f.with_body(_merge_block(lift_stmt[0], stmt)) @@ -277,7 +277,7 @@ def _do_fold(stmt): def _ftransform(f, mod, ctx): return f.with_body(tvm.tir.stmt_functor.ir_transform( - f.body, _do_fold, None, ["AttrStmt"])) + f.body, _do_fold, None, ["tir.AttrStmt"])) return tvm.tir.transform.prim_func_pass( _ftransform, opt_level=0, name="tir.vta.InjectSkipCopy") @@ -307,7 +307,7 @@ def _do_fold(stmt): op.device_api, op.body) return None return f.with_body(tvm.tir.stmt_functor.ir_transform( - f.body, None, _do_fold, ["AttrStmt"])) + f.body, None, _do_fold, ["tir.AttrStmt"])) return tvm.transform.Sequential( [tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"), tvm.tir.transform.CoProcSync()], @@ -708,7 +708,7 @@ def _do_fold(op): return None return func.with_body(tvm.tir.stmt_functor.ir_transform( - func.body, _do_fold, None, ["AttrStmt"])) + func.body, _do_fold, None, ["tir.AttrStmt"])) return tvm.tir.transform.prim_func_pass( _ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip") @@ -737,7 +737,7 @@ def _do_fold(stmt): return stmt return func.with_body(tvm.tir.stmt_functor.ir_transform( - func.body, None, _do_fold, ["AttrStmt"])) + func.body, None, _do_fold, ["tir.AttrStmt"])) return tvm.tir.transform.prim_func_pass( _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope") @@ -956,7 +956,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): return stmt return func.with_body(tvm.tir.stmt_functor.ir_transform( - func.body, None, _do_fold, ["AttrStmt"])) + func.body, None, _do_fold, ["tir.AttrStmt"])) return tvm.tir.transform.prim_func_pass( _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")