Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions include/tvm/tirx/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ TVM_DLL const Op& large_uint_imm();
* (i.e., round(x.1) = x and round (x.5) = x+1)
*/
TVM_DLL const Op& q_multiply_shift();
TVM_DLL const Op& q_multiply_shift_per_axis();

/*!
* \brief Returns the address of an element in the buffer (see pseudocode below).
Expand Down Expand Up @@ -499,6 +500,11 @@ TVM_DLL const Op& tvm_call_trace_packed_lowered();
*/
TVM_DLL const Op& tvm_storage_sync();

/*!
* \brief Marker where a transform should replace generated kernel initialization.
*/
TVM_DLL const Op& tvm_kernel_replace_point();

/*!
* \brief See pseudo code
*
Expand Down Expand Up @@ -912,6 +918,11 @@ TVM_DLL const Op& cuda_atomic_add();
*/
TVM_DLL const Op& cuda_thread_fence();

/*!
* \brief tvm intrinsic for cuda warpgroup sync instruction
*/
TVM_DLL const Op& cuda_warpgroup_sync();

/*!
* \brief Warp-level butterfly shuffle-XOR reduction.
*
Expand Down Expand Up @@ -952,6 +963,11 @@ TVM_DLL const Op& cuda_cta_sync();
*/
TVM_DLL const Op& cuda_grid_sync();

/*!
* \brief tvm intrinsic for cuda cluster-wide sync instruction
*/
TVM_DLL const Op& cuda_cluster_sync();

/*!
* \brief tvm intrinsic that returns ``cooperative_groups::thread_rank()``
* for the enclosing CTA (linear thread index within the block).
Expand Down Expand Up @@ -1053,25 +1069,19 @@ TVM_DLL const Op& ptx_reduce3_max_f32();
*/
TVM_DLL const Op& ptx_reduce3_min_f32();

/*!
* \brief tvm intrinsic for PTX packed add instruction (sm_100a+)
*/
TVM_DLL const Op& ptx_add_packed_f32x2();

/*!
* \brief tvm intrinsic for PTX packed subtract instruction (sm_100a+)
*/
TVM_DLL const Op& ptx_sub_packed_f32x2();

/*!
* \brief tvm intrinsic for PTX packed multiply instruction (sm_100a+)
*/
TVM_DLL const Op& ptx_mul_packed_f32x2();

/*!
* \brief tvm intrinsic for PTX packed FMA instruction (sm_100a+)
*/
TVM_DLL const Op& ptx_fma_packed_f32x2();
TVM_DLL const Op& ptx_add_f32();
TVM_DLL const Op& ptx_add_f32x2();
TVM_DLL const Op& ptx_add_f64();
TVM_DLL const Op& ptx_sub_f32();
TVM_DLL const Op& ptx_sub_f32x2();
TVM_DLL const Op& ptx_sub_f64();
TVM_DLL const Op& ptx_mul_f32();
TVM_DLL const Op& ptx_mul_f32x2();
TVM_DLL const Op& ptx_mul_f64();
TVM_DLL const Op& ptx_fma_f32();
TVM_DLL const Op& ptx_fma_f32x2();
TVM_DLL const Op& ptx_fma_f64();
TVM_DLL const Op& ptx_max_f32();

} // namespace builtin
} // namespace tirx
Expand Down
3 changes: 0 additions & 3 deletions include/tvm/tirx/exec_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ struct ExecContext {
/*! \brief Apply modulo filter on a factorized CTA axis such as cbx/cby/cbz. */
bool WithCtaAxisModulo(const std::string& axis, int64_t modulus, int64_t residue,
ExecContext* out, std::string* err) const;

/*! \brief Apply scope_switch; A preserved, split recomputed for new scope_kind. */
bool WithScopeSwitch(ScopeKind new_scope_kind, ExecContext* out, std::string* err) const;
};

/*!
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/tirx/exec_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ namespace tvm {
namespace tirx {

/*!
* \brief The target execution scope kind of an ExecScopeStmt.
* \brief The target execution scope kind of a tile primitive call.
*
* Replaces the string-keyed name of ExecScope. One value per user-facing
* `with T.<kind>():` construct. Ordered from coarsest to finest; smaller
* integer = wider scope, so ``ScopeKindHigher`` is a plain ``<``.
* Identifies the granularity at which an op executes (the per-call
* ``scope`` on a ``TilePrimitiveCall``, e.g. ``Tx.warp.copy(...)``).
* Ordered from coarsest to finest; smaller integer = wider scope, so
* ``ScopeKindHigher`` is a plain ``<``.
*/
enum class ScopeKind : int {
kCluster = 2,
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tirx/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ namespace attr {
*/
constexpr const char* kKernelLaunchParams = "tirx.kernel_launch_params";

/*!
* \brief CUDA launch bound minimum CTAs per SM.
*
* Type: IntImm
*/
constexpr const char* kLaunchBoundsMinBlocksPerSM = "tirx.launch_bounds_min_blocks_per_sm";

/*!
* \brief Whether to set noalias rule on the function arguments.
*
Expand Down
7 changes: 5 additions & 2 deletions include/tvm/tirx/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <tvm/ir/type.h>
#include <tvm/tirx/builtin.h>
#include <tvm/tirx/expr.h>
#include <tvm/tirx/op_attr_types.h>
#include <tvm/tirx/stmt.h>
#include <tvm/tirx/target_builtin/cuda.h>
#include <tvm/tirx/target_builtin/trn.h>
Expand All @@ -43,8 +44,10 @@

namespace tvm {

#define TVM_TIR_REGISTER_OP(OpName) \
TVM_REGISTER_OP("tirx." OpName).set_attr<TScriptPrinterName>("TScriptPrinterName", OpName)
#define TVM_TIR_REGISTER_OP(OpName) \
TVM_REGISTER_OP("tirx." OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", OpName) \
.set_attr<TIRxOpCategory>("TIRxOpCategory", ffi::String("builtin"), /*plevel=*/1)

#define TVM_TIRX_REGISTER_OP(OpName) TVM_TIR_REGISTER_OP(OpName)

Expand Down
17 changes: 17 additions & 0 deletions include/tvm/tirx/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ enum class ScriptDtypePrintLocation : int {

using TScriptDtypePrintLocation = int64_t;

/*!
* \brief Broad TIRx op category.
*
* Expected values:
* - "builtin"
* - "tile_primitive"
* - "device_intrin"
*/
using TIRxOpCategory = ffi::String;

/*!
* \brief Device intrinsic namespace.
*
* Expected values include "cuda", "ptx", "nvshmem", "nki", and "metal".
*/
using TDeviceIntrinsicNamespace = ffi::String;

/*!
* \brief The effect type of the call.
*/
Expand Down
46 changes: 0 additions & 46 deletions include/tvm/tirx/script/builder/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,52 +247,6 @@ class BlockInitFrame : public TIRFrame {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockInitFrame, TIRFrame, BlockInitFrameNode);
};

/*!
* \brief A frame that represents an execution scope (e.g. cta, warp, thread).
*
* When exiting this frame, it produces an ExecScopeStmt wrapping the body.
* This is the new IR pattern, replacing the old pattern of storing exec_scope on SBlock.
*
* \sa ExecScopeFrame
*/
class ExecScopeFrameNode : public TIRFrameNode {
public:
/*! \brief The execution scope (always plain kind; no slice). */
ffi::Optional<tvm::tirx::ExecScope> exec_scope;
/*! \brief Optional surface-syntax guards for ``with Tx.scope(cond)``. */
ffi::Array<PrimExpr> guards;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ExecScopeFrameNode>()
.def_ro("exec_scope", &ExecScopeFrameNode::exec_scope)
.def_ro("guards", &ExecScopeFrameNode::guards);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.ExecScopeFrame", ExecScopeFrameNode,
TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to ExecScopeFrameNode.
*
* \sa ExecScopeFrameNode
*/
class ExecScopeFrame : public TIRFrame {
public:
explicit ExecScopeFrame(ffi::ObjectPtr<ExecScopeFrameNode> data) : TIRFrame(ffi::UnsafeInit{}) {
TVM_FFI_ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ExecScopeFrame, TIRFrame, ExecScopeFrameNode);
};

/*!
* \brief A frame that represents the for loop.
*
Expand Down
15 changes: 0 additions & 15 deletions include/tvm/tirx/script/builder/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,6 @@ SBlockFrame Block(ffi::String name, bool no_realize = false, ffi::String exec_sc

void TilePrimitiveCall(tvm::tirx::TilePrimitiveCall op_call);

/*!
* \brief Create an ExecScopeFrame for execution scope contexts.
* \param exec_scope_name The name of the execution scope (e.g. "cta", "warp").
* \return The ExecScopeFrame.
*/
ExecScopeFrame ExecScopeBlock(ffi::String exec_scope_name,
ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());

ExecScopeFrame Kernel(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
ExecScopeFrame Cluster(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
ExecScopeFrame WarpGroup(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
ExecScopeFrame CTA(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
ExecScopeFrame Warp(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());
ExecScopeFrame Thread(ffi::Array<PrimExpr> guards = ffi::Array<PrimExpr>());

ffi::Array<tvm::tirx::Var> KernelId(ffi::Array<PrimExpr> extents, ffi::String parent);

ffi::Array<tvm::tirx::Var> CtaId(ffi::Array<PrimExpr> extents, ffi::String parent);
Expand Down
44 changes: 1 addition & 43 deletions include/tvm/tirx/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -952,53 +952,11 @@ class SBlockRealize : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(SBlockRealizeNode);
};

/*!
* \brief A statement that annotates the execution scope for its body.
*
* ExecScopeStmt represents a hardware execution scope (e.g. cta, warp, thread)
* that wraps a body statement. This decouples the execution scope concept from
* SBlock, making the IR structure cleaner.
*
* Example:
* \code
* with T.cta():
* ...
* \endcode
*/
class ExecScopeStmtNode : public StmtNode {
public:
/*! \brief The execution scope. */
ExecScope exec_scope;
/*! \brief The body statement under this execution scope. */
Stmt body;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ExecScopeStmtNode>()
.def_ro("exec_scope", &ExecScopeStmtNode::exec_scope)
.def_ro("body", &ExecScopeStmtNode::body);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ExecScopeStmt", ExecScopeStmtNode, StmtNode);
};

/*!
* \brief Managed reference to ExecScopeStmtNode.
* \sa ExecScopeStmtNode
*/
class ExecScopeStmt : public Stmt {
public:
TVM_DLL ExecScopeStmt(ExecScope exec_scope, Stmt body, Span span = Span());

TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExecScopeStmt, Stmt, ExecScopeStmtNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExecScopeStmtNode);
};

/*!
* \brief Standalone statement that declares a scope-id binding (e.g. cta_id,
* warp_id, lane_id). Carries a ``ScopeIdDef`` value.
*
* Unlike legacy ``ExecScopeStmt::scope_id_def`` (an array payload), each
* declaration is a flat stmt within the device-region body. The declared
* Each declaration is a flat stmt within the device-region body. The declared
* ``Var``\ s are visible in subsequent stmts in the same enclosing scope
* (the AttrStmt ``kDeviceEntry`` body), analogous to ``BindNode``.
*/
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/tirx/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SBlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SBlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ExecScopeStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ScopeIdDefStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const tirx::TilePrimitiveCallNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const ffi::Object* op, Args...) {
Expand All @@ -127,7 +126,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
IR_STMT_FUNCTOR_DISPATCH(SBlockNode);
IR_STMT_FUNCTOR_DISPATCH(SBlockRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(ExecScopeStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ScopeIdDefStmtNode);
IR_STMT_FUNCTOR_DISPATCH(tirx::TilePrimitiveCallNode);
vtable.Finalize();
Expand Down Expand Up @@ -185,7 +183,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SBlockNode* op) override;
void VisitStmt_(const SBlockRealizeNode* op) override;
void VisitStmt_(const ExecScopeStmtNode* op) override;
void VisitStmt_(const ScopeIdDefStmtNode* op) override;
void VisitStmt_(const tirx::TilePrimitiveCallNode* op) override;
};
Expand Down Expand Up @@ -304,7 +301,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const EvaluateNode* op) override;
Stmt VisitStmt_(const SBlockNode* op) override;
Stmt VisitStmt_(const SBlockRealizeNode* op) override;
Stmt VisitStmt_(const ExecScopeStmtNode* op) override;
Stmt VisitStmt_(const ScopeIdDefStmtNode* op) override;
Stmt VisitStmt_(const tirx::TilePrimitiveCallNode* op) override;
/*!
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/tirx/target_builtin/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,6 @@ TVM_DLL const Op& mma_fill_legacy();
*/
TVM_DLL const Op& ptx_ldg32();

/*!
* \brief tvm intrinsic for ptx predicate load with 32-bit data type.
*
*/
TVM_DLL const Op& ptx_ldg32();

/*!
* \brief tvm intrinsic for sparse tensor core ptx instructions.
*
Expand Down Expand Up @@ -374,6 +368,12 @@ TVM_DLL const Op& ptx_fence_mbarrier_init();
*/
TVM_DLL const Op& ptx_fetch_register();

/*!
* \brief PTX programmatic dependent launch synchronization.
*/
TVM_DLL const Op& ptx_griddepcontrol_wait();
TVM_DLL const Op& ptx_griddepcontrol_launch_dependents();

/*!
* \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer.
* For example, if each thread in a warp of size 32 has 4 elements from the result of
Expand Down
15 changes: 9 additions & 6 deletions include/tvm/tirx/tirx_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ TVM_DLL const Op& sqrt();

TVM_DLL const Op& exp();

TVM_DLL const Op& exp2();

TVM_DLL const Op& add();

TVM_DLL const Op& sub();
Expand Down Expand Up @@ -221,12 +223,13 @@ TVM_DLL const Op& binary_chain();

TVM_DLL const Op& select();

/*!
* \brief See pesudo code below:
*
* tvm_kernel_replace_point()
*/
TVM_DLL const Op& tvm_kernel_replace_point();
TVM_DLL const Op& fma();

TVM_DLL const Op& silu();

TVM_DLL const Op& compose_op();

TVM_DLL const Op& permute_layout();

} // namespace tirx
} // namespace tvm
Expand Down
Loading
Loading