Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ class IntImm : public PrimExpr {
TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
};

/*!
Expand Down Expand Up @@ -572,6 +573,7 @@ class FloatImm : public PrimExpr {
TVM_DLL FloatImm(DataType dtype, double value, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode);
};

/*!
Expand Down
30 changes: 30 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class StringImm : public PrimExpr {
public:
TVM_DLL StringImm(String value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode);
};

/*!
Expand Down Expand Up @@ -117,6 +118,7 @@ class Cast : public PrimExpr {
public:
TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode);
};

/*!
Expand Down Expand Up @@ -165,6 +167,7 @@ class Add : public PrimExpr {
public:
TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode);
};

/*! \brief a - b */
Expand All @@ -181,6 +184,7 @@ class Sub : public PrimExpr {
public:
TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode);
};

/*! \brief a * b */
Expand All @@ -197,6 +201,7 @@ class Mul : public PrimExpr {
public:
TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode);
};

/*!
Expand All @@ -216,6 +221,7 @@ class Div : public PrimExpr {
public:
TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode);
};

/*!
Expand All @@ -235,6 +241,7 @@ class Mod : public PrimExpr {
public:
TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode);
};

/*! \brief Floor division, floor(a/b) */
Expand All @@ -251,6 +258,7 @@ class FloorDiv : public PrimExpr {
public:
TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode);
};

/*! \brief The remainder of the floordiv */
Expand All @@ -267,6 +275,7 @@ class FloorMod : public PrimExpr {
public:
TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode);
};

/*! \brief min(a, b) */
Expand All @@ -283,6 +292,7 @@ class Min : public PrimExpr {
public:
TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode);
};

/*! \brief max(a, b) */
Expand All @@ -299,6 +309,7 @@ class Max : public PrimExpr {
public:
TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode);
};

/*!
Expand Down Expand Up @@ -347,6 +358,7 @@ class EQ : public PrimExpr {
public:
TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode);
};

/*! \brief a != b */
Expand All @@ -363,6 +375,7 @@ class NE : public PrimExpr {
public:
TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode);
};

/*! \brief a < b */
Expand All @@ -379,6 +392,7 @@ class LT : public PrimExpr {
public:
TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode);
};

/*! \brief a <= b */
Expand All @@ -395,6 +409,7 @@ class LE : public PrimExpr {
public:
TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode);
};

/*! \brief a > b */
Expand All @@ -411,6 +426,7 @@ class GT : public PrimExpr {
public:
TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode);
};

/*! \brief a >= b */
Expand All @@ -427,6 +443,7 @@ class GE : public PrimExpr {
public:
TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode);
};

/*! \brief a && b */
Expand Down Expand Up @@ -466,6 +483,7 @@ class And : public PrimExpr {
public:
TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode);
};

/*! \brief a || b */
Expand Down Expand Up @@ -505,6 +523,7 @@ class Or : public PrimExpr {
public:
TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode);
};

/*! \brief !a */
Expand Down Expand Up @@ -540,6 +559,7 @@ class Not : public PrimExpr {
public:
TVM_DLL Not(PrimExpr a, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode);
};

/*!
Expand Down Expand Up @@ -591,6 +611,7 @@ class Select : public PrimExpr {
TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode);
};

/*!
Expand Down Expand Up @@ -706,6 +727,7 @@ class ProducerLoad : public PrimExpr {
TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode);
};

/*!
Expand Down Expand Up @@ -765,6 +787,7 @@ class Load : public PrimExpr {
TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate,
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LoadNode);
};

/*!
Expand Down Expand Up @@ -817,6 +840,7 @@ class Ramp : public PrimExpr {
public:
TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode);
};

/*! \brief Create a vector where all the elements are value. */
Expand Down Expand Up @@ -856,6 +880,7 @@ class Broadcast : public PrimExpr {
public:
TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode);
};

/*!
Expand Down Expand Up @@ -902,6 +927,7 @@ class Let : public PrimExpr {
public:
TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode);
};

/*!
Expand Down Expand Up @@ -948,6 +974,7 @@ class Call : public PrimExpr {
public:
TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};

/*!
Expand Down Expand Up @@ -995,6 +1022,7 @@ class Shuffle : public PrimExpr {
TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode);
};

// Reduce operator
Expand Down Expand Up @@ -1124,6 +1152,7 @@ class Reduce : public PrimExpr {
int value_index, Array<PrimExpr> init, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode);
};

/*! \brief Any shape. */
Expand Down Expand Up @@ -1159,6 +1188,7 @@ class Any : public PrimExpr {
TVM_DLL Any(Span span = Span());

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode);
};

/*
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class LetStmt : public Stmt {
TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode);
};

/*!
Expand Down Expand Up @@ -158,6 +159,7 @@ class AttrStmt : public Stmt {
TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode);
};

/*!
Expand Down Expand Up @@ -206,6 +208,7 @@ class AssertStmt : public Stmt {
TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode);
};

/*!
Expand Down Expand Up @@ -271,6 +274,7 @@ class Store : public Stmt {
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StoreNode);
};

/*!
Expand Down Expand Up @@ -442,6 +446,7 @@ class ProducerStore : public Stmt {
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode);
};

/*!
Expand Down Expand Up @@ -505,6 +510,7 @@ class ProducerRealize : public Stmt {
String storage_scope = "", Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode);
};

/*!
Expand Down Expand Up @@ -679,6 +685,7 @@ class AllocateConst : public Stmt {
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode);
};

/*! \brief Declare a buffer that can be used in the body */
Expand Down Expand Up @@ -812,6 +819,7 @@ class SeqStmt : public Stmt {
};

TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode);
};

/*!
Expand Down Expand Up @@ -898,6 +906,7 @@ class Evaluate : public Stmt {
explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}

TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode);
};

/*!
Expand Down Expand Up @@ -1055,6 +1064,7 @@ class While : public Stmt {
TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode);
};

/*!
Expand Down Expand Up @@ -1099,6 +1109,7 @@ class Prefetch : public Stmt {
TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode);
};

/*!
Expand Down Expand Up @@ -1203,6 +1214,7 @@ class MatchBufferRegion : public ObjectRef {
TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);

TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode);
};

/*!
Expand Down