From e2af93bce4672326eb997b6c42869e2740b1fc0c Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Fri, 5 Jun 2026 14:58:59 -0400 Subject: [PATCH 1/4] feat(lower-tirx): support scoped ops and launch bounds Fold the latest TIRx op-dispatch and lowering follow-ups into one upstream-facing commit. This updates the scoped Tx op surface, splits TIRx op namespaces, and adds explicit CUDA launch bounds support. --- include/tvm/tirx/builtin.h | 43 +- include/tvm/tirx/exec_context.h | 3 - include/tvm/tirx/exec_scope.h | 9 +- include/tvm/tirx/function.h | 7 + include/tvm/tirx/op.h | 7 +- include/tvm/tirx/op_attr_types.h | 28 + include/tvm/tirx/script/builder/frame.h | 46 - include/tvm/tirx/script/builder/ir.h | 15 - include/tvm/tirx/stmt.h | 44 +- include/tvm/tirx/stmt_functor.h | 4 - include/tvm/tirx/target_builtin/cuda.h | 12 +- include/tvm/tirx/tirx_op.h | 10 + include/tvm/tirx/tirx_stmt.h | 9 +- python/tvm/runtime/script_printer.py | 8 +- python/tvm/s_tir/backend/adreno/pipeline.py | 4 +- python/tvm/s_tir/pipeline.py | 4 +- python/tvm/script/parser/core/entry.py | 3 +- python/tvm/tirx/__init__.py | 2 +- python/tvm/tirx/bench.py | 36 +- python/tvm/tirx/lang/alloc_pool.py | 35 +- python/tvm/tirx/lang/pipeline.py | 66 +- python/tvm/tirx/lang/smem_desc.py | 16 +- python/tvm/tirx/lang/tile_scheduler.py | 294 ++- python/tvm/tirx/lang/warp_role.py | 55 +- python/tvm/tirx/op.py | 58 +- .../tvm/tirx/operator/intrinsics/_schema.py | 38 +- .../tvm/tirx/operator/intrinsics/cuda/misc.py | 4 +- .../tirx/operator/intrinsics/cuda/registry.py | 42 +- .../operator/tile_primitive/cuda/common.py | 48 +- .../tile_primitive/cuda/copy/_swizzle_iter.py | 24 +- .../tile_primitive/cuda/copy/fallback.py | 10 +- .../tile_primitive/cuda/copy/gmem_smem.py | 8 +- .../tile_primitive/cuda/copy/ld_stmatrix.py | 22 +- .../operator/tile_primitive/cuda/copy/reg.py | 22 +- .../tile_primitive/cuda/copy_async/dsmem.py | 24 +- .../tile_primitive/cuda/copy_async/ldgsts.py | 14 +- .../cuda/copy_async/tcgen05_cp.py | 20 +- .../cuda/copy_async/tcgen05_ldst.py | 50 +- .../tile_primitive/cuda/copy_async/tma.py | 37 +- .../cuda/elementwise/_common.py | 12 +- .../cuda/elementwise/ops/__init__.py | 2 +- .../cuda/elementwise/ops/unary.py | 18 +- .../tile_primitive/cuda/elementwise/reg.py | 30 +- .../tile_primitive/cuda/elementwise/smem.py | 30 +- .../cuda/elementwise/vec_emit/__init__.py | 2 +- .../cuda/elementwise/vec_emit/binary_f32x2.py | 14 +- .../cuda/elementwise/vec_emit/cast_vec2.py | 8 +- .../cuda/elementwise/vec_emit/fma_f32x2.py | 12 +- .../tile_primitive/cuda/exec_scope_utils.py | 36 +- .../tile_primitive/cuda/gemm/mma_m16n8k_.py | 20 +- .../tile_primitive/cuda/gemm_async/tcgen05.py | 68 +- .../cuda/permute_layout/warp_xor_swizzle.py | 46 +- .../tile_primitive/cuda/reduction/local.py | 116 +- .../tile_primitive/cuda/reduction/shared.py | 58 +- .../cuda/reduction/sm100_packed.py | 175 +- .../tile_primitive/cuda/reduction/utils.py | 10 +- .../tvm/tirx/operator/tile_primitive/ops.py | 94 +- .../tile_primitive/trn/binary/default.py | 34 +- .../trn/compose_op/binary_chain.py | 30 +- .../trn/compose_op/binary_reduce.py | 66 +- .../trn/compose_op/compose_op.py | 2 +- .../trn/compose_op/reduce_negate.py | 2 +- .../trn/compose_op/unary_reduce.py | 66 +- .../tile_primitive/trn/compose_op/utils.py | 26 +- .../tile_primitive/trn/copy/default.py | 121 +- .../operator/tile_primitive/trn/dim_utils.py | 8 +- .../tile_primitive/trn/gemm/default.py | 66 +- .../trn/instruction_generator.py | 8 +- .../tile_primitive/trn/private_alloc.py | 29 +- .../tile_primitive/trn/reduction/utils.py | 58 +- .../tile_primitive/trn/select/default.py | 32 +- .../tile_primitive/trn/unary/default.py | 4 +- .../tile_primitive/trn/unary/utils.py | 45 +- .../trn/unary/with_bias_scale.py | 4 +- python/tvm/tirx/script/__init__.py | 51 +- python/tvm/tirx/script/builder/__init__.py | 4 +- python/tvm/tirx/script/builder/frame.py | 12 - python/tvm/tirx/script/builder/ir.py | 204 +- python/tvm/tirx/script/builder/tirx.py | 317 +++- python/tvm/tirx/script/parser/__init__.py | 3 + python/tvm/tirx/script/parser/entry.py | 33 +- python/tvm/tirx/script/parser/parser.py | 2 +- python/tvm/tirx/script/tile.py | 121 ++ python/tvm/tirx/stmt.py | 78 +- python/tvm/tirx/stmt_functor.py | 25 +- python/tvm/tirx/transform/common.py | 7 +- .../transform/trn/private_buffer_alloc.py | 27 +- src/target/cuda/codegen_cuda.cc | 45 +- src/target/cuda/intrin_rule_cuda.cc | 20 +- .../hexagon/llvm/intrin_rule_hexagon.cc | 1 + src/target/intrin_rule.cc | 2 + src/target/llvm/codegen_llvm.cc | 2 - src/target/llvm/codegen_llvm.h | 1 - src/target/metal/intrin_rule_metal.cc | 16 +- src/target/source/codegen_c.cc | 2 - src/target/source/codegen_c.h | 1 - src/target/source/codegen_trn.cc | 36 +- src/target/webgpu/intrin_rule_webgpu.cc | 17 +- src/tirx/analysis/exec_context.cc | 10 - src/tirx/analysis/filter_canonical.cc | 10 +- src/tirx/analysis/verify_tirx_well_formed.cc | 55 +- src/tirx/ir/stmt.cc | 15 - src/tirx/ir/stmt_functor.cc | 16 - src/tirx/ir/tir_visitor_with_path.cc | 4 - src/tirx/ir/tir_visitor_with_path.h | 1 - src/tirx/ir/tirx_stmt.cc | 8 +- src/tirx/ir/transform.cc | 2 +- src/tirx/op/runtime.cc | 2 + src/tirx/op/target_builtin/cuda.cc | 251 ++- src/tirx/op/target_builtin/trn.cc | 61 + src/tirx/op/tirx.cc | 33 +- src/tirx/script/builder/frame.cc | 19 +- src/tirx/script/builder/ir.cc | 24 +- src/tirx/script/builder/utils.h | 15 - src/tirx/script/printer/block.cc | 10 +- src/tirx/script/printer/buffer.cc | 2 +- src/tirx/script/printer/expr.cc | 2 +- src/tirx/script/printer/stmt.cc | 44 +- src/tirx/script/printer/utils.h | 11 - src/tirx/transform/lower_tirx.cc | 34 +- src/tirx/transform/lower_tirx_cleanup.cc | 30 - src/tirx/transform/split_host_device.cc | 45 +- src/tirx/transform/tile_primitive_dispatch.cc | 173 +- tests/python/codegen/test_inject_ptx_ldg32.py | 2 +- .../test_s_tir_transform_inject_ptx_ldg32.py | 2 +- tests/python/tirx-base/test_tir_op_types.py | 10 +- .../python/tirx-base/test_tir_stmt_functor.py | 6 +- .../tirx/codegen/test_codegen_ampere.py | 206 +- .../tirx/codegen/test_codegen_blackwell.py | 448 +++-- .../python/tirx/codegen/test_codegen_cuda.py | 574 +++--- .../python/tirx/codegen/test_codegen_dsmem.py | 48 +- .../tirx/codegen/test_codegen_hopper.py | 751 ++++---- tests/python/tirx/codegen/test_codegen_nki.py | 217 ++- .../tirx/codegen/test_codegen_nvshmem.py | 164 +- tests/python/tirx/codegen/test_cuda_copy.py | 220 +-- .../tirx/codegen/test_cuda_cta_reduce.py | 158 +- .../tirx/codegen/test_cuda_warp_reduce.py | 110 +- .../tile_primitive/cuda/copy/test_fallback.py | 138 +- .../cuda/copy/test_gmem_smem.py | 172 +- .../cuda/copy/test_ld_stmatrix.py | 377 ++-- .../tile_primitive/cuda/copy/test_reg.py | 338 ++-- .../cuda/copy_async/test_dsmem.py | 100 +- .../cuda/copy_async/test_ldgsts.py | 30 +- .../cuda/copy_async/test_smem_tmem.py | 364 ++-- .../cuda/copy_async/test_tma.py | 423 ++--- .../cuda/copy_async/test_tmem.py | 314 ++-- .../cuda/copy_async/test_tmem_16xnb.py | 482 +++-- .../cuda/elementwise/test_binary.py | 736 ++++---- .../cuda/elementwise/test_fma.py | 238 ++- .../cuda/elementwise/test_unary.py | 1047 +++++------ .../cuda/gemm/test_gemm_mma_m16n8k_.py | 457 +++-- .../cuda/gemm_async/test_gemm_async.py | 1257 ++++++------- .../permute_layout/test_permute_layout.py | 154 +- .../cuda/reduction/test_reduction.py | 826 ++++----- .../tile_primitive/test_dispatcher.py | 26 +- .../tile_primitive/trn/test_binary_trn.py | 263 ++- .../tile_primitive/trn/test_compose_op_trn.py | 743 ++++---- .../tile_primitive/trn/test_copy_trn.py | 925 +++++---- .../tile_primitive/trn/test_gemm_trn.py | 495 +++-- .../trn/test_private_alloc_trn.py | 313 ++-- .../tile_primitive/trn/test_reduction_trn.py | 245 ++- .../tile_primitive/trn/test_select_trn.py | 131 +- .../tile_primitive/trn/test_unary_trn.py | 245 ++- tests/python/tirx/test_buffer_print.py | 104 +- tests/python/tirx/test_control_flow.py | 99 +- tests/python/tirx/test_hint.py | 119 +- tests/python/tirx/test_inline.py | 25 +- tests/python/tirx/test_jit.py | 114 +- tests/python/tirx/test_layout.py | 4 +- tests/python/tirx/test_op.py | 132 +- .../python/tirx/test_op_namespace_cleanup.py | 248 +++ tests/python/tirx/test_parser_printer.py | 1639 ++++++++-------- .../tirx/test_printer_tir_namespaces.py | 288 +-- .../python/tirx/test_roundtrip_namespaces.py | 22 +- tests/python/tirx/test_verifier.py | 415 ++--- .../tirx/transform/test_stmt_functor.py | 31 +- .../transform/test_transform_lower_tirx.py | 1649 ++++++++--------- .../test_transform_naive_allocator.py | 143 +- 178 files changed, 11421 insertions(+), 11803 deletions(-) create mode 100644 python/tvm/tirx/script/tile.py create mode 100644 tests/python/tirx/test_op_namespace_cleanup.py diff --git a/include/tvm/tirx/builtin.h b/include/tvm/tirx/builtin.h index ff61386699c4..1a11598fa427 100644 --- a/include/tvm/tirx/builtin.h +++ b/include/tvm/tirx/builtin.h @@ -134,6 +134,7 @@ TVM_DLL const Op& large_uint_imm(); * (i.e., round(x.1) = x and round (x.5) = x+1) */ TVM_DLL const Op& q_multiply_shift(); +TVM_DLL const Op& q_multiply_shift_per_axis(); /*! * \brief Returns the address of an element in the buffer (see pseudocode below). @@ -912,6 +913,11 @@ TVM_DLL const Op& cuda_atomic_add(); */ TVM_DLL const Op& cuda_thread_fence(); +/*! + * \brief tvm intrinsic for cuda warpgroup sync instruction + */ +TVM_DLL const Op& cuda_warpgroup_sync(); + /*! * \brief Warp-level butterfly shuffle-XOR reduction. * @@ -952,6 +958,11 @@ TVM_DLL const Op& cuda_cta_sync(); */ TVM_DLL const Op& cuda_grid_sync(); +/*! + * \brief tvm intrinsic for cuda cluster-wide sync instruction + */ +TVM_DLL const Op& cuda_cluster_sync(); + /*! * \brief tvm intrinsic that returns ``cooperative_groups::thread_rank()`` * for the enclosing CTA (linear thread index within the block). @@ -1053,25 +1064,19 @@ TVM_DLL const Op& ptx_reduce3_max_f32(); */ TVM_DLL const Op& ptx_reduce3_min_f32(); -/*! - * \brief tvm intrinsic for PTX packed add instruction (sm_100a+) - */ -TVM_DLL const Op& ptx_add_packed_f32x2(); - -/*! - * \brief tvm intrinsic for PTX packed subtract instruction (sm_100a+) - */ -TVM_DLL const Op& ptx_sub_packed_f32x2(); - -/*! - * \brief tvm intrinsic for PTX packed multiply instruction (sm_100a+) - */ -TVM_DLL const Op& ptx_mul_packed_f32x2(); - -/*! - * \brief tvm intrinsic for PTX packed FMA instruction (sm_100a+) - */ -TVM_DLL const Op& ptx_fma_packed_f32x2(); +TVM_DLL const Op& ptx_add_f32(); +TVM_DLL const Op& ptx_add_f32x2(); +TVM_DLL const Op& ptx_add_f64(); +TVM_DLL const Op& ptx_sub_f32(); +TVM_DLL const Op& ptx_sub_f32x2(); +TVM_DLL const Op& ptx_sub_f64(); +TVM_DLL const Op& ptx_mul_f32(); +TVM_DLL const Op& ptx_mul_f32x2(); +TVM_DLL const Op& ptx_mul_f64(); +TVM_DLL const Op& ptx_fma_f32(); +TVM_DLL const Op& ptx_fma_f32x2(); +TVM_DLL const Op& ptx_fma_f64(); +TVM_DLL const Op& ptx_max_f32(); } // namespace builtin } // namespace tirx diff --git a/include/tvm/tirx/exec_context.h b/include/tvm/tirx/exec_context.h index d8caedce754b..01422703896c 100644 --- a/include/tvm/tirx/exec_context.h +++ b/include/tvm/tirx/exec_context.h @@ -136,9 +136,6 @@ struct ExecContext { /*! \brief Apply modulo filter on a factorized CTA axis such as cbx/cby/cbz. */ bool WithCtaAxisModulo(const std::string& axis, int64_t modulus, int64_t residue, ExecContext* out, std::string* err) const; - - /*! \brief Apply scope_switch; A preserved, split recomputed for new scope_kind. */ - bool WithScopeSwitch(ScopeKind new_scope_kind, ExecContext* out, std::string* err) const; }; /*! diff --git a/include/tvm/tirx/exec_scope.h b/include/tvm/tirx/exec_scope.h index 189c538a434e..027bff550e8c 100644 --- a/include/tvm/tirx/exec_scope.h +++ b/include/tvm/tirx/exec_scope.h @@ -35,11 +35,12 @@ namespace tvm { namespace tirx { /*! - * \brief The target execution scope kind of an ExecScopeStmt. + * \brief The target execution scope kind of a tile primitive call. * - * Replaces the string-keyed name of ExecScope. One value per user-facing - * `with T.():` construct. Ordered from coarsest to finest; smaller - * integer = wider scope, so ``ScopeKindHigher`` is a plain ``<``. + * Identifies the granularity at which an op executes (the per-call + * ``scope`` on a ``TilePrimitiveCall``, e.g. ``Tx.warp.copy(...)``). + * Ordered from coarsest to finest; smaller integer = wider scope, so + * ``ScopeKindHigher`` is a plain ``<``. */ enum class ScopeKind : int { kCluster = 2, diff --git a/include/tvm/tirx/function.h b/include/tvm/tirx/function.h index dd2aefdc1268..651c49133691 100644 --- a/include/tvm/tirx/function.h +++ b/include/tvm/tirx/function.h @@ -305,6 +305,13 @@ namespace attr { */ constexpr const char* kKernelLaunchParams = "tirx.kernel_launch_params"; +/*! + * \brief CUDA launch bound minimum CTAs per SM. + * + * Type: IntImm + */ +constexpr const char* kLaunchBoundsMinBlocksPerSM = "tirx.launch_bounds_min_blocks_per_sm"; + /*! * \brief Whether to set noalias rule on the function arguments. * diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index 549aab4df807..60b292bbb265 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -43,8 +44,10 @@ namespace tvm { -#define TVM_TIR_REGISTER_OP(OpName) \ - TVM_REGISTER_OP("tirx." OpName).set_attr("TScriptPrinterName", OpName) +#define TVM_TIR_REGISTER_OP(OpName) \ + TVM_REGISTER_OP("tirx." OpName) \ + .set_attr("TScriptPrinterName", OpName) \ + .set_attr("TIRxOpCategory", ffi::String("builtin"), /*plevel=*/1) #define TVM_TIRX_REGISTER_OP(OpName) TVM_TIR_REGISTER_OP(OpName) diff --git a/include/tvm/tirx/op_attr_types.h b/include/tvm/tirx/op_attr_types.h index f766ad19d70b..7ebd87ed6f3c 100644 --- a/include/tvm/tirx/op_attr_types.h +++ b/include/tvm/tirx/op_attr_types.h @@ -82,6 +82,34 @@ enum class ScriptDtypePrintLocation : int { using TScriptDtypePrintLocation = int64_t; +/*! + * \brief Broad TIRx op category. + * + * Expected values: + * - "builtin" + * - "tile_primitive" + * - "device_intrin" + */ +using TIRxOpCategory = ffi::String; + +/*! + * \brief Tile primitive subcategory. + * + * Expected values: + * - "dispatch" + * - "compose" + * - "async" + * - "marker" + */ +using TTilePrimitiveKind = ffi::String; + +/*! + * \brief Device intrinsic namespace. + * + * Expected values include "cuda", "ptx", "nvshmem", "nki", and "metal". + */ +using TDeviceIntrinsicNamespace = ffi::String; + /*! * \brief The effect type of the call. */ diff --git a/include/tvm/tirx/script/builder/frame.h b/include/tvm/tirx/script/builder/frame.h index 3906705819da..5b2d3953269b 100644 --- a/include/tvm/tirx/script/builder/frame.h +++ b/include/tvm/tirx/script/builder/frame.h @@ -247,52 +247,6 @@ class BlockInitFrame : public TIRFrame { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockInitFrame, TIRFrame, BlockInitFrameNode); }; -/*! - * \brief A frame that represents an execution scope (e.g. cta, warp, thread). - * - * When exiting this frame, it produces an ExecScopeStmt wrapping the body. - * This is the new IR pattern, replacing the old pattern of storing exec_scope on SBlock. - * - * \sa ExecScopeFrame - */ -class ExecScopeFrameNode : public TIRFrameNode { - public: - /*! \brief The execution scope (always plain kind; no slice). */ - ffi::Optional exec_scope; - /*! \brief Optional surface-syntax guards for ``with Tx.scope(cond)``. */ - ffi::Array guards; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("exec_scope", &ExecScopeFrameNode::exec_scope) - .def_ro("guards", &ExecScopeFrameNode::guards); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.ExecScopeFrame", ExecScopeFrameNode, - TIRFrameNode); - - public: - /*! - * \brief The method called when exiting RAII scope. - * \sa tvm::support::With - */ - void ExitWithScope() final; -}; - -/*! - * \brief Managed reference to ExecScopeFrameNode. - * - * \sa ExecScopeFrameNode - */ -class ExecScopeFrame : public TIRFrame { - public: - explicit ExecScopeFrame(ffi::ObjectPtr data) : TIRFrame(ffi::UnsafeInit{}) { - TVM_FFI_ICHECK(data != nullptr); - data_ = std::move(data); - } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ExecScopeFrame, TIRFrame, ExecScopeFrameNode); -}; - /*! * \brief A frame that represents the for loop. * diff --git a/include/tvm/tirx/script/builder/ir.h b/include/tvm/tirx/script/builder/ir.h index 644fd5d2e130..ad18d7ac4001 100644 --- a/include/tvm/tirx/script/builder/ir.h +++ b/include/tvm/tirx/script/builder/ir.h @@ -139,21 +139,6 @@ SBlockFrame Block(ffi::String name, bool no_realize = false, ffi::String exec_sc void TilePrimitiveCall(tvm::tirx::TilePrimitiveCall op_call); -/*! - * \brief Create an ExecScopeFrame for execution scope contexts. - * \param exec_scope_name The name of the execution scope (e.g. "cta", "warp"). - * \return The ExecScopeFrame. - */ -ExecScopeFrame ExecScopeBlock(ffi::String exec_scope_name, - ffi::Array guards = ffi::Array()); - -ExecScopeFrame Kernel(ffi::Array guards = ffi::Array()); -ExecScopeFrame Cluster(ffi::Array guards = ffi::Array()); -ExecScopeFrame WarpGroup(ffi::Array guards = ffi::Array()); -ExecScopeFrame CTA(ffi::Array guards = ffi::Array()); -ExecScopeFrame Warp(ffi::Array guards = ffi::Array()); -ExecScopeFrame Thread(ffi::Array guards = ffi::Array()); - ffi::Array KernelId(ffi::Array extents, ffi::String parent); ffi::Array CtaId(ffi::Array extents, ffi::String parent); diff --git a/include/tvm/tirx/stmt.h b/include/tvm/tirx/stmt.h index 1730c2b22e18..1ed4d5acac54 100644 --- a/include/tvm/tirx/stmt.h +++ b/include/tvm/tirx/stmt.h @@ -952,53 +952,11 @@ class SBlockRealize : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(SBlockRealizeNode); }; -/*! - * \brief A statement that annotates the execution scope for its body. - * - * ExecScopeStmt represents a hardware execution scope (e.g. cta, warp, thread) - * that wraps a body statement. This decouples the execution scope concept from - * SBlock, making the IR structure cleaner. - * - * Example: - * \code - * with T.cta(): - * ... - * \endcode - */ -class ExecScopeStmtNode : public StmtNode { - public: - /*! \brief The execution scope. */ - ExecScope exec_scope; - /*! \brief The body statement under this execution scope. */ - Stmt body; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("exec_scope", &ExecScopeStmtNode::exec_scope) - .def_ro("body", &ExecScopeStmtNode::body); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ExecScopeStmt", ExecScopeStmtNode, StmtNode); -}; - -/*! - * \brief Managed reference to ExecScopeStmtNode. - * \sa ExecScopeStmtNode - */ -class ExecScopeStmt : public Stmt { - public: - TVM_DLL ExecScopeStmt(ExecScope exec_scope, Stmt body, Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ExecScopeStmt, Stmt, ExecScopeStmtNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ExecScopeStmtNode); -}; - /*! * \brief Standalone statement that declares a scope-id binding (e.g. cta_id, * warp_id, lane_id). Carries a ``ScopeIdDef`` value. * - * Unlike legacy ``ExecScopeStmt::scope_id_def`` (an array payload), each - * declaration is a flat stmt within the device-region body. The declared + * Each declaration is a flat stmt within the device-region body. The declared * ``Var``\ s are visible in subsequent stmts in the same enclosing scope * (the AttrStmt ``kDeviceEntry`` body), analogous to ``BindNode``. */ diff --git a/include/tvm/tirx/stmt_functor.h b/include/tvm/tirx/stmt_functor.h index 85b467e1857b..0262a167918d 100644 --- a/include/tvm/tirx/stmt_functor.h +++ b/include/tvm/tirx/stmt_functor.h @@ -100,7 +100,6 @@ class StmtFunctor { virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SBlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SBlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const ExecScopeStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ScopeIdDefStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const tirx::TilePrimitiveCallNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const ffi::Object* op, Args...) { @@ -127,7 +126,6 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); IR_STMT_FUNCTOR_DISPATCH(SBlockNode); IR_STMT_FUNCTOR_DISPATCH(SBlockRealizeNode); - IR_STMT_FUNCTOR_DISPATCH(ExecScopeStmtNode); IR_STMT_FUNCTOR_DISPATCH(ScopeIdDefStmtNode); IR_STMT_FUNCTOR_DISPATCH(tirx::TilePrimitiveCallNode); vtable.Finalize(); @@ -185,7 +183,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const SBlockNode* op) override; void VisitStmt_(const SBlockRealizeNode* op) override; - void VisitStmt_(const ExecScopeStmtNode* op) override; void VisitStmt_(const ScopeIdDefStmtNode* op) override; void VisitStmt_(const tirx::TilePrimitiveCallNode* op) override; }; @@ -304,7 +301,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const EvaluateNode* op) override; Stmt VisitStmt_(const SBlockNode* op) override; Stmt VisitStmt_(const SBlockRealizeNode* op) override; - Stmt VisitStmt_(const ExecScopeStmtNode* op) override; Stmt VisitStmt_(const ScopeIdDefStmtNode* op) override; Stmt VisitStmt_(const tirx::TilePrimitiveCallNode* op) override; /*! diff --git a/include/tvm/tirx/target_builtin/cuda.h b/include/tvm/tirx/target_builtin/cuda.h index 76472f70fa4c..ff10ee0b43e6 100644 --- a/include/tvm/tirx/target_builtin/cuda.h +++ b/include/tvm/tirx/target_builtin/cuda.h @@ -126,12 +126,6 @@ TVM_DLL const Op& mma_fill_legacy(); */ TVM_DLL const Op& ptx_ldg32(); -/*! - * \brief tvm intrinsic for ptx predicate load with 32-bit data type. - * - */ -TVM_DLL const Op& ptx_ldg32(); - /*! * \brief tvm intrinsic for sparse tensor core ptx instructions. * @@ -374,6 +368,12 @@ TVM_DLL const Op& ptx_fence_mbarrier_init(); */ TVM_DLL const Op& ptx_fetch_register(); +/*! + * \brief PTX programmatic dependent launch synchronization. + */ +TVM_DLL const Op& ptx_griddepcontrol_wait(); +TVM_DLL const Op& ptx_griddepcontrol_launch_dependents(); + /*! * \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer. * For example, if each thread in a warp of size 32 has 4 elements from the result of diff --git a/include/tvm/tirx/tirx_op.h b/include/tvm/tirx/tirx_op.h index 299a960fb88b..772f6ce34f06 100644 --- a/include/tvm/tirx/tirx_op.h +++ b/include/tvm/tirx/tirx_op.h @@ -189,6 +189,8 @@ TVM_DLL const Op& sqrt(); TVM_DLL const Op& exp(); +TVM_DLL const Op& exp2(); + TVM_DLL const Op& add(); TVM_DLL const Op& sub(); @@ -221,6 +223,14 @@ TVM_DLL const Op& binary_chain(); TVM_DLL const Op& select(); +TVM_DLL const Op& fma(); + +TVM_DLL const Op& silu(); + +TVM_DLL const Op& compose_op(); + +TVM_DLL const Op& permute_layout(); + /*! * \brief See pesudo code below: * diff --git a/include/tvm/tirx/tirx_stmt.h b/include/tvm/tirx/tirx_stmt.h index 62df8a0a53e1..9f141a8c3a2e 100644 --- a/include/tvm/tirx/tirx_stmt.h +++ b/include/tvm/tirx/tirx_stmt.h @@ -49,6 +49,9 @@ class TilePrimitiveCallNode : public StmtNode { // Optional dispatch variant name registered via @register_dispatch. ffi::Optional dispatch{std::nullopt}; + // Cooperation scope of this call. Default thread (an unscoped call). + ExecScope scope = ExecScope(ScopeKind::kThread); + static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() @@ -56,7 +59,8 @@ class TilePrimitiveCallNode : public StmtNode { .def_ro("args", &TilePrimitiveCallNode::args) .def_ro("workspace", &TilePrimitiveCallNode::workspace) .def_ro("config", &TilePrimitiveCallNode::config) - .def_ro("dispatch", &TilePrimitiveCallNode::dispatch); + .def_ro("dispatch", &TilePrimitiveCallNode::dispatch) + .def_ro("scope", &TilePrimitiveCallNode::scope); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.TilePrimitiveCall", TilePrimitiveCallNode, StmtNode); @@ -71,7 +75,8 @@ class TilePrimitiveCall : public Stmt { TVM_DLL TilePrimitiveCall(tvm::Op op, ffi::Array args, ffi::Map workspace = {}, ffi::Map config = {}, - ffi::Optional dispatch = std::nullopt); + ffi::Optional dispatch = std::nullopt, + ExecScope scope = ExecScope(ScopeKind::kThread)); static bool IsValidOpCallArgType(const ffi::Any& arg); diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 209efe77a0cc..238973725fbc 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -176,7 +176,7 @@ def script( extra_config : Optional[dict] = None Dialect-specific configuration passed through to PrinterConfig.extra_config. Keys are conventionally namespaced as ".", e.g. - ``{"tirx.prefix": "Tx"}``. + ``{"tirx.prefix": "T"}``. path_to_underline : Optional[List[AccessPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[AccessPath, str]] = None @@ -192,10 +192,10 @@ def script( The TVM Script of the given TVM IR """ - # Auto-switch to tirx (`Tx`/`tirx`) flavor only when explicitly + # Auto-switch to tirx (`T`/`tirx`) flavor only when explicitly # printing a PrimFunc / IRModule that has no s_tir-tagged content. # Free objects (Buffer, BufferRegion, ...) keep the default `T`/`tir` - # flavor — they have no enclosing function to indicate tirx vs s_tir. + # flavor -- they have no enclosing function to indicate tirx vs s_tir. merged_extra: dict = {} if extra_config is not None: merged_extra.update(extra_config) @@ -224,7 +224,7 @@ def script( if any_prim and not any_s_tir: switch_to_tirx = True if switch_to_tirx: - merged_extra["tirx.prefix"] = "Tx" + merged_extra["tirx.prefix"] = "T" return _script( self, diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py b/python/tvm/s_tir/backend/adreno/pipeline.py index a185f2e4f036..06168d902a33 100644 --- a/python/tvm/s_tir/backend/adreno/pipeline.py +++ b/python/tvm/s_tir/backend/adreno/pipeline.py @@ -76,7 +76,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Additional passes based on configuration. if bool(config.get("tirx.instrument_bound_checkers", False)): passes.append(s_tir.transform.InstrumentBoundCheckers()) - if bool(config.get("tirx.ptx_ldg32", False)): + if bool(config.get("tirx.ptx.ldg32", False)): passes.append(s_tir.transform.InjectPTXLDG32(True)) if not bool(config.get("tirx.disable_cse_tir", False)): passes.append(tirx.transform.CommonSubexprElim()) @@ -104,7 +104,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I ) if bool(config.get("tirx.use_async_copy", False)): passes.append(s_tir.transform.InjectPTXAsyncCopy()) - if bool(config.get("tirx.ptx_ldg32", False)): + if bool(config.get("tirx.ptx.ldg32", False)): passes.append(s_tir.transform.InjectPTXLDG32()) passes.extend( [ diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py index 070deb7681ae..fb8310dc2604 100644 --- a/python/tvm/s_tir/pipeline.py +++ b/python/tvm/s_tir/pipeline.py @@ -76,7 +76,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Additional passes based on configuration. if bool(config.get("tirx.instrument_bound_checkers", False)): passes.append(s_tir.transform.InstrumentBoundCheckers()) - if bool(config.get("tirx.ptx_ldg32", False)): + if bool(config.get("tirx.ptx.ldg32", False)): passes.append(s_tir.transform.InjectPTXLDG32(True)) if not bool(config.get("tirx.disable_cse_tir", False)): passes.append(tirx.transform.CommonSubexprElim()) @@ -104,7 +104,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I ) if bool(config.get("tirx.use_async_copy", False)): passes.append(s_tir.transform.InjectPTXAsyncCopy()) - if bool(config.get("tirx.ptx_ldg32", False)): + if bool(config.get("tirx.ptx.ldg32", False)): passes.append(s_tir.transform.InjectPTXLDG32()) passes.extend( [ diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 7764d30b4887..e9e670c5be79 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -44,6 +44,7 @@ def _default_globals() -> dict[str, Any]: relax, # pylint: disable=import-outside-toplevel ) from tvm.script.parser import tirx as _tirx_parser # pylint: disable=import-outside-toplevel + from tvm.script.tirx import tile as _tirx_tile # pylint: disable=import-outside-toplevel from tvm.tirx import layout as _tirx_layout # pylint: disable=import-outside-toplevel # Expose the layout `Axis` class so printed layout sugar like @@ -58,7 +59,7 @@ def _default_globals() -> dict[str, Any]: "tir": _tirx_parser, "R": relax, "relax": relax, - "Tx": _tirx_dsl, + "Tx": _tirx_tile, "tirx": _tirx_dsl, "Axis": _tirx_layout.Axis, } diff --git a/python/tvm/tirx/__init__.py b/python/tvm/tirx/__init__.py index efda655066cd..4378a9dfbe6c 100644 --- a/python/tvm/tirx/__init__.py +++ b/python/tvm/tirx/__init__.py @@ -44,7 +44,7 @@ from .stmt import SeqStmt from .stmt import IfThenElse, Evaluate, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, SBlock, SBlockRealize -from .stmt import TilePrimitiveCall, ExecScopeStmt, ScopeIdDefStmt +from .stmt import TilePrimitiveCall, ScopeIdDefStmt from .function import PrimFunc, TensorIntrin, IndexMap diff --git a/python/tvm/tirx/bench.py b/python/tvm/tirx/bench.py index 69f39ffbd13f..d12ff2e3d04d 100644 --- a/python/tvm/tirx/bench.py +++ b/python/tvm/tirx/bench.py @@ -30,7 +30,7 @@ import tvm_ffi import tvm -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.support import nvcc @@ -566,9 +566,9 @@ def export_to_perfetto_trace( tgen.flush() -@Tx.meta_class +@T.meta_class class CudaProfiler: - """A lightweight wrapper around Tx.timer_* CUDA intrinsics. + """A lightweight wrapper around T.timer_* CUDA intrinsics. Stores repeated arguments used by timer_init/start/end/finalize so users can call concise methods in kernels. Intended to mirror Pipeline/TileScheduler helpers. @@ -580,7 +580,7 @@ class CudaProfiler: def __init__( self, - profiler_buffer: Tx.Buffer, + profiler_buffer: T.Buffer, write_stride: int, num_groups: int, default_leader: None | tvm.tirx.PrimExpr | bool = None, @@ -590,30 +590,30 @@ def __init__( self.write_stride = write_stride self.num_groups = num_groups self.default_leader = default_leader - # Accept either a Python bool or a PrimExpr; normalize simple bools to Tx.bool + # Accept either a Python bool or a PrimExpr; normalize simple bools to T.bool # so we can use it uniformly inside macros for conditional emission. if isinstance(profiler_enabled, bool | np.bool_): - self.profiler_enabled = Tx.bool(bool(profiler_enabled)) + self.profiler_enabled = T.bool(bool(profiler_enabled)) else: # Assume PrimExpr-like input; use as-is self.profiler_enabled = profiler_enabled # type: ignore[assignment] - self.profiler_tag = Tx.alloc_buffer([1], "uint64", scope="local", align=8) - self.profiler_write_offset = Tx.alloc_buffer([1], "uint32", scope="local", align=8) + self.profiler_tag = T.alloc_buffer([1], "uint64", scope="local", align=8) + self.profiler_write_offset = T.alloc_buffer([1], "uint32", scope="local", align=8) def _leader(self, leader: None | tvm.tirx.PrimExpr | bool): if leader is not None: if isinstance(leader, bool | np.bool_): - return Tx.bool(bool(leader)) + return T.bool(bool(leader)) return leader if self.default_leader is not None: return self.default_leader - return Tx.bool(True) + return T.bool(True) - @Tx.inline + @T.inline def init(self, group_id: tvm.tirx.PrimExpr): if self.profiler_enabled: - Tx.timer_init_cuda( + T.timer_init_cuda( self.buffer.data, self.profiler_tag.data, self.profiler_write_offset.data, @@ -621,10 +621,10 @@ def init(self, group_id: tvm.tirx.PrimExpr): group_id, ) - @Tx.inline + @T.inline def start(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None): if self.profiler_enabled: - Tx.timer_start_cuda( + T.timer_start_cuda( event_type, self.buffer.data, self.profiler_tag.data, @@ -633,10 +633,10 @@ def start(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None self._leader(leader), ) - @Tx.inline + @T.inline def end(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None): if self.profiler_enabled: - Tx.timer_end_cuda( + T.timer_end_cuda( event_type, self.buffer.data, self.profiler_tag.data, @@ -645,10 +645,10 @@ def end(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None): self._leader(leader), ) - @Tx.inline + @T.inline def finalize(self, leader: None | tvm.tirx.PrimExpr | bool = None): if self.profiler_enabled: - Tx.timer_finalize_cuda( + T.timer_finalize_cuda( self.buffer.data, self.profiler_tag.data, self.profiler_write_offset.data, diff --git a/python/tvm/tirx/lang/alloc_pool.py b/python/tvm/tirx/lang/alloc_pool.py index fd4e2c54cd74..48bb9929c618 100644 --- a/python/tvm/tirx/lang/alloc_pool.py +++ b/python/tvm/tirx/lang/alloc_pool.py @@ -248,8 +248,8 @@ def __init__( # tcgen05 alloc/dealloc are warp-uniform PTX instructions: every lane # in the chosen warp must participate, and exactly one warp in the # CTA must execute them. The pool emits its own - # ``if warp_id() == target_warp: with Tx.warp(): tcgen05.alloc(...)`` - # guard, using the cta->warp scope id ``Tx.warp_id()``. + # ``if warp_id() == target_warp: tcgen05.alloc(...)`` + # guard, using the cta->warp scope id ``T.warp_id()``. # NOTE: synccheck currently false-deadlocks on kernels that declare a # second warp-scope id (cpusim binds only one warp var); the generated # CUDA is equivalent to ``thread_rank() // 32 == target_warp``. @@ -275,12 +275,13 @@ def _addr_slot(self): def addr(self): return self._addr_slot() - def _emit_warp_guard(self, Tx, target_warp, emit): - warp_id = Tx.warp_id() - with Tx.If(warp_id == target_warp): - with Tx.Then(): - with Tx.warp(): - emit() + def _emit_warp_guard(self, target_warp, emit): + from tvm.script import tirx as T + + warp_id = T.warp_id() + with T.If(warp_id == target_warp): + with T.Then(): + emit() def _resolve_cols(self, shape, dtype, cols, layout=None): if cols is not None: @@ -379,33 +380,33 @@ def move_base_to(self, col): def commit(self): assert not self._committed, "TMEMPool.commit() can only be called once" - from tvm.script import tirx as Tx + from tvm.script import tirx as T def emit_alloc(): _emit_stmt( - Tx.ptx.tcgen05.alloc( - Tx.address_of(self.addr), n_cols=self.total_cols, cta_group=self.cta_group + T.ptx.tcgen05.alloc( + T.address_of(self.addr), n_cols=self.total_cols, cta_group=self.cta_group ) ) if self.sync_after_alloc: - _emit_stmt(Tx.cuda.warp_sync()) + _emit_stmt(T.cuda.warp_sync()) - self._emit_warp_guard(Tx, self.alloc_warp, emit_alloc) + self._emit_warp_guard(self.alloc_warp, emit_alloc) self._committed = True def dealloc(self): assert self._committed, "TMEMPool.dealloc() called before commit()" assert not self._deallocated, "TMEMPool.dealloc() can only be called once" self._deallocated = True - from tvm.script import tirx as Tx + from tvm.script import tirx as T def emit_dealloc(): - _emit_stmt(Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=self.cta_group)) + _emit_stmt(T.ptx.tcgen05.relinquish_alloc_permit(cta_group=self.cta_group)) _emit_stmt( - Tx.ptx.tcgen05.dealloc(self.addr, n_cols=self.total_cols, cta_group=self.cta_group) + T.ptx.tcgen05.dealloc(self.addr, n_cols=self.total_cols, cta_group=self.cta_group) ) - self._emit_warp_guard(Tx, self.dealloc_warp, emit_dealloc) + self._emit_warp_guard(self.dealloc_warp, emit_dealloc) # --------------------------------------------------------------------------- diff --git a/python/tvm/tirx/lang/pipeline.py b/python/tvm/tirx/lang/pipeline.py index c3e5ca20e1e6..ee86090398e9 100644 --- a/python/tvm/tirx/lang/pipeline.py +++ b/python/tvm/tirx/lang/pipeline.py @@ -16,14 +16,14 @@ # under the License. """Reusable pipeline state and mbarrier helpers for SM100 kernels. -These classes emit TIR via @Tx.inline. Decorate with @Tx.meta_class so that -instances are automatically treated as meta values inside @Tx.prim_func. +These classes emit TIR via @T.inline. Decorate with @T.meta_class so that +instances are automatically treated as meta values inside @T.prim_func. """ -from tvm.script import tirx as Tx +from tvm.script import tirx as T -@Tx.meta_class +@T.meta_class class PipelineState: """Tracks stage and phase for a software-pipelined ring buffer. @@ -40,18 +40,18 @@ class PipelineState: """ def __init__(self, depth: int, phase=None): - self.stage = Tx.local_scalar("int32") - self.phase = Tx.local_scalar("int32") + self.stage = T.local_scalar("int32") + self.phase = T.local_scalar("int32") self.depth = depth if phase is not None: self.init(phase) - @Tx.inline + @T.inline def init(self, phase): self.stage = 0 self.phase = phase - @Tx.inline + @T.inline def advance(self): if self.depth > 1: self.stage = self.stage + 1 @@ -62,7 +62,7 @@ def advance(self): self.phase = self.phase ^ 1 -@Tx.meta_class +@T.meta_class class MBarrier: """Mbarrier wrapper with regular ``mbarrier.arrive``. @@ -76,14 +76,14 @@ class MBarrier: XORed into the phase bit on every ``wait`` / ``arrive``. leader : PrimExpr, optional Boolean predicate selecting the single thread that runs - ``mbarrier.init``. Defaults to ``Tx.cuda.thread_rank() == 0`` -- + ``mbarrier.init``. Defaults to ``T.cuda.thread_rank() == 0`` -- thread 0 of the enclosing CTA, which always picks exactly one thread regardless of which scope_id vars the caller declared. Override only when you want a different CTA-local thread to do the init. - Note: the default deliberately avoids ``Tx.warp_id()`` / - ``Tx.lane_id()``. Those introduce deferred ``cta->warp`` / + Note: the default deliberately avoids ``T.warp_id()`` / + ``T.lane_id()``. Those introduce deferred ``cta->warp`` / ``warp->thread`` ScopeIdDefs that the verifier cannot pin down unless the kernel header declares the full warp/lane chain (e.g. a single-CTA DSMEM kernel that only declares ``thread_id``). It also @@ -95,21 +95,21 @@ def __init__(self, pool, depth, phase_offset=0, leader=None): self.buf = pool.alloc((depth,), "uint64", align=8) self.depth = depth self.phase_offset = phase_offset - self.leader = leader if leader is not None else (Tx.cuda.thread_rank() == 0) + self.leader = leader if leader is not None else (T.cuda.thread_rank() == 0) - @Tx.inline + @T.inline def init(self, count): if self.leader: - for i in Tx.unroll(self.depth): - Tx.ptx.mbarrier.init(self.buf.ptr_to([i]), count) + for i in T.unroll(self.depth): + T.ptx.mbarrier.init(self.buf.ptr_to([i]), count) - @Tx.inline + @T.inline def wait(self, stage, phase): # Blocks: ``mbarrier.try_wait`` loops internally until the phase flips, # so this returns only once the barrier has completed. - Tx.ptx.mbarrier.try_wait(self.buf.ptr_to([stage]), phase ^ self.phase_offset) + T.ptx.mbarrier.try_wait(self.buf.ptr_to([stage]), phase ^ self.phase_offset) - @Tx.inline + @T.inline def arrive(self, stage, cta_id=None, pred=None): # Default: local-CTA arrive — emits the simple # ``mbarrier.arrive.shared.b64`` form. To arrive on a remote @@ -120,10 +120,10 @@ def arrive(self, stage, cta_id=None, pred=None): # silently ``mapa`` ed across the cluster) and a per-call cost # of ~3 PTX ops on every single-CTA kernel. if cta_id is None: - Tx.ptx.mbarrier.arrive(self.buf.ptr_to([stage])) + T.ptx.mbarrier.arrive(self.buf.ptr_to([stage])) else: actual_pred = True if pred is None else pred - Tx.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred) + T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred) def ptr_to(self, idx): return self.buf.ptr_to(idx) @@ -138,10 +138,10 @@ def remote_view(self, rank): from tvm.ir import PointerType, PrimType from tvm.tirx import Var as TIRVar - expr = Tx.reinterpret("handle", Tx.ptx.map_shared_rank(self.buf.ptr_to([0]), rank)) + expr = T.reinterpret("handle", T.ptx.map_shared_rank(self.buf.ptr_to([0]), rank)) ptr = TIRVar("remote_mbar_ptr", PointerType(PrimType("uint64"))) - Tx.Bind(expr, var=ptr) - buf = Tx.decl_buffer([self.depth], "uint64", data=ptr, scope="shared") + T.Bind(expr, var=ptr) + buf = T.decl_buffer([self.depth], "uint64", data=ptr, scope="shared") remote = object.__new__(type(self)) remote.buf = buf remote.depth = self.depth @@ -156,7 +156,7 @@ class TMABar(MBarrier): (matching MBarrier.arrive defaults). """ - @Tx.inline + @T.inline def arrive(self, stage, tx_count=None, cta_id=None, pred=None): # NOTE: this arrive() kwarg set intentionally differs from # MBarrier.arrive (hardware necessity, LSP-incompatible by design). @@ -166,36 +166,36 @@ def arrive(self, stage, tx_count=None, cta_id=None, pred=None): # arrive is local-CTA only. See ``MBarrier.arrive`` for the # full default-local rationale. if tx_count is not None: - Tx.ptx.mbarrier.arrive.expect_tx(self.buf.ptr_to([stage]), tx_count) + T.ptx.mbarrier.arrive.expect_tx(self.buf.ptr_to([stage]), tx_count) elif cta_id is None: - Tx.ptx.mbarrier.arrive(self.buf.ptr_to([stage])) + T.ptx.mbarrier.arrive(self.buf.ptr_to([stage])) else: actual_pred = True if pred is None else pred - Tx.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred) + T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred) class TCGen05Bar(MBarrier): """Barrier signaled by ``tcgen05`` commit. The caller is responsible for ensuring only one thread issues the - commit, e.g. by wrapping the call in ``if Tx.ptx.elect_sync():``. + commit, e.g. by wrapping the call in ``if T.ptx.elect_sync():``. """ - @Tx.inline + @T.inline def arrive(self, stage, cta_group=1, cta_mask=None): # NOTE: this arrive() kwarg set intentionally differs from # MBarrier.arrive (hardware necessity, LSP-incompatible by design). if cta_mask is None and cta_group == 1: - Tx.ptx.tcgen05.commit(self.buf.ptr_to([stage])) + T.ptx.tcgen05.commit(self.buf.ptr_to([stage])) else: - Tx.ptx.tcgen05.commit(self.buf.ptr_to([stage]), cta_group=cta_group, cta_mask=cta_mask) + T.ptx.tcgen05.commit(self.buf.ptr_to([stage]), cta_group=cta_group, cta_mask=cta_mask) # Barrier-type tags accepted by Pipeline's ``full=`` / ``empty=`` arguments. _BAR_KINDS = {"tma": TMABar, "tcgen05": TCGen05Bar, "mbar": MBarrier} -@Tx.meta_class +@T.meta_class class Pipeline: """A full/empty mbarrier pair for a software-pipelined data flow. diff --git a/python/tvm/tirx/lang/smem_desc.py b/python/tvm/tirx/lang/smem_desc.py index 0a88aa414ba5..c858cb70690c 100644 --- a/python/tvm/tirx/lang/smem_desc.py +++ b/python/tvm/tirx/lang/smem_desc.py @@ -17,25 +17,25 @@ """SMEM matrix descriptor helper for tcgen05 / wgmma.""" -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx.operator.tile_primitive.cuda.common import smem_desc_add_16B_offset -@Tx.meta_class +@T.meta_class class SmemDescriptor: """Encoded once via :meth:`init`, reused via :meth:`add_16B_offset`.""" def __init__(self): - self._buf = Tx.alloc_local([1], "uint64") + self._buf = T.alloc_local([1], "uint64") @property def desc(self): return self._buf[0] - @Tx.inline + @T.inline def init(self, smem_ptr, ldo, sdo, swizzle): - Tx.ptx.tcgen05.encode_matrix_descriptor( - Tx.address_of(self._buf[0]), smem_ptr, ldo, sdo, swizzle + T.ptx.tcgen05.encode_matrix_descriptor( + T.address_of(self._buf[0]), smem_ptr, ldo, sdo, swizzle ) def add_16B_offset(self, offset): @@ -50,6 +50,6 @@ def make_lo_uniform(self): d->lo = __shfl_sync(0xffffffff, d->lo, 0); }} """ - return Tx.cuda.func_call( - func_name, Tx.address_of(self._buf[0]), source_code=source_code, return_type="void" + return T.cuda.func_call( + func_name, T.address_of(self._buf[0]), source_code=source_code, return_type="void" ) diff --git a/python/tvm/tirx/lang/tile_scheduler.py b/python/tvm/tirx/lang/tile_scheduler.py index 99936613d060..3fd27f25ee5f 100644 --- a/python/tvm/tirx/lang/tile_scheduler.py +++ b/python/tvm/tirx/lang/tile_scheduler.py @@ -16,33 +16,33 @@ # under the License. """Reusable tile scheduler helpers for TIR tests/kernels. -These classes emit TIR via @Tx.inline. Decorate with @Tx.meta_class so that -instances are automatically treated as meta values inside @Tx.prim_func. +These classes emit TIR via @T.inline. Decorate with @T.meta_class so that +instances are automatically treated as meta values inside @T.prim_func. """ -from tvm.script import tirx as Tx +from tvm.script import tirx as T -@Tx.meta_class +@T.meta_class class BaseTileScheduler: """Base class for tile schedulers with common state and macros.""" def __init__(self, prefix: str): - self.m_idx = Tx.local_scalar("int32") - self.n_idx = Tx.local_scalar("int32") - self.linear_idx = Tx.local_scalar("int32") + self.m_idx = T.local_scalar("int32") + self.n_idx = T.local_scalar("int32") + self.linear_idx = T.local_scalar("int32") - @Tx.inline + @T.inline def update_current_m_n_idx(self, linear_idx): # To be implemented by subclasses pass - @Tx.inline + @T.inline def init(self, linear_init): self.linear_idx = linear_init self.update_current_m_n_idx(linear_init) - @Tx.inline + @T.inline def next_tile(self, step): self.linear_idx = self.linear_idx + step self.update_current_m_n_idx(self.linear_idx) @@ -96,7 +96,7 @@ class ClusterPersistentScheduler2D(BaseTileScheduler): ---------- prefix : str Prefix for TIR variable names - num_m_tiles : int | Tx.ExprLike + num_m_tiles : int | T.ExprLike Total number of tiles in M dimension (can be runtime expression) num_n_tiles : int Total number of tiles in N dimension @@ -114,13 +114,13 @@ class ClusterPersistentScheduler2D(BaseTileScheduler): Attributes ---------- - m_idx : Tx.local_scalar + m_idx : T.local_scalar Current M tile index (output) - n_idx : Tx.local_scalar + n_idx : T.local_scalar Current N tile index (output) - work_idx : Tx.local_scalar + work_idx : T.local_scalar Global work item index for this cluster - tile_count : Tx.local_scalar + tile_count : T.local_scalar Number of tiles processed by this cluster so far Usage @@ -133,8 +133,8 @@ class ClusterPersistentScheduler2D(BaseTileScheduler): scheduler.init(cluster_id) # cluster_id = cta_idx // CLUSTER_SIZE while scheduler.valid(): - m = Tx.meta_var(scheduler.m_idx) # current M tile - n = Tx.meta_var(scheduler.n_idx) # current N tile + m = T.meta_var(scheduler.m_idx) # current M tile + n = T.meta_var(scheduler.n_idx) # current N tile # ... process tile (m, n) ... scheduler.next_tile() ``` @@ -220,7 +220,7 @@ def __init__( # Rename internal state for clarity self.work_idx = self.linear_idx # alias: global work item index - self.tile_count = Tx.local_scalar("int32") + self.tile_count = T.local_scalar("int32") self.tile_idx = self.tile_count # alias for backward compatibility is_static_m = isinstance(num_m_tiles, int) @@ -234,10 +234,8 @@ def __init__( self._FULL_GROUPS = self._M_TILE_ROWS // l2_group_size else: # Dynamic expressions for runtime M - self._M_TILE_ROWS = Tx.truncdiv( - self._num_m_tiles + self._cluster_m - 1, self._cluster_m - ) - self._FULL_GROUPS = Tx.truncdiv(self._M_TILE_ROWS, self._l2_group_size) + self._M_TILE_ROWS = T.truncdiv(self._num_m_tiles + self._cluster_m - 1, self._cluster_m) + self._FULL_GROUPS = T.truncdiv(self._M_TILE_ROWS, self._l2_group_size) self._TAIL_ROWS = self._M_TILE_ROWS - self._FULL_GROUPS * l2_group_size self._TOTAL_TILES = self._M_TILE_ROWS * n_tile_cols * cluster_m * cluster_n @@ -248,7 +246,7 @@ def __init__( self._M_BLOCKS = ( self._M_TILE_ROWS // l2_group_size if is_static_m - else Tx.truncdiv(self._M_TILE_ROWS, l2_group_size) + else T.truncdiv(self._M_TILE_ROWS, l2_group_size) ) self._BLOCK_SIZE = l2_group_size * l2_group_size # tiles per block self._FULL_BLOCK_TILES = self._M_BLOCKS * self._N_BLOCKS * self._BLOCK_SIZE @@ -257,19 +255,19 @@ def __init__( self._RESIDUAL_M = self._M_TILE_ROWS - self._M_BLOCKS * l2_group_size # fmt: off - @Tx.inline + @T.inline def update_current_m_n_idx(self, work_idx): """Convert global work index to (m_idx, n_idx) tile coordinates.""" - CLUSTER_M = Tx.meta_var(self._cluster_m) - CLUSTER_N = Tx.meta_var(self._cluster_n) + CLUSTER_M = T.meta_var(self._cluster_m) + CLUSTER_N = T.meta_var(self._cluster_n) # Extract hierarchical cluster-local offsets - cluster_m_offset = Tx.meta_var(work_idx % CLUSTER_M) - t = Tx.meta_var(work_idx // CLUSTER_M) - cluster_n_offset = Tx.meta_var(t % CLUSTER_N) - tile_linear = Tx.meta_var(t // CLUSTER_N) + cluster_m_offset = T.meta_var(work_idx % CLUSTER_M) + t = T.meta_var(work_idx // CLUSTER_M) + cluster_n_offset = T.meta_var(t % CLUSTER_N) + tile_linear = T.meta_var(t // CLUSTER_N) - @Tx.inline + @T.inline def set_tile_coords(tile_row, tile_col): self.m_idx = tile_row * CLUSTER_M + cluster_m_offset self.n_idx = tile_col * CLUSTER_N + cluster_n_offset @@ -299,59 +297,59 @@ def _update_group_major(self, tile_linear, set_tile_coords): else: self._gm_emit_full_and_tail(tile_linear, set_tile_coords) - @Tx.inline + @T.inline def _gm_emit_zero(self, set_tile_coords): set_tile_coords(0, 0) - @Tx.inline + @T.inline def _gm_emit_full_only(self, tile_linear, set_tile_coords): - FULL_GROUPS = Tx.meta_var(self._FULL_GROUPS) - GROUP_SIZE = Tx.meta_var(self._l2_group_size) - GROUP_SPAN = Tx.meta_var(self._l2_group_size * self._N_TILE_COLS) + FULL_GROUPS = T.meta_var(self._FULL_GROUPS) + GROUP_SIZE = T.meta_var(self._l2_group_size) + GROUP_SPAN = T.meta_var(self._l2_group_size * self._N_TILE_COLS) if (FULL_GROUPS > 0) & (tile_linear < FULL_GROUPS * GROUP_SPAN): - group_id: Tx.let = tile_linear // GROUP_SPAN - within_group: Tx.let = tile_linear % GROUP_SPAN - tile_row: Tx.let = group_id * GROUP_SIZE + (within_group % GROUP_SIZE) - tile_col: Tx.let = within_group // GROUP_SIZE + group_id: T.let = tile_linear // GROUP_SPAN + within_group: T.let = tile_linear % GROUP_SPAN + tile_row: T.let = group_id * GROUP_SIZE + (within_group % GROUP_SIZE) + tile_col: T.let = within_group // GROUP_SIZE set_tile_coords(tile_row, tile_col) else: set_tile_coords(0, 0) - @Tx.inline + @T.inline def _gm_emit_tail_only(self, tile_linear, set_tile_coords): - FULL_GROUPS = Tx.meta_var(self._FULL_GROUPS) - TAIL_ROWS = Tx.meta_var(self._TAIL_ROWS) - GROUP_SIZE = Tx.meta_var(self._l2_group_size) - GROUP_SPAN = Tx.meta_var(self._l2_group_size * self._N_TILE_COLS) + FULL_GROUPS = T.meta_var(self._FULL_GROUPS) + TAIL_ROWS = T.meta_var(self._TAIL_ROWS) + GROUP_SIZE = T.meta_var(self._l2_group_size) + GROUP_SPAN = T.meta_var(self._l2_group_size * self._N_TILE_COLS) if TAIL_ROWS > 0: - rem: Tx.let = tile_linear - FULL_GROUPS * GROUP_SPAN - tile_row: Tx.let = FULL_GROUPS * GROUP_SIZE + (rem % TAIL_ROWS) - tile_col: Tx.let = rem // TAIL_ROWS + rem: T.let = tile_linear - FULL_GROUPS * GROUP_SPAN + tile_row: T.let = FULL_GROUPS * GROUP_SIZE + (rem % TAIL_ROWS) + tile_col: T.let = rem // TAIL_ROWS set_tile_coords(tile_row, tile_col) else: set_tile_coords(0, 0) - @Tx.inline + @T.inline def _gm_emit_full_and_tail(self, tile_linear, set_tile_coords): - FULL_GROUPS = Tx.meta_var(self._FULL_GROUPS) - TAIL_ROWS = Tx.meta_var(self._TAIL_ROWS) - GROUP_SIZE = Tx.meta_var(self._l2_group_size) - GROUP_SPAN = Tx.meta_var(self._l2_group_size * self._N_TILE_COLS) + FULL_GROUPS = T.meta_var(self._FULL_GROUPS) + TAIL_ROWS = T.meta_var(self._TAIL_ROWS) + GROUP_SIZE = T.meta_var(self._l2_group_size) + GROUP_SPAN = T.meta_var(self._l2_group_size * self._N_TILE_COLS) if (FULL_GROUPS > 0) & (tile_linear < FULL_GROUPS * GROUP_SPAN): - group_id: Tx.let = tile_linear // GROUP_SPAN - within_group: Tx.let = tile_linear % GROUP_SPAN - tile_row: Tx.let = group_id * GROUP_SIZE + (within_group % GROUP_SIZE) - tile_col: Tx.let = within_group // GROUP_SIZE + group_id: T.let = tile_linear // GROUP_SPAN + within_group: T.let = tile_linear % GROUP_SPAN + tile_row: T.let = group_id * GROUP_SIZE + (within_group % GROUP_SIZE) + tile_col: T.let = within_group // GROUP_SIZE set_tile_coords(tile_row, tile_col) elif TAIL_ROWS > 0: - rem: Tx.let = tile_linear - FULL_GROUPS * GROUP_SPAN - tile_row: Tx.let = FULL_GROUPS * GROUP_SIZE + (rem % TAIL_ROWS) - tile_col: Tx.let = rem // TAIL_ROWS + rem: T.let = tile_linear - FULL_GROUPS * GROUP_SPAN + tile_row: T.let = FULL_GROUPS * GROUP_SIZE + (rem % TAIL_ROWS) + tile_col: T.let = rem // TAIL_ROWS set_tile_coords(tile_row, tile_col) else: set_tile_coords(0, 0) - @Tx.inline + @T.inline def _update_serpentine(self, tile_linear, set_tile_coords): """CUTLASS-style 2D block swizzle with serpentine traversal. @@ -365,52 +363,52 @@ def _update_serpentine(self, tile_linear, set_tile_coords): This maximizes L2 reuse for both A and B matrices. """ - S = Tx.meta_var(self._l2_group_size) # swizzle_size - M_BLOCKS = Tx.meta_var(self._M_BLOCKS) - N_BLOCKS = Tx.meta_var(self._N_BLOCKS) - BLOCK_SIZE = Tx.meta_var(self._BLOCK_SIZE) # S * S - FULL_BLOCK_TILES = Tx.meta_var(self._FULL_BLOCK_TILES) - M_TILE_ROWS = Tx.meta_var(self._M_TILE_ROWS) - Tx.meta_var(self._N_TILE_COLS) - RESIDUAL_N = Tx.meta_var(self._RESIDUAL_N) - RESIDUAL_M = Tx.meta_var(self._RESIDUAL_M) + S = T.meta_var(self._l2_group_size) # swizzle_size + M_BLOCKS = T.meta_var(self._M_BLOCKS) + N_BLOCKS = T.meta_var(self._N_BLOCKS) + BLOCK_SIZE = T.meta_var(self._BLOCK_SIZE) # S * S + FULL_BLOCK_TILES = T.meta_var(self._FULL_BLOCK_TILES) + M_TILE_ROWS = T.meta_var(self._M_TILE_ROWS) + T.meta_var(self._N_TILE_COLS) + RESIDUAL_N = T.meta_var(self._RESIDUAL_N) + RESIDUAL_M = T.meta_var(self._RESIDUAL_M) # Check if we're in the full block region if (M_BLOCKS > 0) & (N_BLOCKS > 0) & (tile_linear < FULL_BLOCK_TILES): # Which block (in linear order along columns of blocks) - block_linear: Tx.let = tile_linear // BLOCK_SIZE - within_block: Tx.let = tile_linear % BLOCK_SIZE + block_linear: T.let = tile_linear // BLOCK_SIZE + within_block: T.let = tile_linear % BLOCK_SIZE # Block column and row - block_col: Tx.let = block_linear // M_BLOCKS - block_row_raw: Tx.let = block_linear % M_BLOCKS + block_col: T.let = block_linear // M_BLOCKS + block_row_raw: T.let = block_linear % M_BLOCKS # Serpentine: odd columns go bottom-to-top - block_row: Tx.let = Tx.Select( + block_row: T.let = T.Select( block_col % 2 == 0, block_row_raw, M_BLOCKS - 1 - block_row_raw ) # Position within block (row-major within block) - local_row: Tx.let = within_block // S - local_col: Tx.let = within_block % S + local_row: T.let = within_block // S + local_col: T.let = within_block % S - tile_row: Tx.let = block_row * S + local_row - tile_col: Tx.let = block_col * S + local_col + tile_row: T.let = block_row * S + local_row + tile_col: T.let = block_col * S + local_col set_tile_coords(tile_row, tile_col) elif RESIDUAL_N > 0: # Residual tiles in the rightmost partial column of blocks # These are tiles where n >= N_BLOCKS * S - rem: Tx.let = tile_linear - FULL_BLOCK_TILES + rem: T.let = tile_linear - FULL_BLOCK_TILES # First handle the right residual strip (full M height, partial N width) - right_strip_tiles: Tx.let = M_TILE_ROWS * RESIDUAL_N + right_strip_tiles: T.let = M_TILE_ROWS * RESIDUAL_N if rem < right_strip_tiles: # Row-major within the right strip - tile_row: Tx.let = rem // RESIDUAL_N - tile_col: Tx.let = N_BLOCKS * S + (rem % RESIDUAL_N) + tile_row: T.let = rem // RESIDUAL_N + tile_col: T.let = N_BLOCKS * S + (rem % RESIDUAL_N) set_tile_coords(tile_row, tile_col) elif RESIDUAL_M > 0: # Bottom residual strip (already covered in right strip overlap) @@ -422,11 +420,11 @@ def _update_serpentine(self, tile_linear, set_tile_coords): elif RESIDUAL_M > 0: # Bottom residual strip only (no right residual) - rem: Tx.let = tile_linear - FULL_BLOCK_TILES - bottom_strip_tiles: Tx.let = RESIDUAL_M * (N_BLOCKS * S) + rem: T.let = tile_linear - FULL_BLOCK_TILES + bottom_strip_tiles: T.let = RESIDUAL_M * (N_BLOCKS * S) if rem < bottom_strip_tiles: - tile_row: Tx.let = M_BLOCKS * S + (rem % RESIDUAL_M) - tile_col: Tx.let = rem // RESIDUAL_M + tile_row: T.let = M_BLOCKS * S + (rem % RESIDUAL_M) + tile_col: T.let = rem // RESIDUAL_M set_tile_coords(tile_row, tile_col) else: set_tile_coords(0, 0) @@ -434,7 +432,7 @@ def _update_serpentine(self, tile_linear, set_tile_coords): # Fallback set_tile_coords(0, 0) - @Tx.inline + @T.inline def init(self, cluster_id): """Initialize scheduler for a given cluster. @@ -447,14 +445,14 @@ def init(self, cluster_id): self.tile_count = 0 self.update_current_m_n_idx(cluster_id) - @Tx.inline + @T.inline def next_tile(self): """Advance to the next tile for this cluster.""" self.linear_idx = self.linear_idx + self._num_clusters self.tile_count = self.tile_count + 1 self.update_current_m_n_idx(self.linear_idx) - @Tx.inline + @T.inline def next_tile_stride(self, stride: int): """Advance by a custom stride (for non-standard scheduling).""" self.linear_idx = self.linear_idx + stride @@ -486,8 +484,8 @@ def __init__( ): super().__init__(prefix) self._step = step - self.tile_idx = Tx.local_scalar("int32") - self.k_idx = Tx.local_scalar("int32") + self.tile_idx = T.local_scalar("int32") + self.k_idx = T.local_scalar("int32") # ---- constants / primexprs baked once ---- self._G = group_rows @@ -501,9 +499,9 @@ def __init__( self._GROUP_SIZE = group_rows * n_tiles * k_tiles self._TOTAL = m_tiles * n_tiles * k_tiles else: - self._GROUPS = Tx.truncdiv(m_tiles, group_rows) + self._GROUPS = T.truncdiv(m_tiles, group_rows) self._FINAL_ROWS = m_tiles - self._GROUPS * group_rows - self._SAFE_FINAL_ROWS = Tx.max(self._FINAL_ROWS, 1) + self._SAFE_FINAL_ROWS = T.max(self._FINAL_ROWS, 1) self._GROUP_SIZE = self._G * self._N * self._K self._TOTAL = m_tiles * n_tiles * k_tiles @@ -513,21 +511,21 @@ def __init__( self._HAS_TAIL = self._FINAL_ROWS > 0 # fmt: off - @Tx.inline + @T.inline def update_current_m_n_idx(self, linear_idx): # full-group formulas - full_m: Tx.let = Tx.floordiv(linear_idx, self._GROUP_SIZE) * self._G + Tx.floormod( + full_m: T.let = T.floordiv(linear_idx, self._GROUP_SIZE) * self._G + T.floormod( linear_idx, self._G ) - full_n: Tx.let = Tx.floormod(Tx.floordiv(linear_idx, self._G), self._N) - full_k: Tx.let = Tx.floordiv(Tx.floormod(linear_idx, self._GROUP_SIZE), self._G * self._N) + full_n: T.let = T.floormod(T.floordiv(linear_idx, self._G), self._N) + full_k: T.let = T.floordiv(T.floormod(linear_idx, self._GROUP_SIZE), self._G * self._N) # tail formulas (relative to FULL_BOUND) # Use _SAFE_FINAL_ROWS (max(FINAL_ROWS, 1)) to avoid divide-by-zero when there is no tail - rem: Tx.let = linear_idx - self._FULL_BOUND - tail_m: Tx.let = self._GROUPS * self._G + Tx.floormod(rem, self._SAFE_FINAL_ROWS) - tail_n: Tx.let = Tx.floordiv(rem, self._SAFE_FINAL_ROWS) % self._N - tail_k: Tx.let = Tx.floordiv(rem, self._SAFE_FINAL_ROWS * self._N) + rem: T.let = linear_idx - self._FULL_BOUND + tail_m: T.let = self._GROUPS * self._G + T.floormod(rem, self._SAFE_FINAL_ROWS) + tail_n: T.let = T.floordiv(rem, self._SAFE_FINAL_ROWS) % self._N + tail_k: T.let = T.floordiv(rem, self._SAFE_FINAL_ROWS * self._N) # choose phase if self._HAS_FULL & (linear_idx < self._FULL_BOUND): @@ -543,19 +541,19 @@ def update_current_m_n_idx(self, linear_idx): self.n_idx = 0 self.k_idx = 0 - @Tx.inline + @T.inline def init(self, linear_init): self.linear_idx = linear_init self.tile_idx = 0 self.update_current_m_n_idx(linear_init) - @Tx.inline + @T.inline def next_tile(self): self.linear_idx = self.linear_idx + self._step self.tile_idx = self.tile_idx + 1 self.update_current_m_n_idx(self.linear_idx) - @Tx.inline + @T.inline def next_tile_stride(self, stride: int): self.linear_idx = self.linear_idx + stride self.tile_idx = self.tile_idx + 1 @@ -581,13 +579,13 @@ def __init__( self._group_size = group_size self._world_size = world_size - @Tx.inline + @T.inline def update_current_m_n_idx(self, linear_idx): - my_rank: Tx.let = Tx.nvshmem.my_pe() - remote_m_clusters: Tx.let = self._m_clusters - self._m_clusters // self._world_size - group_rows: Tx.let = (remote_m_clusters // self._group_size) * self._group_size - final_rows: Tx.let = remote_m_clusters - group_rows - group_repeat: Tx.let = self._group_size * self._n_clusters + my_rank: T.let = T.nvshmem.my_pe() + remote_m_clusters: T.let = self._m_clusters - self._m_clusters // self._world_size + group_rows: T.let = (remote_m_clusters // self._group_size) * self._group_size + final_rows: T.let = remote_m_clusters - group_rows + group_repeat: T.let = self._group_size * self._n_clusters if linear_idx < group_rows * self._n_clusters and group_rows > 0: self.m_idx = ( (linear_idx // group_repeat) * self._group_size @@ -596,7 +594,7 @@ def update_current_m_n_idx(self, linear_idx): ) % self._m_clusters self.n_idx = (linear_idx % group_repeat) // self._group_size elif linear_idx < remote_m_clusters * self._n_clusters: - remainder_idx: Tx.let = linear_idx - group_rows * self._n_clusters + remainder_idx: T.let = linear_idx - group_rows * self._n_clusters self.m_idx = ( group_rows + remainder_idx % final_rows @@ -604,7 +602,7 @@ def update_current_m_n_idx(self, linear_idx): ) % self._m_clusters self.n_idx = remainder_idx // final_rows else: - remainder_idx: Tx.let = linear_idx - remote_m_clusters * self._n_clusters + remainder_idx: T.let = linear_idx - remote_m_clusters * self._n_clusters self.m_idx = ( remote_m_clusters + remainder_idx % (self._m_clusters // self._world_size) @@ -612,7 +610,7 @@ def update_current_m_n_idx(self, linear_idx): ) % self._m_clusters self.n_idx = remainder_idx // (self._m_clusters // self._world_size) - @Tx.inline + @T.inline def next_tile(self, stride: int): self.linear_idx = self.linear_idx + stride self.update_current_m_n_idx(self.linear_idx) @@ -630,24 +628,24 @@ def __init__(self, prefix: str, b_indices, h_indices, q_indices, tiles_indptr): self.h_indices = h_indices self.q_indices = q_indices self.tiles_indptr = tiles_indptr - self.q_idx = Tx.local_scalar("int32") - self.h_idx = Tx.local_scalar("int32") - self.b_idx = Tx.local_scalar("int32") - self.linear_lim = Tx.local_scalar("int32") + self.q_idx = T.local_scalar("int32") + self.h_idx = T.local_scalar("int32") + self.b_idx = T.local_scalar("int32") + self.linear_lim = T.local_scalar("int32") - @Tx.inline + @T.inline def _load(self): self.q_idx = self.q_indices[self.linear_idx] self.h_idx = self.h_indices[self.linear_idx] self.b_idx = self.b_indices[self.linear_idx] - @Tx.inline + @T.inline def init(self, sm): self.linear_idx = self.tiles_indptr[sm] self.linear_lim = self.tiles_indptr[sm + 1] self._load() - @Tx.inline + @T.inline def next_tile(self): self.linear_idx = self.linear_idx + 1 self._load() @@ -690,29 +688,29 @@ def __init__( self._total_tasks = num_batches * num_heads * num_m_blocks # Output indices - self.batch_idx = Tx.local_scalar("int32") - self.head_idx = Tx.local_scalar("int32") - self.m_block_idx = Tx.local_scalar("int32") + self.batch_idx = T.local_scalar("int32") + self.head_idx = T.local_scalar("int32") + self.m_block_idx = T.local_scalar("int32") # fmt: off - @Tx.inline + @T.inline def update_current_m_n_idx(self, linear_idx): """Convert linear index to (batch, head, m_block) coordinates.""" - NUM_HEADS = Tx.meta_var(self._num_heads) - NUM_M_BLOCKS = Tx.meta_var(self._num_m_blocks) - HEAD_M_PRODUCT = Tx.meta_var(NUM_HEADS * NUM_M_BLOCKS) + NUM_HEADS = T.meta_var(self._num_heads) + NUM_M_BLOCKS = T.meta_var(self._num_m_blocks) + HEAD_M_PRODUCT = T.meta_var(NUM_HEADS * NUM_M_BLOCKS) self.batch_idx = linear_idx // HEAD_M_PRODUCT self.head_idx = (linear_idx % HEAD_M_PRODUCT) // NUM_M_BLOCKS self.m_block_idx = linear_idx % NUM_M_BLOCKS - @Tx.inline + @T.inline def init(self, cta_id): """Initialize scheduler with CTA ID.""" self.linear_idx = cta_id self.update_current_m_n_idx(cta_id) - @Tx.inline + @T.inline def next_tile(self): """Advance to next tile by striding by num_ctas.""" self.linear_idx = self.linear_idx + self._num_ctas @@ -770,30 +768,30 @@ def __init__( self._num_hb_quotient = self._num_hb // l2_swizzle # Output indices - self.batch_idx = Tx.local_scalar("int32") - self.head_idx = Tx.local_scalar("int32") - self.m_block_idx = Tx.local_scalar("int32") + self.batch_idx = T.local_scalar("int32") + self.head_idx = T.local_scalar("int32") + self.m_block_idx = T.local_scalar("int32") # fmt: off - @Tx.inline + @T.inline def update_current_m_n_idx(self, linear_idx): """Convert linear index to (batch, head, m_block) with LPT + L2 swizzle.""" - L2_SWIZZLE = Tx.meta_var(self._l2_swizzle) - L2_MAJOR = Tx.meta_var(self._l2_major) - NUM_HB_QUOTIENT = Tx.meta_var(self._num_hb_quotient) - NUM_HB = Tx.meta_var(self._num_hb) - NUM_HEADS = Tx.meta_var(self._num_heads) - NUM_M_BLOCKS = Tx.meta_var(self._num_m_blocks) + L2_SWIZZLE = T.meta_var(self._l2_swizzle) + L2_MAJOR = T.meta_var(self._l2_major) + NUM_HB_QUOTIENT = T.meta_var(self._num_hb_quotient) + NUM_HB = T.meta_var(self._num_hb) + NUM_HEADS = T.meta_var(self._num_heads) + NUM_M_BLOCKS = T.meta_var(self._num_m_blocks) # L2 swizzle decomposition - bidhb: Tx.let = linear_idx // L2_MAJOR - l2_mod: Tx.let = linear_idx % L2_MAJOR + bidhb: T.let = linear_idx // L2_MAJOR + l2_mod: T.let = linear_idx % L2_MAJOR # Handle residual section (last partial swizzle group) - num_hb_remainder: Tx.let = Tx.max(NUM_HB % L2_SWIZZLE, 1) - m_block_raw: Tx.let = Tx.Select(bidhb < NUM_HB_QUOTIENT, l2_mod // L2_SWIZZLE, l2_mod // num_hb_remainder) # noqa: E501 - bidhb_residual: Tx.let = Tx.Select(bidhb < NUM_HB_QUOTIENT, l2_mod % L2_SWIZZLE, l2_mod % num_hb_remainder) # noqa: E501 - bidhb_actual: Tx.let = bidhb * L2_SWIZZLE + bidhb_residual + num_hb_remainder: T.let = T.max(NUM_HB % L2_SWIZZLE, 1) + m_block_raw: T.let = T.Select(bidhb < NUM_HB_QUOTIENT, l2_mod // L2_SWIZZLE, l2_mod // num_hb_remainder) # noqa: E501 + bidhb_residual: T.let = T.Select(bidhb < NUM_HB_QUOTIENT, l2_mod % L2_SWIZZLE, l2_mod % num_hb_remainder) # noqa: E501 + bidhb_actual: T.let = bidhb * L2_SWIZZLE + bidhb_residual self.batch_idx = bidhb_actual // NUM_HEADS self.head_idx = bidhb_actual % NUM_HEADS @@ -801,13 +799,13 @@ def update_current_m_n_idx(self, linear_idx): # LPT: Reverse block order so high-work blocks are processed first self.m_block_idx = (NUM_M_BLOCKS - 1) - m_block_raw - @Tx.inline + @T.inline def init(self, cta_id): """Initialize scheduler with CTA ID.""" self.linear_idx = cta_id self.update_current_m_n_idx(cta_id) - @Tx.inline + @T.inline def next_tile(self): """Advance to next tile by striding by num_ctas.""" self.linear_idx = self._total_tasks diff --git a/python/tvm/tirx/lang/warp_role.py b/python/tvm/tirx/lang/warp_role.py index 874800c78cb4..0258013bab1a 100644 --- a/python/tvm/tirx/lang/warp_role.py +++ b/python/tvm/tirx/lang/warp_role.py @@ -35,28 +35,31 @@ # MMA compute code """ -from tvm.script import tirx as Tx +from tvm.script import tirx as T class WarpRole: """A warp-level role that guards a block of code by warp_id comparison - and wraps it in ``Tx.warp()`` with optional register budget. + with optional register budget. Generates:: if == : - with Tx.warp(): - Tx.ptx.setmaxnreg(, ) # if regs specified - + T.ptx.setmaxnreg(, ) # if regs specified + + + The ``if`` guard narrows the active set to the single warp; individual + tile-primitive calls inside ```` carry their own exec scope via + a scope-namespace prefix (e.g. ``Tx.warp.copy(...)``). Parameters ---------- warp_id_var : Var - The warp_id variable (from ``Tx.warp_id(...)``). + The warp_id variable (from ``T.warp_id(...)``). warp_id_val : int Which warp index this role corresponds to. regs : int, optional - Register budget (passed to ``Tx.ptx.setmaxnreg``). + Register budget (passed to ``T.ptx.setmaxnreg``). If None, no setmaxnreg is emitted. increase : bool Direction for ``setmaxnreg`` (default False = decrease). @@ -69,18 +72,15 @@ def __init__(self, warp_id_var, warp_id_val, regs=None, increase=False): self.increase = increase def __enter__(self): - self._if_frame = Tx.If(self.warp_id_var == self.warp_id_val) + self._if_frame = T.If(self.warp_id_var == self.warp_id_val) self._if_frame.__enter__() - self._then_frame = Tx.Then() + self._then_frame = T.Then() self._then_frame.__enter__() - self._warp_frame = Tx.warp() - self._warp_frame.__enter__() if self.regs is not None: - Tx.evaluate(Tx.ptx.setmaxnreg(self.increase, self.regs)) + T.evaluate(T.ptx.setmaxnreg(self.increase, self.regs)) return self def __exit__(self, *exc): - self._warp_frame.__exit__(*exc) self._then_frame.__exit__(*exc) self._if_frame.__exit__(*exc) return False @@ -88,26 +88,28 @@ def __exit__(self, *exc): class WarpgroupRole: """A warpgroup-level role that guards by wg_id comparison, - wraps in ``Tx.warpgroup()``, with optional register budget. + with optional register budget. Generates (single wg_id):: if == : - with Tx.warpgroup(): - Tx.ptx.setmaxnreg(, ) # if regs specified - + T.ptx.setmaxnreg(, ) # if regs specified + Generates (range of wg_ids, e.g. ``wg_id_val=(0, 2)``):: if 0 <= and < 2: - with Tx.warpgroup(): - Tx.ptx.setmaxnreg(, ) - + T.ptx.setmaxnreg(, ) + + + The ``if`` guard narrows the active set to the target warpgroup(s); + individual tile-primitive calls inside ```` carry their own exec + scope via a scope-namespace prefix (e.g. ``Tx.wg.copy(...)``). Parameters ---------- wg_id_var : Var - The warpgroup_id variable (from ``Tx.warpgroup_id(...)``). + The warpgroup_id variable (from ``T.warpgroup_id(...)``). wg_id_val : int or tuple[int, int] Which warpgroup index (int) or range ``(start, stop)`` this role corresponds to. @@ -126,20 +128,17 @@ def __init__(self, wg_id_var, wg_id_val, regs=None, increase=False): def __enter__(self): if isinstance(self.wg_id_val, tuple): start, stop = self.wg_id_val - self._if_frame = Tx.If(start <= self.wg_id_var and self.wg_id_var < stop) + self._if_frame = T.If(start <= self.wg_id_var and self.wg_id_var < stop) else: - self._if_frame = Tx.If(self.wg_id_var == self.wg_id_val) + self._if_frame = T.If(self.wg_id_var == self.wg_id_val) self._if_frame.__enter__() - self._then_frame = Tx.Then() + self._then_frame = T.Then() self._then_frame.__enter__() - self._wg_frame = Tx.warpgroup() - self._wg_frame.__enter__() if self.regs is not None: - Tx.evaluate(Tx.ptx.setmaxnreg(self.increase, self.regs)) + T.evaluate(T.ptx.setmaxnreg(self.increase, self.regs)) return self def __exit__(self, *exc): - self._wg_frame.__exit__(*exc) self._then_frame.__exit__(*exc) self._if_frame.__exit__(*exc) return False diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py index 834908bce36b..1d2c7ec2d167 100644 --- a/python/tvm/tirx/op.py +++ b/python/tvm/tirx/op.py @@ -59,6 +59,27 @@ tir = tirx # alias for backward compat with upstream tir.convert() calls +_DEVICE_INTRIN_PREFIX_TO_NAMESPACE = { + "cuda_": "cuda", + "ptx_": "ptx", + "nvshmem_": "nvshmem", + "nki_": "nki", +} + + +def _canonical_device_intrin_name(func_name: str) -> str: + """Return the canonical registry name for statically registered device intrinsics.""" + + if not isinstance(func_name, str) or not func_name.startswith("tirx."): + return func_name + basename = func_name[len("tirx.") :] + if "." in basename: + return func_name + for prefix, namespace in _DEVICE_INTRIN_PREFIX_TO_NAMESPACE.items(): + if basename.startswith(prefix): + return f"tirx.{namespace}.{basename[len(prefix) :]}" + return func_name + def _pack_buffer(buf, span=None): """Build intrinsics that packs the buffer.""" @@ -218,6 +239,8 @@ def call_intrin(dtype, func_name, *args, attrs=None, span=None): call : PrimExpr The call expression. """ + if isinstance(func_name, str): + func_name = _canonical_device_intrin_name(func_name) return Call(dtype, func_name, args, attrs=attrs, span=span) @@ -1523,6 +1546,31 @@ def exp10(x): return call_intrin(x.dtype, "tirx.exp10", x) +def fma(x, y, z): + """Take fused multiply-add of input x, y, z. + + Parameters + ---------- + x : PrimExpr + First input argument. + + y : PrimExpr + Second input argument. + + z : PrimExpr + Third input argument. + + Returns + ------- + out : PrimExpr + The result of x * y + z. + """ + x = tir.convert(x) + y = tir.convert(y) + z = tir.convert(z) + return call_intrin(x.dtype, "tirx.fma", x, y, z) + + def erf(x): """Take gauss error function of the input x. @@ -2220,7 +2268,7 @@ def filter(var, pred, *, span=None): # pylint: disable=redefined-builtin Use this wrapper only when the predicate is *not* in the canonical thread-filter grammar (see ``src/tirx/analysis/filter_canonical.h``). Canonical predicates -- pure conjunctions of ``scopeid_var const`` - comparisons plus bare ``Tx.ptx.elect_sync()`` calls -- are recognized by + comparisons plus bare ``T.ptx.elect_sync()`` calls -- are recognized by the lowering pass directly from ``if cond:``, so the wrapper is redundant for them. @@ -3322,7 +3370,7 @@ def cuda_thread_rank(): referencing user-declared scope_id vars. For example, the idiomatic mbarrier.init leader predicate is:: - Tx.cuda.thread_rank() == 0 + T.cuda.thread_rank() == 0 Returns ------- @@ -3918,7 +3966,7 @@ def ptx_cp_async_bulk_shared_to_cluster(dst_ptr, src_ptr, size, mbar): mbar : PrimExpr Mbarrier address in shared::cluster space for completion signaling, - usually produced by ``Tx.ptx.map_shared_rank``. + usually produced by ``T.ptx.map_shared_rank``. Returns ------- @@ -4900,7 +4948,7 @@ def ptx_ldmatrix(trans, num, dtype, smem_ptr, *dst_handles): """TVM intrinsic for ldmatrix.sync.aligned.m8n8.x{num}{.trans}.shared.{dtype}. Mirrors the PTX ISA destination form: each output register is a separate - operand. Pass ``Tx.address_of(buf[idx])`` (or ``buf.ptr_to([idx])``) for + operand. Pass ``T.address_of(buf[idx])`` (or ``buf.ptr_to([idx])``) for each destination — the slots may be non-contiguous. Parameters @@ -5039,7 +5087,7 @@ def ptx_stmatrix(trans, num, dtype, smem_ptr, *src_handles, shape="m8n8", space= """TVM intrinsic for ``stmatrix.sync.aligned.shape.x{num}{.trans}.space.{dtype}``. Mirrors :func:`ptx_ldmatrix`: each source register is a separate operand. - Pass ``Tx.address_of(buf[idx])`` (or ``buf.ptr_to([idx])``) for each + Pass ``T.address_of(buf[idx])`` (or ``buf.ptr_to([idx])``) for each source — the slots may be non-contiguous. Parameters diff --git a/python/tvm/tirx/operator/intrinsics/_schema.py b/python/tvm/tirx/operator/intrinsics/_schema.py index 7d83d5cb7526..57e409e9555c 100644 --- a/python/tvm/tirx/operator/intrinsics/_schema.py +++ b/python/tvm/tirx/operator/intrinsics/_schema.py @@ -25,9 +25,7 @@ ``__forceinline__ __device__ { }``, * registers a codegen function under the op name so ``call_intrin("", "tirx.", *args)`` resolves to a call to that - helper, and -* registers the op with TVM's Op registry (``TCallEffectKind=Opaque``) so - it doesn't need a C++ ``TIR_DEFINE_BUILTIN_FUNC`` entry. + helper. TVM Op registration is static C++ only. Args passed to the codegen are split into ``(forward_args, attr_args)``: the trailing ``n_attrs`` are attrs (consumed by the ``helper_name`` / @@ -144,37 +142,3 @@ def codegen(*args): codegen.__name__ = f"codegen_{op_name}" register_codegen(op_name)(codegen) - _ensure_op_registered(f"tirx.{op_name}") - - -# --------------------------------------------------------------------------- -# Dynamic Op registration — ensures op_name has a TVM Op (with default -# TCallEffectKind=Opaque) so call_intrin can resolve it without requiring a -# C++ TIR_DEFINE_BUILTIN_FUNC entry. -# --------------------------------------------------------------------------- - -import tvm_ffi # noqa: E402 - -_ir_register_op = tvm_ffi.get_global_func("ir.RegisterOp") -_ir_register_op_attr = tvm_ffi.get_global_func("ir.RegisterOpAttr") -# CallEffectKind enum (include/tvm/tir/op_attr_types.h): Opaque = 4. -_CALL_EFFECT_KIND_OPAQUE = 4 -_registered_attrs: set = set() - - -def _ensure_op_registered(op_name: str) -> None: - """Register ``op_name`` if not already in TVM's Op registry, plus a - default ``TCallEffectKind=Opaque`` attribute. Both calls are no-ops when - the op / attribute is already registered (the C++-side registrations win - by plevel).""" - try: - _ir_register_op(op_name, "") - except Exception: - pass - if op_name in _registered_attrs: - return - try: - _ir_register_op_attr(op_name, "TCallEffectKind", _CALL_EFFECT_KIND_OPAQUE, 10) - _registered_attrs.add(op_name) - except Exception: - pass diff --git a/python/tvm/tirx/operator/intrinsics/cuda/misc.py b/python/tvm/tirx/operator/intrinsics/cuda/misc.py index 01404a9cc68a..0cca2cd19456 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/misc.py +++ b/python/tvm/tirx/operator/intrinsics/cuda/misc.py @@ -208,7 +208,7 @@ def codegen_cuda_printf(fmt, *args): if isinstance(fmt, tvm.tirx.StringImm): fmt = fmt.value if not isinstance(fmt, str): - raise ValueError("Tx.cuda.printf format must be a string literal") + raise ValueError("T.cuda.printf format must be a string literal") fmt_literal = json.dumps(fmt) arg_dtypes = [str(arg.dtype) for arg in args] signature = "|".join([fmt, *arg_dtypes]) @@ -232,7 +232,7 @@ def c_type(dtype: str) -> str: return "int" if dtype == "handle": return "void*" - raise ValueError(f"Unsupported Tx.cuda.printf argument dtype: {dtype}") + raise ValueError(f"Unsupported T.cuda.printf argument dtype: {dtype}") params = ", ".join(f"{c_type(dtype)} arg{i}" for i, dtype in enumerate(arg_dtypes)) call_args = ", ".join(f"arg{i}" for i in range(len(args))) diff --git a/python/tvm/tirx/operator/intrinsics/cuda/registry.py b/python/tvm/tirx/operator/intrinsics/cuda/registry.py index 72a0e6ec8e32..e6aad10ddb55 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/registry.py +++ b/python/tvm/tirx/operator/intrinsics/cuda/registry.py @@ -26,8 +26,23 @@ import tvm_ffi CODEGEN_REGISTRY = {} -_CALL_EFFECT_KIND_OPAQUE = 4 -_registered_attrs: set[str] = set() + + +def _canonical_device_intrin_name(op_name: str) -> str: + if not op_name.startswith("tirx."): + return op_name + basename = op_name[len("tirx.") :] + if "." in basename: + return op_name + for prefix, namespace in ( + ("cuda_", "cuda"), + ("ptx_", "ptx"), + ("nvshmem_", "nvshmem"), + ("nki_", "nki"), + ): + if basename.startswith(prefix): + return f"tirx.{namespace}.{basename[len(prefix) :]}" + return op_name @tvm_ffi.register_global_func("tirx.intrinsics.cuda.get_codegen") @@ -45,7 +60,8 @@ def register_codegen(op, backend="cuda"): def decorator(func): full_op_name = "tirx." + op - _ensure_op_registered(full_op_name) + canonical_op_name = _canonical_device_intrin_name(full_op_name) + op_names = {full_op_name, canonical_op_name} @functools.wraps(func) def wrapper(arg_list): @@ -54,24 +70,8 @@ def wrapper(arg_list): return res[0], res[1] return res, list() - CODEGEN_REGISTRY[full_op_name] = wrapper + for op_name in op_names: + CODEGEN_REGISTRY[op_name] = wrapper return wrapper return decorator - - -def _ensure_op_registered(op_name: str) -> None: - """Ensure dynamic TIRx ops also have a purity/effect attribute.""" - try: - tvm_ffi.get_global_func("ir.RegisterOp")(op_name, "") - except Exception: - pass - if op_name in _registered_attrs: - return - try: - tvm_ffi.get_global_func("ir.RegisterOpAttr")( - op_name, "TCallEffectKind", _CALL_EFFECT_KIND_OPAQUE, 10 - ) - _registered_attrs.add(op_name) - except Exception: - pass diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/common.py b/python/tvm/tirx/operator/tile_primitive/cuda/common.py index b7696293c93c..08c56deaecdd 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/common.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/common.py @@ -24,7 +24,7 @@ from tvm.arith.analyzer import Analyzer from tvm.runtime import DataType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import Buffer, BufferRegion, PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext, fail from tvm.tirx.stmt import TilePrimitiveCall @@ -70,7 +70,7 @@ def smem_desc_add_16B_offset(desc_val, offset): return desc.desc_; }} """ - return Tx.cuda.func_call( + return T.cuda.func_call( func_name, desc_val, offset, source_code=source_code, return_type="uint64" ) @@ -205,41 +205,41 @@ def copy_vec_load_impl( if sctx.is_cta: # fmt: off - @Tx.prim_func + @T.prim_func def impl(): """Implement copy operation with vectorized loads/stores.""" - for s in Tx.serial(0, n_elements // (tx * vec_len)): - for tid_x in Tx.thread_binding(tx, "threadIdx.x"): + for s in T.serial(0, n_elements // (tx * vec_len)): + for tid_x in T.thread_binding(tx, "threadIdx.x"): if inst_type == CopyInstType.NORMAL: - for vec in Tx.vectorized(vec_len): - fused = Tx.meta_var((s * tx + tid_x) * vec_len + vec) - dst_indices = Tx.meta_var(get_indices(fused, dst_st, dst_extent)) - src_indices = Tx.meta_var(get_indices(fused, src_st, src_extent)) + for vec in T.vectorized(vec_len): + fused = T.meta_var((s * tx + tid_x) * vec_len + vec) + dst_indices = T.meta_var(get_indices(fused, dst_st, dst_extent)) + src_indices = T.meta_var(get_indices(fused, src_st, src_extent)) dst[tuple(dst_indices)] = src[tuple(src_indices)] elif inst_type == CopyInstType.CP_ASYNC: - fused = Tx.meta_var((s * tx + tid_x) * vec_len) - dst_indices = Tx.meta_var(get_indices(fused, dst_st, dst_extent)) - src_indices = Tx.meta_var(get_indices(fused, src_st, src_extent)) - Tx.evaluate(Tx.ptx.cp_async(dst.ptr_to(dst_indices), src.ptr_to(src_indices), cp_size)) # noqa: E501 + fused = T.meta_var((s * tx + tid_x) * vec_len) + dst_indices = T.meta_var(get_indices(fused, dst_st, dst_extent)) + src_indices = T.meta_var(get_indices(fused, src_st, src_extent)) + T.evaluate(T.ptx.cp_async(dst.ptr_to(dst_indices), src.ptr_to(src_indices), cp_size)) # noqa: E501 if dst.scope().startswith("shared") and inst_type == CopyInstType.NORMAL: - Tx.tvm_storage_sync("shared") + T.tvm_storage_sync("shared") # fmt: on elif sctx.is_thread: # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - for s in Tx.serial(0, n_elements // (vec_len)): + for s in T.serial(0, n_elements // (vec_len)): if inst_type == CopyInstType.NORMAL: - for vec in Tx.vectorized(vec_len): - fused = Tx.meta_var(s * vec_len + vec) - dst_indices = Tx.meta_var(get_indices(fused, dst_st, dst_extent)) - src_indices = Tx.meta_var(get_indices(fused, src_st, src_extent)) + for vec in T.vectorized(vec_len): + fused = T.meta_var(s * vec_len + vec) + dst_indices = T.meta_var(get_indices(fused, dst_st, dst_extent)) + src_indices = T.meta_var(get_indices(fused, src_st, src_extent)) dst[tuple(dst_indices)] = src[tuple(src_indices)] elif inst_type == CopyInstType.CP_ASYNC: - fused = Tx.meta_var(s * vec_len) - dst_indices = Tx.meta_var(get_indices(fused, dst_st, dst_extent)) - src_indices = Tx.meta_var(get_indices(fused, src_st, src_extent)) - Tx.evaluate(Tx.ptx.cp_async(dst.ptr_to(dst_indices), src.ptr_to(src_indices), cp_size)) # noqa: E501 + fused = T.meta_var(s * vec_len) + dst_indices = T.meta_var(get_indices(fused, dst_st, dst_extent)) + src_indices = T.meta_var(get_indices(fused, src_st, src_extent)) + T.evaluate(T.ptx.cp_async(dst.ptr_to(dst_indices), src.ptr_to(src_indices), cp_size)) # noqa: E501 # fmt: on else: fail(f"unsupported exec_scope {sctx.scope_kind}") diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/_swizzle_iter.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy/_swizzle_iter.py index 2f8303f2ccd7..0037c4ac07b8 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/_swizzle_iter.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/_swizzle_iter.py @@ -69,7 +69,7 @@ import tvm from tvm import arith -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx.expr import IntImm as _IntImm from tvm.tirx.layout import ComposeLayout, SwizzleLayout @@ -270,7 +270,7 @@ def try_recognize( def emit_init(pattern: SwizzlePattern, s_off_resolved): - """Emit at thread setup (call from inside the @Tx.prim_func body): + """Emit at thread setup (call from inside the @T.prim_func body): 1. ``base_off = swizzle.apply(s_off_resolved)`` — runtime, per-thread, computed once. @@ -296,7 +296,7 @@ def emit_init(pattern: SwizzlePattern, s_off_resolved): if n == 0: return None, base_off - signed_strides = Tx.alloc_buffer([n], "int32", scope="local") + signed_strides = T.alloc_buffer([n], "int32", scope="local") q = tvm.tirx.floordiv(s_off_resolved, C) def _sigma_bit(bit_pos: int): @@ -308,27 +308,29 @@ def _sigma_bit(bit_pos: int): return _IntImm("int32", 1) - row_bit * _IntImm("int32", 2) for j, (bj, stride) in enumerate(zip(pattern.bit_positions, pattern.iter_strides_elems)): - T = stride # = 2^(bj + p) elements + stride_pow = stride # = 2^(bj + p) elements if 0 <= bj < sw: # Case 1.A (inner): signed_stride = sigma_(at + bj) · T. - value = _sigma_bit(at + bj) * _IntImm("int32", T) + value = _sigma_bit(at + bj) * _IntImm("int32", stride_pow) elif sw <= bj < at: # Case 1.B (mid): signed_stride = +T. - value = _IntImm("int32", T) + value = _IntImm("int32", stride_pow) elif at <= bj < at + sw: # Case 1.C (outer): signed_stride = T + sigma_(bj - at) · T_sec. # Invariant: bj >= at, so T_sec = T >> at = 2^(bj - at + p) # = T(bj - at) is well-defined (no underflow). - T_sec = T >> at - value = _IntImm("int32", T) + _sigma_bit(bj - at) * _IntImm("int32", T_sec) + stride_sec = stride_pow >> at + value = _IntImm("int32", stride_pow) + _sigma_bit(bj - at) * _IntImm( + "int32", stride_sec + ) else: # bj >= at + sw, Case 1.D (above) # No swizzle effect at this bit; signed_stride = +T. - value = _IntImm("int32", T) + value = _IntImm("int32", stride_pow) # NB: Buffer.__setitem__ syntax (``signed_strides[j] = value``) is # intercepted by the TIRx script parser but not by raw Python when - # this function is called from outside an @Tx.inline body. Use the + # this function is called from outside an @T.inline body. Use the # low-level buffer_store builder instead. - Tx.buffer_store(signed_strides, value, [_IntImm("int32", j)]) + T.buffer_store(signed_strides, value, [_IntImm("int32", j)]) return signed_strides, base_off diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/fallback.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy/fallback.py index bd0faa3bd8cd..ab69d3924450 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/fallback.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/fallback.py @@ -20,7 +20,7 @@ import warnings import tvm -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import Buffer, PrimFunc from tvm.tirx.operator.tile_primitive.dispatcher import ( predicate, @@ -75,14 +75,14 @@ def _src_coord(lvs): coord[src_indices[k]] += lv return coord - with Tx.grid(*copy_extents) as lvs: - Tx.buffer_store(dst_buf, src_buf[tuple(_src_coord(lvs))], _dst_coord(lvs)) + with T.grid(*copy_extents) as lvs: + T.buffer_store(dst_buf, src_buf[tuple(_src_coord(lvs))], _dst_coord(lvs)) scope_kind = sctx.scope_kind if scope_kind == "thread": - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): _copy_body(dst, src) @@ -96,7 +96,7 @@ def impl(): elif scope_kind == "cta": first_tid += 32 * int(sctx.intra["warpid"][1]) - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): tid = _axis_decl(tid_axis_name, sctx) if tid == first_tid: diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/gmem_smem.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy/gmem_smem.py index fbce2e2e2be7..aee24c62e4f5 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/gmem_smem.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/gmem_smem.py @@ -27,7 +27,7 @@ import tvm from tvm.runtime import DataType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import Buffer, PrimFunc from tvm.tirx import Var as _TirVar from tvm.tirx.expr import IntImm as _IntImm @@ -136,7 +136,7 @@ def _emit_gmem_smem(op_call: TilePrimitiveCall, sctx: DispatchContext) -> PrimFu # [outer x thread x vec] coord scheme below. vec_bits = vec_len * elem_bits - copy_op = getattr(Tx.cuda, f"copy_{vec_bits}b") + copy_op = getattr(T.cuda, f"copy_{vec_bits}b") # Partition guarantees ``prod(s_p.shard.extents) == prod(g_p.shard.extents) # == n_elements`` (the total transfer count). Express the per-thread @@ -263,7 +263,7 @@ def _s_off(f, s_lin): v0 = _IntImm("int32", 0) # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): tid = _decl_tid() _setup_swizzle(tid) @@ -272,7 +272,7 @@ def impl(): # misaligned vector ops. # # Use a serial TIR loop and let ptxas unroll downstream. Mirrors - # the reg.py rationale in commit ac7ecf70f0: explicit ``Tx.unroll`` + # the reg.py rationale in commit ac7ecf70f0: explicit ``T.unroll`` # materializes the per-iter scratch (s_lin/g_lin/s_off/s_ptr/g_ptr) # as N copies of each ``alignas(64)`` declaration. For large # ``total_outer`` (e.g. thread-scope fp32 swizzled copies of 32x256 diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py index c8243523a6a1..75b0c9015b55 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py @@ -25,7 +25,7 @@ from math import prod import tvm -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import PrimFunc from tvm.tirx import Var as _TirVar from tvm.tirx.expr import IntImm as _IntImm @@ -294,14 +294,14 @@ def _try_num(r_in, s_in, num): # Step 10: emit one ldmatrix/stmatrix per mm, per warp. def _get_warp_idx_in_T(): - # Tx.warp_id_in_wg() / Tx.warp_id() must be called from inside a - # @Tx.prim_func body — wrap so the prim_func parser calls us at parse + # T.warp_id_in_wg() / T.warp_id() must be called from inside a + # @T.prim_func body — wrap so the prim_func parser calls us at parse # time (Python `if` here is plain control flow, not TIR-intercepted). if r_lane_axis == "laneid": return 0 if r_lane_axis == "tid_in_wg": - return Tx.warp_id_in_wg() - return Tx.warp_id() # "tx" + return T.warp_id_in_wg() + return T.warp_id() # "tx" def _seg4_coord(laneid_expr): # num=1: seg 4 trivially extent-1, pass 0. num>1: use lane//8 (tile @@ -373,7 +373,7 @@ def __init__(self): def _resolve_s_off(laneid_var, warp_var): # Build the placeholder→runtime-var map and substitute. Keep this in a - # regular Python helper — the @Tx.prim_func parser intercepts dict + # regular Python helper — the @T.prim_func parser intercepts dict # literals when written directly in the body. vmap = {lane_ph: laneid_var} if warp_ph is not None: @@ -405,10 +405,10 @@ def _smem_off(mm_idx, logical_off): return logical_off # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): r_local = r_buf.local(m_total, layout=TileLayout(S[(m_total,)])) - laneid = Tx.lane_id() + laneid = T.lane_id() warp_idx_in_T = _get_warp_idx_in_T() # Resolve s_off_template by substituting placeholders → actual # scope-id vars (via _resolve_s_off helper to keep the dict literal @@ -416,7 +416,7 @@ def impl(): # without swizzle we keep using the per-iter s.apply directly. if swizzle_pattern is not None: _setup_swizzle(_resolve_s_off(laneid, warp_idx_in_T)) - for mm in Tx.unroll(m_outer): + for mm in T.unroll(m_outer): tile_off = s.apply( warp_idx_in_T, 0, 0, mm, _seg4_coord(laneid), 0, shape=apply_shape, )[s_mem_axis] @@ -430,9 +430,9 @@ def impl(): for i in range(num) ] if direction == "ld": - Tx.ptx.ldmatrix(trans, num, ".b16", smem_ptr, *handles) + T.ptx.ldmatrix(trans, num, ".b16", smem_ptr, *handles) else: - Tx.ptx.stmatrix( + T.ptx.stmatrix( trans, num, ".b16", smem_ptr, *handles, shape="m8n8", space="shared", ) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/reg.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy/reg.py index b8de9d641f57..5ea8d40e9382 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/reg.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/reg.py @@ -30,7 +30,7 @@ import tvm from tvm.arith import Analyzer from tvm.runtime import DataType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import Buffer, PrimFunc from tvm.tirx import Var as _TirVar from tvm.tirx.expr import IntImm as _IntImm @@ -410,15 +410,15 @@ def _axis_decl(axis_name: str, sctx: DispatchContext): if axis_name == "tx": return sctx.launch_params["threadIdx.x"].var if axis_name == "laneid": - return Tx.lane_id() + return T.lane_id() if axis_name == "wid_in_wg": - return Tx.warp_id_in_wg() + return T.warp_id_in_wg() if axis_name == "tid_in_wg": - return Tx.thread_id_in_wg() + return T.thread_id_in_wg() if axis_name == "warpid": - return Tx.warp_id() + return T.warp_id() if axis_name == "wgid": - return Tx.warpgroup_id() + return T.warpgroup_id() raise ValueError(f"unsupported thread axis {axis_name}") @@ -463,7 +463,7 @@ def _flat_coords(outer_atoms, flat_idx: int) -> list[int]: def _ptr_off(base_ptr, off): - return Tx.cuda.func_call( + return T.cuda.func_call( "tvm_builtin_pointer_offset", base_ptr, off, @@ -504,7 +504,7 @@ def _emit_reg(op_call: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc: # Build the per-thread S offset OUTSIDE the impl using placeholder Vars # (one per thread axis). Inside the impl we'll declare the real scope_ids - # via Tx.lane_id/Tx.thread_id_in_wg/... and substitute them in. + # via T.lane_id/T.thread_id_in_wg/... and substitute them in. placeholders = _make_thread_placeholders(r_p) s_off_template = _s_thread_offset(r_p, s_p, placeholders) @@ -515,7 +515,7 @@ def _emit_reg(op_call: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc: for _ax, val in r_p.offset.items(): r_off_base = r_off_base + val - copy_op = getattr(Tx.cuda, f"copy_{vec_bits}b") + copy_op = getattr(T.cuda, f"copy_{vec_bits}b") total_outer = 1 for a in outer: @@ -556,13 +556,13 @@ def _s_iter_off(f, ds, s_off): # fmt: off s_zero_indices = [0] * len(s_buf.shape) - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): s_off = _substitute_axes(s_off_template, placeholders, sctx) _setup_swizzle(s_off) r_local = r_buf.local(*per_thread_r_shape) # Keep as a serial TIR loop and let ptxas unroll downstream. An - # explicit ``Tx.unroll`` materializes the per-iter scratch + # explicit ``T.unroll`` materializes the per-iter scratch # (ds/dr/s_ptr/r_ptr, swizzle ``v_[]`` signed-strides) as N # copies of each buffer declaration; on kernels with many R↔S copy # sites and large ``total_outer`` (FA4 writeback) this floods the diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/dsmem.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/dsmem.py index 0266b432f57a..c0e3e7cfdcc8 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/dsmem.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/dsmem.py @@ -21,7 +21,7 @@ import operator import tvm -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import Buffer, PrimFunc from tvm.tirx.operator.tile_primitive import ( DispatchContext, @@ -132,7 +132,7 @@ def copy_dsmem_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> PrimFu outer_src_strides = [grouped_src.shard[i].stride for i in outer_shard_indices] outer_dst_strides = [grouped_dst.shard[i].stride for i in outer_shard_indices] - # Helper to compute element offsets from loop variables (called via Tx.meta_var) + # Helper to compute element offsets from loop variables (called via T.meta_var) def compute_offsets(loop_vars): if len(outer_extents) == 1: lvs = [loop_vars] @@ -149,27 +149,27 @@ def compute_offsets(loop_vars): dst_tile = to_tile_layout(dst_buf.layout, dst_buf.shape) # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): # Map mbar to remote CTA (complete_tx targets the destination's mbar) - remote_mbar = Tx.ptx.map_shared_rank(mbar, remote_cta_id) + remote_mbar = T.ptx.map_shared_rank(mbar, remote_cta_id) if not outer_extents: # Single contiguous chunk — no iteration needed src_ptr = src_buf.ptr_to(src_st) - cluster_dst = Tx.ptx.map_shared_rank(dst_buf.ptr_to(dst_st), remote_cta_id) - Tx.ptx.cp_async.bulk.s2c(cluster_dst, src_ptr, chunk_bytes, remote_mbar) + cluster_dst = T.ptx.map_shared_rank(dst_buf.ptr_to(dst_st), remote_cta_id) + T.ptx.cp_async.bulk.s2c(cluster_dst, src_ptr, chunk_bytes, remote_mbar) else: - for loop_vars in Tx.grid(*outer_extents): - src_elem_offset, dst_elem_offset = Tx.meta_var(compute_offsets(loop_vars)) + for loop_vars in T.grid(*outer_extents): + src_elem_offset, dst_elem_offset = T.meta_var(compute_offsets(loop_vars)) - src_buf_w = Tx.decl_buffer( + src_buf_w = T.decl_buffer( src_buf.shape, src_buf.dtype, src_buf.data, elem_offset=src_buf.elem_offset + src_elem_offset, scope=src_buf.scope(), layout=src_tile, ) - dst_buf_w = Tx.decl_buffer( + dst_buf_w = T.decl_buffer( dst_buf.shape, dst_buf.dtype, dst_buf.data, elem_offset=dst_buf.elem_offset + dst_elem_offset, scope=dst_buf.scope(), @@ -177,8 +177,8 @@ def impl(): ) src_ptr = src_buf_w.ptr_to(src_st) - cluster_dst = Tx.ptx.map_shared_rank(dst_buf_w.ptr_to(dst_st), remote_cta_id) - Tx.ptx.cp_async.bulk.s2c(cluster_dst, src_ptr, chunk_bytes, remote_mbar) + cluster_dst = T.ptx.map_shared_rank(dst_buf_w.ptr_to(dst_st), remote_cta_id) + T.ptx.cp_async.bulk.s2c(cluster_dst, src_ptr, chunk_bytes, remote_mbar) # fmt: on return impl diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/ldgsts.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/ldgsts.py index 1742c53bc191..8c86f75ac8c4 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/ldgsts.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/ldgsts.py @@ -19,14 +19,14 @@ (SASS: ``LDGSTS``). Shares the partition / layout-alignment algorithm with -``cuda/copy/gmem_smem.py`` (sync ``Tx.copy`` global ↔ shared); differs at +``cuda/copy/gmem_smem.py`` (sync ``T.copy`` global ↔ shared); differs at emit time only: * direction: ``cp.async`` is global → shared only (hardware restriction). * cp_size: PTX ``cp.async`` only accepts 4 / 8 / 16 bytes, so the vec-width candidate set is restricted to ``{32, 64, 128}`` bits. -* emit: ``Tx.evaluate(Tx.ptx.cp_async(dst, src, cp_size))`` instead of the - synchronous ``Tx.cuda.copy_{vec_bits}b(dst, src)``. +* emit: ``T.evaluate(T.ptx.cp_async(dst, src, cp_size))`` instead of the + synchronous ``T.cuda.copy_{vec_bits}b(dst, src)``. Note: ``cp.async`` does **not** sync at emit time — caller is responsible for ``commit_group`` / ``wait_group`` / ``cta_sync`` plumbing around the @@ -34,7 +34,7 @@ """ from tvm.runtime import DataType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import Buffer, PrimFunc from tvm.tirx import Var as _TirVar from tvm.tirx.expr import IntImm as _IntImm @@ -247,17 +247,17 @@ def _s_off(f, s_lin): v0 = _IntImm("int32", 0) # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): tid = _decl_tid() _setup_swizzle(tid) - for f in Tx.unroll(total_outer): + for f in T.unroll(total_outer): s_lin = s_p.apply(f, tid, v0, shape=apply_shape)["m"] g_lin = g_p.apply(f, tid, v0, shape=apply_shape)["m"] s_off = _s_off(f, s_lin) s_ptr = _ptr_off(s_buf.ptr_to(s_zero), s_off) g_ptr = _ptr_off(g_buf.ptr_to(g_zero), g_lin) - Tx.evaluate(Tx.ptx.cp_async(s_ptr, g_ptr, cp_size)) + T.evaluate(T.ptx.cp_async(s_ptr, g_ptr, cp_size)) # cp.async is caller-synced — no cta_sync here (commit_group / # wait_group / cta_sync are the caller's responsibility). # fmt: on diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_cp.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_cp.py index b06a62f60338..3a9d81947804 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_cp.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_cp.py @@ -60,7 +60,7 @@ import tvm from tvm.arith import Analyzer from tvm.runtime import DataType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import Buffer, PrimFunc from tvm.tirx.layout import ComposeLayout, SwizzleLayout, TCol, TileLayout, TLane from tvm.tirx.layout import m as m_axis @@ -378,7 +378,7 @@ def _get_or_create_desc(sctx, s_buf, ldo, sdo, swizzle): return cached desc_buf = tvm.tirx.decl_buffer((1,), "uint64", name="cp_desc", scope="local") - encode_call = Tx.ptx.tcgen05.encode_matrix_descriptor( + encode_call = T.ptx.tcgen05.encode_matrix_descriptor( desc_buf.data, s_buf.ptr_to([0] * len(s_buf.shape)), ldo, sdo, swizzle ) wrap = SeqStmt([AllocBuffer(desc_buf), Evaluate(encode_call)]) @@ -410,17 +410,17 @@ def copy_smem_tmem_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> Pr t_addr = t_buf.allocated_addr from tvm.tirx.operator.tile_primitive.cuda.common import smem_desc_add_16B_offset - # Flatten the N-D middle iteration into a single Tx.unroll. Each iteration's + # Flatten the N-D middle iteration into a single T.unroll. Each iteration's # per-dim index is (flat // stride) % extent, summed into the t/s offsets. # Works uniformly for n_mid ∈ {0, 1, 2, ...}; total == 1 (no middle dims) is - # special-cased to avoid a degenerate Tx.unroll(1). + # special-cased to avoid a degenerate T.unroll(1). total = functools.reduce(operator.mul, [n for n, _, _ in middle_iters], 1) # fmt: off if total == 1: - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - Tx.ptx.tcgen05.cp( + T.ptx.tcgen05.cp( t_addr[0] + t_col0, smem_desc_add_16B_offset(desc_buf[0], init_off_16B), shape="32x128b", cta_group=cta_group, multicast="warpx4", @@ -437,11 +437,11 @@ def compute_offsets(flat): s_off = s_off + idx * s_step return t_off, s_off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - for flat in Tx.unroll(total): - t_off, s_off = Tx.meta_var(compute_offsets(flat)) - Tx.ptx.tcgen05.cp( + for flat in T.unroll(total): + t_off, s_off = T.meta_var(compute_offsets(flat)) + T.ptx.tcgen05.cp( t_addr[0] + t_col0 + t_off, smem_desc_add_16B_offset(desc_buf[0], init_off_16B + s_off), shape="32x128b", cta_group=cta_group, multicast="warpx4", diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py index ff270a867fff..ffd5e18a3a5c 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py @@ -25,7 +25,7 @@ import tvm from tvm.arith import Analyzer from tvm.runtime import DataType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import Buffer, PrimFunc from tvm.tirx.layout import ( S, @@ -113,7 +113,7 @@ def _classify_tmem_datapath(tmem_buf): # Compatibility matrix between the TMEM buffer's datapath layout and the -# tcgen05 ld/st atom requested by ``Tx.copy_async``: +# tcgen05 ld/st atom requested by ``T.copy_async``: # # datapath x atom | accepted? | rationale # ---------------------------- | --------- | -------------------------------- @@ -279,15 +279,14 @@ def _emit_32x32b_path( # assert analyzer.can_prove_equal(local_st[1], 0) assert analyzer.can_prove_equal(local_extent[1], width) - op = Tx.ptx.tcgen05.ld if direction == "tmem2local" else Tx.ptx.tcgen05.st + op = T.ptx.tcgen05.ld if direction == "tmem2local" else T.ptx.tcgen05.st # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - with Tx.warp(): - local_storage = local_buf.view(local_buf.shape[1] * elem_per_32b, layout=TileLayout(S[num * elem_per_32b])) # noqa: E501 - local_32b = local_storage.view("uint32") - op(tmem_buf.allocated_addr[0], *[local_32b[local_st[1] // elem_per_32b+i] for i in range(num)], shape="32x32b", num=num, row=0, col=offset_32b) # noqa: E501 + local_storage = local_buf.view(local_buf.shape[1] * elem_per_32b, layout=TileLayout(S[num * elem_per_32b])) # noqa: E501 + local_32b = local_storage.view("uint32") + op(tmem_buf.allocated_addr[0], *[local_32b[local_st[1] // elem_per_32b+i] for i in range(num)], shape="32x32b", num=num, row=0, col=offset_32b) # noqa: E501 # fmt: on return impl @@ -394,7 +393,7 @@ def _emit_16xnb_path( local_col_off_elems = local_col_off is_load = direction == "tmem2local" - op = Tx.ptx.tcgen05.ld if is_load else Tx.ptx.tcgen05.st + op = T.ptx.tcgen05.ld if is_load else T.ptx.tcgen05.st # We intentionally do *not* emit ``.pack::16b`` / ``.unpack::16b`` for # 16-bit dtypes. That qualifier would store one 16-bit element per 32-bit # TMEM cell (LOW half only, HIGH half wasted) — fine for some CUTLASS @@ -405,21 +404,20 @@ def _emit_16xnb_path( # the layout factory's iters describe that packing. # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - with Tx.warp(): - # Per-thread 1-D flat view of the local storage, then a uint32 view - # for the register-pointer arguments of the PTX builtin. - local_storage = local_buf.view(per_thread_elems, layout=TileLayout(S[per_thread_elems])) - local_32b = local_storage.view("uint32") - local_reg_base = local_col_off_elems // elem_per_32b - for slab in range(n_slabs): - reg_base = slab * regs_per_thread_per_slab - op( - tmem_buf.allocated_addr[0], - *[local_32b[local_reg_base + reg_base + i] for i in range(regs_per_thread_per_slab)], # noqa: E501 - shape=shape, num=num, row=slab * 16, col=col_off_32b, - ) + # Per-thread 1-D flat view of the local storage, then a uint32 view + # for the register-pointer arguments of the PTX builtin. + local_storage = local_buf.view(per_thread_elems, layout=TileLayout(S[per_thread_elems])) + local_32b = local_storage.view("uint32") + local_reg_base = local_col_off_elems // elem_per_32b + for slab in range(n_slabs): + reg_base = slab * regs_per_thread_per_slab + op( + tmem_buf.allocated_addr[0], + *[local_32b[local_reg_base + reg_base + i] for i in range(regs_per_thread_per_slab)], # noqa: E501 + shape=shape, num=num, row=slab * 16, col=col_off_32b, + ) # fmt: on return impl @@ -429,9 +427,9 @@ def impl(): # When: one buffer is in tmem (tensor memory, Blackwell SM100+) and the other # is in local scope, at warpgroup exec scope. # -# Emits: Tx.ptx.tcgen05.ld / Tx.ptx.tcgen05.st (async). The caller is -# responsible for issuing the matching ``Tx.ptx.tcgen05.wait.ld`` / -# ``Tx.ptx.tcgen05.wait.st`` when synchronization is required. +# Emits: T.ptx.tcgen05.ld / T.ptx.tcgen05.st (async). The caller is +# responsible for issuing the matching ``T.ptx.tcgen05.wait.ld`` / +# ``T.ptx.tcgen05.wait.st`` when synchronization is required. @register_dispatch( "copy_async", "cuda", diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tma.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tma.py index ae6e78ada911..7e19713e492b 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tma.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tma.py @@ -48,7 +48,8 @@ import tvm from tvm.arith import Analyzer -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx import Buffer, PrimFunc from tvm.tirx.layout import ComposeLayout, Layout, S, SwizzleLayout, TileLayout from tvm.tirx.operator.tile_primitive import ( @@ -1166,17 +1167,17 @@ def val_key(value) -> str: tensor_map = cached_tensormap tensormap_is_cached = True else: - tensor_map = Tx.Var( - g_buf.data.name + "_tensormap", dtype=Tx.handle("tensormap").type_annotation + tensor_map = T.Var( + g_buf.data.name + "_tensormap", dtype=T.handle("tensormap").type_annotation ) tensormap_is_cached = False # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - for loop_vars in Tx.unroll(flat_total_extent): - s_offset, tma_coords = Tx.meta_var(compute_offsets_and_tma_coords(loop_vars)) - s_buf_w_offset = Tx.decl_buffer( + for loop_vars in T.unroll(flat_total_extent): + s_offset, tma_coords = T.meta_var(compute_offsets_and_tma_coords(loop_vars)) + s_buf_w_offset = T.decl_buffer( s_buf.shape, s_buf.dtype, s_buf.data, @@ -1186,11 +1187,11 @@ def impl(): ) if direction == "g2s": - Tx.ptx.cp_async.bulk.tensor.g2c( + T.ptx.cp_async.bulk.tensor.g2c( plan.rank, s_buf_w_offset.ptr_to(s_st), mbar, - Tx.address_of(tensor_map), + T.address_of(tensor_map), cta_mask, cta_group, op_call.config.get("cache_hint", ""), @@ -1198,18 +1199,18 @@ def impl(): ) else: if use_tma_reduce is None: - Tx.ptx.cp_async.bulk.tensor.s2g( + T.ptx.cp_async.bulk.tensor.s2g( plan.rank, s_buf_w_offset.ptr_to(s_st), - Tx.address_of(tensor_map), + T.address_of(tensor_map), op_call.config.get("cache_hint", ""), *tma_coords, ) else: - Tx.ptx.cp_async.bulk.tensor.s2g_reduce( + T.ptx.cp_async.bulk.tensor.s2g_reduce( plan.rank, s_buf_w_offset.ptr_to(s_st), - Tx.address_of(tensor_map), + T.address_of(tensor_map), op_call.config.get("cache_hint", ""), use_tma_reduce, *tma_coords, @@ -1218,10 +1219,10 @@ def impl(): if not tensormap_is_cached: # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def create_tensor_map(): - Tx.Bind(Tx.tvm_stack_alloca("tensormap", 1), var=tensor_map) - Tx.call_packed( + T.Bind(T.tvm_stack_alloca("tensormap", 1), var=tensor_map) + T.call_packed( "runtime.cuTensorMapEncodeTiled", tensor_map, plan.elem_dtype, @@ -1250,10 +1251,10 @@ def create_tensor_map(): warp_id_in_cta = sctx.launch_params["warp_id_in_cta"].var # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def prefetch_tensor_map(): if warp_id_in_cta == 0: - Tx.ptx.prefetch_tensormap(Tx.address_of(tensor_map)) + T.ptx.prefetch_tensormap(T.address_of(tensor_map)) Tx.tvm_kernel_replace_point() # fmt: on diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/_common.py b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/_common.py index 57494ac9e895..97a62a040a49 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/_common.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/_common.py @@ -32,7 +32,7 @@ from tvm.arith.analyzer import Analyzer from tvm.runtime import DataType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion from tvm.tirx.layout import Axis, Iter, TileLayout @@ -351,16 +351,16 @@ def fetch_src_value(src, fused, dst_indices, dst_start, dst_extent): def emit_scope_sync(scope_kind: str): - """Returns an ``@Tx.inline`` sync helper matched to the exec scope.""" + """Returns an ``@T.inline`` sync helper matched to the exec scope.""" - @Tx.inline + @T.inline def sync(): if scope_kind == "cta": - Tx.cuda.cta_sync() + T.cuda.cta_sync() elif scope_kind == "warpgroup": - Tx.cuda.warpgroup_sync(8) + T.cuda.warpgroup_sync(8) elif scope_kind == "warp": - Tx.cuda.warp_sync() + T.cuda.warp_sync() return sync diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/__init__.py b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/__init__.py index a552d7ddf35c..2e82773b9678 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/__init__.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/__init__.py @@ -82,7 +82,7 @@ class VecImpl: # dst_ptr: typed ptr to ``vec_len`` consecutive dst elements # src_ptrs[i]: typed ptr to ``vec_len`` consecutive src[i] elements, # OR a scalar Expr if src[i].is_scalar. - # Runs in Python at @Tx.prim_func build time — branching on src kind is a + # Runs in Python at @T.prim_func build time -- branching on src kind is a # normal Python ``if``, not a TVMScript shape limitation. This is what # collapses the old 4x2 shape-explosion in schema.py's factories. emit: Callable diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/unary.py b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/unary.py index b81f1a07809a..1a016bc567a4 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/unary.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/unary.py @@ -17,7 +17,7 @@ """Unary elementwise ops: zero / fill / reciprocal / sqrt / exp / exp2 / silu. -All carry the same ``Tx.(dst, src[, bias, scale])`` shape (bias / scale +All carry the same ``T.(dst, src[, bias, scale])`` shape (bias / scale optional; ``silu`` ignores bias/scale to preserve legacy behavior). """ @@ -26,7 +26,7 @@ from typing import Any from tvm.ir.expr import PrimExpr -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion, TilePrimitiveCall from tvm.tirx.expr import FloatImm @@ -34,7 +34,7 @@ def _parse_unary(op: TilePrimitiveCall) -> tuple[Plan | None, str | None]: - """Tx.(dst, src[, bias, scale]) → Plan.""" + """T.(dst, src[, bias, scale]) → Plan.""" _dst: BufferRegion = op.args[0] _src = op.args[1] _bias = op.args[2] if len(op.args) > 2 else None @@ -71,7 +71,7 @@ def _check_unary_extras(extras: dict, compute_dtype: str) -> tuple[bool, str | N def _with_bias_scale(raw_op): - """Wrap ``raw_op`` (e.g. ``Tx.exp``) into a compute that applies bias/scale first.""" + """Wrap ``raw_op`` (e.g. ``T.exp``) into a compute that applies bias/scale first.""" def compute(src_vals, extras, dt): x = src_vals[0] @@ -97,21 +97,21 @@ def _compute_fill(src_vals, extras, dt): def _compute_reciprocal(src_vals, extras, dt): x = src_vals[0] - return Tx.FloatImm(x.dtype, 1.0) / x + return T.FloatImm(x.dtype, 1.0) / x def _compute_silu(src_vals, extras, dt): # Legacy: silu doesn't apply bias/scale. x = src_vals[0] - return x / (Tx.FloatImm(x.dtype, 1.0) + Tx.exp(Tx.FloatImm(x.dtype, 0.0) - x)) + return x / (T.FloatImm(x.dtype, 1.0) + T.exp(T.FloatImm(x.dtype, 0.0) - x)) UNARY_OPS: dict[str, OpSpec] = { "zero": OpSpec("zero", _parse_unary, _compute_zero, _check_unary_extras), "fill": OpSpec("fill", _parse_unary, _compute_fill, _check_unary_extras), "reciprocal": OpSpec("reciprocal", _parse_unary, _compute_reciprocal, _check_unary_extras), - "sqrt": OpSpec("sqrt", _parse_unary, _with_bias_scale(Tx.sqrt), _check_unary_extras), - "exp": OpSpec("exp", _parse_unary, _with_bias_scale(Tx.exp), _check_unary_extras), - "exp2": OpSpec("exp2", _parse_unary, _with_bias_scale(Tx.exp2), _check_unary_extras), + "sqrt": OpSpec("sqrt", _parse_unary, _with_bias_scale(T.sqrt), _check_unary_extras), + "exp": OpSpec("exp", _parse_unary, _with_bias_scale(T.exp), _check_unary_extras), + "exp2": OpSpec("exp2", _parse_unary, _with_bias_scale(T.exp2), _check_unary_extras), "silu": OpSpec("silu", _parse_unary, _compute_silu, _check_unary_extras), } diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/reg.py b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/reg.py index 50b2544f9d45..063e4c397903 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/reg.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/reg.py @@ -21,7 +21,7 @@ carries thread-axis info (the "anchor" operand). The region slice is absorbed into the sliced layout up front via ``align_operands_to_anchor`` — emit operates on a flat 1D per-thread view and indexes it with a scalar offset, so -codegen never sees multi-dim ``get_indices`` inside ``Tx.vectorized``. +codegen never sees multi-dim ``get_indices`` inside ``T.vectorized``. Two paths inside emit: * induced (anchor exists) — atom-based, exactly mirrors copy reg.py @@ -35,7 +35,7 @@ import operator from tvm.arith import Analyzer -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import PrimFunc, TilePrimitiveCall from tvm.tirx.layout import TileLayout from tvm.tirx.operator.tile_primitive import DispatchContext @@ -285,7 +285,7 @@ def _make_views_meta(per_op_carved, per_thread_total): iter strides at codegen time. """ return { - op_br: Tx.decl_buffer( + op_br: T.decl_buffer( (per_thread_total,), op_br.buffer.dtype, op_br.buffer.data, @@ -297,7 +297,7 @@ def _make_views_meta(per_op_carved, per_thread_total): # ----------------------------------------------------------------------------- -# Emit — packed (one PTX/CUDA call per outer chunk; no Tx.vectorized inside) +# Emit — packed (one PTX/CUDA call per outer chunk; no T.vectorized inside) # ----------------------------------------------------------------------------- def _emit_induced_packed( plan, vec_impl, vec_len, outer_total, per_thread_total, per_op_carved, anchor_br @@ -306,10 +306,10 @@ def _emit_induced_packed( srcs = plan.srcs dst_br = plan.dst - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - views = Tx.meta_var(_make_views_meta(per_op_carved, per_thread_total)) - # Serial loop (not Tx.unroll): Tx.unroll materializes each per-iter + views = T.meta_var(_make_views_meta(per_op_carved, per_thread_total)) + # Serial loop (not T.unroll): T.unroll materializes each per-iter # ``dst_lane_indices`` / ``src_args`` buffer as a fresh int[1] # declaration, multiplying by outer_total. ptxas unrolls the # static-bound loop without that scratch explosion. @@ -317,7 +317,7 @@ def impl(): # Pass logical 1D coord; each buffer's own layout maps it to # physical at access time (handles wgmma, broadcast, etc.). dst_lane_indices = [[f * vec_len + k] for k in range(vec_len)] - src_args = Tx.meta_var( + src_args = T.meta_var( [ src.scalar if src.is_scalar @@ -328,14 +328,14 @@ def impl(): for src in srcs ] ) - Tx.evaluate(vec_impl.emit(views[dst_br], dst_lane_indices, src_args, extras)) + T.evaluate(vec_impl.emit(views[dst_br], dst_lane_indices, src_args, extras)) return impl # ----------------------------------------------------------------------------- # Emit — scalar fallback (vec_len = 1; one element per outer iter; no -# Tx.vectorized inside, so no codegen vec-packing of multi-dim indices). +# T.vectorized inside, so no codegen vec-packing of multi-dim indices). # ----------------------------------------------------------------------------- def _emit_induced_scalar( plan, spec, outer_total, per_thread_total, per_op_carved, anchor_br @@ -346,16 +346,16 @@ def _emit_induced_scalar( dst_dtype = dst_br.buffer.dtype compute = spec.compute_scalar - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - views = Tx.meta_var(_make_views_meta(per_op_carved, per_thread_total)) - # Serial loop (not Tx.unroll) — see _emit_induced_packed for why. + views = T.meta_var(_make_views_meta(per_op_carved, per_thread_total)) + # Serial loop (not T.unroll) — see _emit_induced_packed for why. for f in range(outer_total): # Logical 1D coord = f (vec_len = 1 in scalar path); each # buffer's layout maps to physical at access time. - src_vals = Tx.meta_var( + src_vals = T.meta_var( [src.scalar if src.is_scalar else views[src.buf_region][f] for src in srcs] ) - views[dst_br][f] = Tx.cast(compute(src_vals, extras, dst_dtype), dst_dtype) + views[dst_br][f] = T.cast(compute(src_vals, extras, dst_dtype), dst_dtype) return impl diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/smem.py b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/smem.py index 2b3fa1acfe5f..3ac7405d7734 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/smem.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/smem.py @@ -31,7 +31,7 @@ from __future__ import annotations from tvm.runtime import DataType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import PrimFunc, TilePrimitiveCall from tvm.tirx.operator.tile_primitive import DispatchContext from tvm.tirx.operator.tile_primitive.dispatcher import fail @@ -159,7 +159,7 @@ def emit_smem(op_call: TilePrimitiveCall, spec, sctx: DispatchContext) -> PrimFu def _tid_expr(sctx: DispatchContext): """Per-scope tid expr. ``thread`` scope returns 0; collective scopes use - ``_axis_decl`` (Tx.lane_id / Tx.thread_id_in_wg / threadIdx.x).""" + ``_axis_decl`` (T.lane_id / T.thread_id_in_wg / threadIdx.x).""" if sctx.scope_kind == "thread": return 0 axis_name = _TID_AXIS_FOR_SCOPE[sctx.scope_kind] @@ -197,18 +197,18 @@ def _emit_packed(plan, vec_impl, vec_chunk, total, thread_cnt, sctx) -> PrimFunc sync = emit_scope_sync(sctx.scope_kind) n_outer = (total + vec_chunk * thread_cnt - 1) // (vec_chunk * thread_cnt) - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): tid = _tid_expr(sctx) - for s in Tx.serial(0, n_outer): + for s in T.serial(0, n_outer): # First lane's fused index for this thread, this chunk. - fused0 = Tx.meta_var(s * vec_chunk * thread_cnt + tid * vec_chunk) + fused0 = T.meta_var(s * vec_chunk * thread_cnt + tid * vec_chunk) # Predicate the call (skip the trailing partial chunk). if fused0 + vec_chunk <= total: - dst_lane_indices = Tx.meta_var( + dst_lane_indices = T.meta_var( [get_indices(fused0 + k, dst_st, dst_ext) for k in range(vec_chunk)] ) - src_args = Tx.meta_var( + src_args = T.meta_var( [ srcs[i].scalar if srcs[i].is_scalar @@ -226,7 +226,7 @@ def impl(): for i in range(len(srcs)) ] ) - Tx.evaluate(vec_impl.emit(dst_buf, dst_lane_indices, src_args, extras)) + T.evaluate(vec_impl.emit(dst_buf, dst_lane_indices, src_args, extras)) sync() return impl @@ -245,18 +245,18 @@ def _emit_scalar(plan, spec, vec_chunk, total, thread_cnt, sctx) -> PrimFunc: sync = emit_scope_sync(sctx.scope_kind) n_outer = (total + vec_chunk * thread_cnt - 1) // (vec_chunk * thread_cnt) - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): tid = _tid_expr(sctx) - for s in Tx.serial(0, n_outer): - for vec in Tx.vectorized(vec_chunk): - fused = Tx.meta_var(s * vec_chunk * thread_cnt + tid * vec_chunk + vec) + for s in T.serial(0, n_outer): + for vec in T.vectorized(vec_chunk): + fused = T.meta_var(s * vec_chunk * thread_cnt + tid * vec_chunk + vec) if fused < total: - dst_idx = Tx.meta_var(get_indices(fused, dst_st, dst_ext)) - src_vals = Tx.meta_var( + dst_idx = T.meta_var(get_indices(fused, dst_st, dst_ext)) + src_vals = T.meta_var( [fetch_src_value(src, fused, dst_idx, dst_st, dst_ext) for src in srcs] ) - dst_buf[tuple(dst_idx)] = Tx.cast( + dst_buf[tuple(dst_idx)] = T.cast( compute(src_vals, extras, dst_dtype), dst_dtype ) sync() diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/__init__.py b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/__init__.py index 6c1dd6bdc0da..1aa4dcb79158 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/__init__.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/__init__.py @@ -34,7 +34,7 @@ with per-lane indices * extras: dict (rounding_mode, etc.) - Returns the PTX/CUDA call result; the schedule wraps in ``Tx.evaluate`` at + Returns the PTX/CUDA call result; the schedule wraps in ``T.evaluate`` at the call site. All Python-side shape branching (scalar vs buffer src) happens in this emit function -- collapses the old 4x2 schema.py factory explosion. """ diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/binary_f32x2.py b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/binary_f32x2.py index d7bf422acdc1..d8e995b45a65 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/binary_f32x2.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/binary_f32x2.py @@ -19,15 +19,15 @@ PTX op family: ``{add,sub,mul}..ftz.f32x2``. Each call processes 2 f32s per operand. The old ``_make_binary_packed_f32x2_factory`` (240+ lines, 8 -``@Tx.prim_func`` shape combos per op) collapses to one ``emit`` per op +``@T.prim_func`` shape combos per op) collapses to one ``emit`` per op because operand-shape branching is now Python-level (outside any -``@Tx.prim_func``). +``@T.prim_func``). """ from __future__ import annotations from tvm.ir.expr import PrimExpr -from tvm.script import tirx as Tx +from tvm.script import tirx as T from ..ops import VecImpl @@ -70,15 +70,15 @@ def applies(op_call, sctx, plan): def _emit_binary_f32x2_for(op_name): - op_func = getattr(Tx.ptx, f"{op_name}_f32x2") + op_func = getattr(T.ptx, f"{op_name}_f32x2") def emit(dst_buf, dst_lane_indices, src_args, extras) -> PrimExpr: a_arg, b_arg = src_args rm = extras.get("rounding_mode", "rz") return op_func( - Tx.address_of(dst_buf[tuple(dst_lane_indices[0])]), - Tx.cuda.make_float2(_lane(a_arg, 0), _lane(a_arg, 1)), - Tx.cuda.make_float2(_lane(b_arg, 0), _lane(b_arg, 1)), + T.address_of(dst_buf[tuple(dst_lane_indices[0])]), + T.cuda.make_float2(_lane(a_arg, 0), _lane(a_arg, 1)), + T.cuda.make_float2(_lane(b_arg, 0), _lane(b_arg, 1)), rounding=rm, ftz=True, ) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/cast_vec2.py b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/cast_vec2.py index 5bd7c5a34f3d..46292761b28f 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/cast_vec2.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/cast_vec2.py @@ -26,7 +26,7 @@ from __future__ import annotations from tvm.ir.expr import PrimExpr -from tvm.script import tirx as Tx +from tvm.script import tirx as T from ..ops import VecImpl @@ -74,10 +74,10 @@ def _emit_cast_vec2(dst_buf, dst_lane_indices, src_args, extras) -> PrimExpr: src_buf, src_lane_indices = src_arg func_name = _intrinsic_name(src_buf.dtype, dst_buf.dtype) source_code = _intrinsic_source(src_buf.dtype, dst_buf.dtype) - return Tx.cuda.func_call( + return T.cuda.func_call( func_name, - Tx.address_of(dst_buf[tuple(dst_lane_indices[0])]), - Tx.address_of(src_buf[tuple(src_lane_indices[0])]), + T.address_of(dst_buf[tuple(dst_lane_indices[0])]), + T.address_of(src_buf[tuple(src_lane_indices[0])]), source_code=source_code, ) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/fma_f32x2.py b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/fma_f32x2.py index 3435b9799ff6..f47476d6a5ce 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/fma_f32x2.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/fma_f32x2.py @@ -24,7 +24,7 @@ from __future__ import annotations from tvm.ir.expr import PrimExpr -from tvm.script import tirx as Tx +from tvm.script import tirx as T from ..ops import VecImpl from .binary_f32x2 import _lane @@ -61,11 +61,11 @@ def _fma_f32x2_applies(op_call, sctx, plan): def _emit_fma_f32x2(dst_buf, dst_lane_indices, src_args, extras) -> PrimExpr: a_arg, b_arg, c_arg = src_args rm = extras.get("rounding_mode", "rz") - return Tx.ptx.fma_f32x2( - Tx.address_of(dst_buf[tuple(dst_lane_indices[0])]), - Tx.cuda.make_float2(_lane(a_arg, 0), _lane(a_arg, 1)), - Tx.cuda.make_float2(_lane(b_arg, 0), _lane(b_arg, 1)), - Tx.cuda.make_float2(_lane(c_arg, 0), _lane(c_arg, 1)), + return T.ptx.fma_f32x2( + T.address_of(dst_buf[tuple(dst_lane_indices[0])]), + T.cuda.make_float2(_lane(a_arg, 0), _lane(a_arg, 1)), + T.cuda.make_float2(_lane(b_arg, 0), _lane(b_arg, 1)), + T.cuda.make_float2(_lane(c_arg, 0), _lane(c_arg, 1)), rounding=rm, ftz=True, ) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/exec_scope_utils.py b/python/tvm/tirx/operator/tile_primitive/cuda/exec_scope_utils.py index 0b274dd488bb..54deb5108c68 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/exec_scope_utils.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/exec_scope_utils.py @@ -18,7 +18,7 @@ from collections.abc import Callable -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext from tvm.tirx.stmt import TilePrimitiveCall @@ -29,7 +29,7 @@ def macro_or_prim_func(macro: Callable, need_macro: bool = False) -> Callable: if need_macro: return macro - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def func(): macro() @@ -48,7 +48,7 @@ def thread_selector(sctx: DispatchContext, inner_impl, macro: bool = False) -> C The dispatch context. Only ``sctx.scope_kind`` is consulted; the caller is responsible for having narrowed into the desired scope via an ``if`` guard with a canonical thread-filter predicate before reaching here. - inner_impl : Tx.inline + inner_impl : T.inline The body to execute inside the selected thread. macro : bool If True, return the macro directly; otherwise wrap it in a ``prim_func``. @@ -59,35 +59,31 @@ def thread_selector(sctx: DispatchContext, inner_impl, macro: bool = False) -> C return macro_or_prim_func(inner_impl, need_macro=macro) if name == "cta": - @Tx.inline() + @T.inline() def impl(): - Tx.lane_id([32]) - if Tx.ptx.elect_sync(): - with Tx.thread(): - inner_impl() + T.lane_id([32]) + if T.ptx.elect_sync(): + inner_impl() return macro_or_prim_func(impl, need_macro=macro) if name == "warp": - @Tx.inline() + @T.inline() def impl(): - Tx.lane_id([32]) - if Tx.ptx.elect_sync(): - with Tx.thread(): - inner_impl() + T.lane_id([32]) + if T.ptx.elect_sync(): + inner_impl() return macro_or_prim_func(impl, need_macro=macro) if name == "warpgroup": - @Tx.inline() + @T.inline() def impl(): - warp_id = Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) + warp_id = T.warp_id_in_wg([4]) + T.lane_id([32]) if warp_id == 0: - with Tx.warp(): - if Tx.ptx.elect_sync(): - with Tx.thread(): - inner_impl() + if T.ptx.elect_sync(): + inner_impl() return macro_or_prim_func(impl, need_macro=macro) raise ValueError(f"thread_selector: unsupported exec_scope {name!r}") diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/gemm/mma_m16n8k_.py b/python/tvm/tirx/operator/tile_primitive/cuda/gemm/mma_m16n8k_.py index 6d1f9183caac..b069556843e0 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/gemm/mma_m16n8k_.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/gemm/mma_m16n8k_.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from tvm.arith.analyzer import Analyzer -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import PrimFunc from tvm.tirx.layout import TileLayout from tvm.tirx.operator.tile_primitive import ( @@ -524,7 +524,7 @@ def _slice_group(buf, region, shape2d, name): ) # Emit one mma per (m, n) output tile, accumulating over K. The tile / init / - # K loops use Tx.unroll: the UnrollLoop pass fully expands them in TIR (their + # K loops use T.unroll: the UnrollLoop pass fully expands them in TIR (their # bounds are compile-time constants), so the local-buffer indices resolve to # static register slots -- mma register operands must be constant. # @@ -548,23 +548,23 @@ def _slice_group(buf, region, shape2d, name): n_rN = inst.n // 4 n_kHi = inst.k // (4 * inst.k_pack) - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): d_local = D.local(*d_shape, layout=D_reg) c_local = C.local(*c_shape, layout=C_reg) a_local = A.local(*a_shape, layout=A_reg) b_local = B.local(*b_shape, layout=B_reg) - for m in Tx.unroll(M_tiles): - for n in Tx.unroll(N_tiles): + for m in T.unroll(M_tiles): + for n in T.unroll(N_tiles): # Initialize D[m, n]: copy C (beta==1) or clear to 0 (beta==0). - for rM in Tx.unroll(n_rM): - for rN in Tx.unroll(n_rN): + for rM in T.unroll(n_rM): + for rN in T.unroll(n_rN): if use_c: d_local[m, n, rM, rN] = c_local[m, n, rM, rN] else: - d_local[m, n, rM, rN] = Tx.float32(0) + d_local[m, n, rM, rN] = T.float32(0) # Accumulate over K in place: d = a·b + d. - for k in Tx.unroll(K_tiles): + for k in T.unroll(K_tiles): # D: 4 f32 in PTX order c_id = 2*rM + rN. d_ptrs = [ d_local.ptr_to([m, n, rM, rN]) for rM in range(n_rM) for rN in range(n_rN) @@ -578,7 +578,7 @@ def impl(): # B: b32 regs in PTX order b32 = kHi. b_ptrs = [b_local.ptr_to([k, n, kHi, 0]) for kHi in range(n_kHi)] # Accumulate in place into D's own regs: c = d. - Tx.ptx.mma( + T.ptx.mma( shape_str, "row", "col", diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py b/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py index a439355a9771..5aac270467e3 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py @@ -28,7 +28,7 @@ import tvm from tvm.arith.analyzer import Analyzer from tvm.runtime import DataType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import PrimFunc from tvm.tirx.layout import ( ComposeLayout, @@ -87,7 +87,7 @@ def _encode_instr_descriptor_dense_uint32( See ``python/tvm/tirx/operator/intrinsics/cuda/header.py:InstrDescriptor`` for the bit layout. Lets the dispatcher pass a literal ``uint32`` to - ``Tx.ptx.tcgen05.mma`` instead of allocating + encoding a per-dispatch + ``T.ptx.tcgen05.mma`` instead of allocating + encoding a per-dispatch local descriptor on every gemm_async call (which forces an inline ``asm`` block that ptxas cannot hoist out of the i_kv loop body). """ @@ -728,8 +728,8 @@ def _make_lo_uniform(desc): d->lo = __shfl_sync(0xffffffff, d->lo, 0); }} """ - return Tx.cuda.func_call( - func_name, Tx.address_of(desc), source_code=source_code, return_type="void" + return T.cuda.func_call( + func_name, T.address_of(desc), source_code=source_code, return_type="void" ) def _make_desc_wrap(desc_buf, smem_buf, base, ldo, sdo, swizzle_val): @@ -771,7 +771,7 @@ def _make_desc(smem_buf, base, ldo, sdo, swizzle_val, name): if not a_is_tmem: A_base = [0] * len(A_buffer.shape) descA_buf = _make_desc(A_buffer, A_base, A_ldo, A_sdo, A_swizzle_mode.value, "descA") - elect_pred = Tx.ptx.elect_sync() if warp_scope else True + elect_pred = T.ptx.elect_sync() if warp_scope else True # Helper: compute B descriptor value for a given (ni, ki) tile def _b_desc_val(descB_in, ni, ki): @@ -789,7 +789,7 @@ def _a_operand(mi, ki, descA_in=None): # A is [M, K] non-transposed: M→TLane (rows), K→TCol (cols) a_row = mi * M_mma a_col = A_tmem_offset_32b + ki * (MMA_K // A_elem_per_32b) - return Tx.cuda.get_tmem_addr(A_tmem_addr, a_row, a_col) + return T.cuda.get_tmem_addr(A_tmem_addr, a_row, a_col) else: A_linear = ( ki * MMA_K * A_extent[-1] + mi * M_mma @@ -831,11 +831,11 @@ def _a_operand(mi, ki, descA_in=None): # Build main_impl: descA_in is None when A is in TMEM (ignored by _a_operand). # fmt: off if is_block_scaled: - @Tx.inline + @T.inline def main_impl(descA_in, descB_in, descI_in): - for mi in Tx.unroll(M_tiles): - for ni in Tx.unroll(N_tiles): - for ki in Tx.unroll(K_iters): + for mi in T.unroll(M_tiles): + for ni in T.unroll(N_tiles): + for ki in T.unroll(K_iters): a_val = _a_operand(mi, ki, descA_in) descB_val = _b_desc_val(descB_in, ni, ki) should_accum = tvm.tirx.any(ki != 0, accum_expr) @@ -846,12 +846,12 @@ def main_impl(descA_in, descB_in, descI_in): sfa_addr = sfa_base + tvm.tirx.floordiv(sfa_tcol, SFA_elem_per_col) sfb_addr = sfb_base + tvm.tirx.floordiv(sfb_tcol, SFB_elem_per_col) if needs_sf_id: - sf_id = Tx.meta_var(analyzer.simplify(tvm.tirx.floormod(sfa_tcol, SFA_elem_per_col))) # noqa: E501 - Tx.cuda.runtime_instr_desc(Tx.address_of(descI_in), sf_id) + sf_id = T.meta_var(analyzer.simplify(tvm.tirx.floormod(sfa_tcol, SFA_elem_per_col))) # noqa: E501 + T.cuda.runtime_instr_desc(T.address_of(descI_in), sf_id) tmem_col = tmem_offset_32b + ni * (N_mma_phys_cols // C_elem_per_32b) if elect_pred: - Tx.ptx.tcgen05.mma.block_scale( - Tx.cuda.get_tmem_addr(tmem_addr, mi * M_mma, tmem_col), + T.ptx.tcgen05.mma.block_scale( + T.cuda.get_tmem_addr(tmem_addr, mi * M_mma, tmem_col), a_val, descB_val, sfa_addr, sfb_addr, descI_in, @@ -861,28 +861,28 @@ def main_impl(descA_in, descB_in, descI_in): enable_input_d=should_accum, ) else: - # Wrap each per-MMA operand in ``Tx.meta_var`` so the parser inlines - # the value directly into the ``Tx.ptx.tcgen05.mma`` call instead of + # Wrap each per-MMA operand in ``T.meta_var`` so the parser inlines + # the value directly into the ``T.ptx.tcgen05.mma`` call instead of # materializing it into a fresh ``alignas(64) T x[1]; x[0] = expr`` # local. Without this wrap each unrolled MMA emits 4 throw-away # 1-element local arrays (``a_val_ptr``, ``descB_val_ptr``, # ``should_accum_ptr``, ``tmem_col_ptr``) which ptxas cannot fold # back into the operand and the resulting LMEM round-trips show up # on the fa4 hot path. - @Tx.inline + @T.inline def main_impl(descA_in, descB_in, descI_in): - for mi in Tx.unroll(M_tiles): - for ni in Tx.unroll(N_tiles): - for ki in Tx.unroll(K_iters): - a_val = Tx.meta_var(_a_operand(mi, ki, descA_in)) - descB_val = Tx.meta_var(_b_desc_val(descB_in, ni, ki)) - should_accum = Tx.meta_var(tvm.tirx.any(ki != 0, accum_expr)) - tmem_col = Tx.meta_var( + for mi in T.unroll(M_tiles): + for ni in T.unroll(N_tiles): + for ki in T.unroll(K_iters): + a_val = T.meta_var(_a_operand(mi, ki, descA_in)) + descB_val = T.meta_var(_b_desc_val(descB_in, ni, ki)) + should_accum = T.meta_var(tvm.tirx.any(ki != 0, accum_expr)) + tmem_col = T.meta_var( tmem_offset_32b + ni * (N_mma_phys_cols // C_elem_per_32b) ) if elect_pred: - Tx.ptx.tcgen05.mma( - Tx.cuda.get_tmem_addr(tmem_addr, mi * M_mma, tmem_col), + T.ptx.tcgen05.mma( + T.cuda.get_tmem_addr(tmem_addr, mi * M_mma, tmem_col), a_val, descB_val, descI_in, d_dtype="float32", a_dtype=A_type, b_dtype=B_type, use_a_tmem=a_is_tmem, cta_group=cta_group, @@ -892,14 +892,14 @@ def main_impl(descA_in, descB_in, descI_in): descA_val = None if a_is_tmem else descA_buf[0] if descI is not None: - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): main_impl(descA_val, descB_buf[0], descI) elif is_block_scaled: - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - descI_local: Tx.uint32 - Tx.ptx.tcgen05.encode_instr_descriptor_block_scaled(Tx.address_of(descI_local), d_dtype=C_type, a_dtype=A_type, b_dtype=B_type, sfa_dtype=SFA_type, sfb_dtype=SFB_type, # noqa: E501, F821 + descI_local: T.uint32 + T.ptx.tcgen05.encode_instr_descriptor_block_scaled(T.address_of(descI_local), d_dtype=C_type, a_dtype=A_type, b_dtype=B_type, sfa_dtype=SFA_type, sfb_dtype=SFB_type, # noqa: E501, F821 sfa_tmem_addr=SFA_init_addr, sfb_tmem_addr=SFB_init_addr, # noqa: E501 M=M_mma * cta_group, N=N_mma, K=MMA_K, trans_a=a_mn_major, trans_b=b_mn_major, n_cta_groups=cta_group) # noqa: E501 main_impl(descA_val, descB_buf[0], descI_local) # noqa: F821 @@ -920,7 +920,7 @@ def impl(): ) descI_const = tvm.tirx.const(descI_value, "uint32") - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): main_impl(descA_val, descB_buf[0], descI_const) # fmt: on @@ -939,10 +939,10 @@ def impl(): # # After (encodes instruction descriptor + calls tcgen05.mma): # descI_local: uint32 -# Tx.ptx.tcgen05.encode_instr_descriptor( +# T.ptx.tcgen05.encode_instr_descriptor( # &descI_local, C_type="f32", A_type="f16", B_type="f16", # M=64, N=256, MMA_K=64, transA=False, transB=True, cta_group=1) -# Tx.ptx.tcgen05.mma(descA_buf[0], descB_buf[0], descI_local) +# T.ptx.tcgen05.mma(descA_buf[0], descB_buf[0], descI_local) # # Before (TilePrimitiveCall — block-scaled fp8 MMA): # Tx.gemm_async(C_tmem, A_smem, B_smem, @@ -950,7 +950,7 @@ def impl(): # # A/B: shared float8_e4m3, SFA/SFB: tmem float8_e8m0fnu # # After (adds scale factor descriptors): -# Tx.ptx.tcgen05.mma(descA, descB, descI, +# T.ptx.tcgen05.mma(descA, descB, descI, # scale_A=sfA_desc, scale_B=sfB_desc) # # Scale factor layout (sf_tmem_layout) must match tcgen05 hardware requirements: diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py b/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py index 3907fe150201..8bee97aae546 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py @@ -74,7 +74,7 @@ import math from tvm.runtime import DataType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import Buffer, BufferRegion, IntImm, PrimFunc from tvm.tirx.layout import TileLayout, _flatten_coord from tvm.tirx.operator.tile_primitive import DispatchContext, fail, register_dispatch @@ -306,27 +306,27 @@ def _project(iter_idx, st_list): dtype = src_buf.dtype # fmt: off - @Tx.prim_func + @T.prim_func def impl(): - warp_size = Tx.meta_var(32) - lane_id = Tx.meta_var(tid_x % warp_size) - regs = Tx.alloc_buffer((P,), dtype, scope="local") + warp_size = T.meta_var(32) + lane_id = T.meta_var(tid_x % warp_size) + regs = T.alloc_buffer((P,), dtype, scope="local") # Phase 1: read via L_src - for r in Tx.unroll(0, P): - j = Tx.meta_var(r ^ ((lane_id >> shift) & mask)) - flat = Tx.meta_var(lane_id + j * warp_size) - iter_idx = Tx.meta_var(get_indices(flat, [0] * len(extent), extent)) - src_idx = Tx.meta_var(_project(iter_idx, src_st)) + for r in T.unroll(0, P): + j = T.meta_var(r ^ ((lane_id >> shift) & mask)) + flat = T.meta_var(lane_id + j * warp_size) + iter_idx = T.meta_var(get_indices(flat, [0] * len(extent), extent)) + src_idx = T.meta_var(_project(iter_idx, src_st)) regs[r] = src_buf[tuple(src_idx)] - Tx.cuda.warp_sync() + T.cuda.warp_sync() # Phase 2: write via L_dst - for r in Tx.unroll(0, P): - j = Tx.meta_var(r ^ ((lane_id >> shift) & mask)) - flat = Tx.meta_var(lane_id + j * warp_size) - iter_idx = Tx.meta_var(get_indices(flat, [0] * len(extent), extent)) - dst_idx = Tx.meta_var(_project(iter_idx, dst_st)) + for r in T.unroll(0, P): + j = T.meta_var(r ^ ((lane_id >> shift) & mask)) + flat = T.meta_var(lane_id + j * warp_size) + iter_idx = T.meta_var(get_indices(flat, [0] * len(extent), extent)) + dst_idx = T.meta_var(_project(iter_idx, dst_st)) dst_buf[tuple(dst_idx)] = regs[r] - Tx.cuda.warp_sync() + T.cuda.warp_sync() # fmt: on return impl @@ -344,7 +344,7 @@ def impl(): # projects back onto ``buf.shape`` via mixed-radix grouping for the emit. # # Before (TilePrimitiveCall): -# with Tx.warp(): +# with T.warp(): # # SFA_smem: u32 (PIPE, BLK_SFA//32, 32), layout shard 4D # # (PIPE, BLK_SFA//128, 4, 32) strides (BLK_SFA, 128, 32, 1) # # SFA_post: same shape; layout shard 4D, strides (BLK_SFA, 128, 1, 4) @@ -352,19 +352,19 @@ def impl(): # # After (BLK_SFA=128, P=4, k=2, shift=3): # lane_id = threadIdx.x % 32 -# regs = Tx.alloc_buffer((4,), "uint32", scope="local") -# for r in Tx.unroll(4): +# regs = T.alloc_buffer((4,), "uint32", scope="local") +# for r in T.unroll(4): # j = r ^ ((lane_id >> 3) & 0x3) # flat = lane_id + j * 32 # (g, l) = decompose(flat, extent=[4, 32]) # regs[r] = src[ks, g, l] -# Tx.cuda.warp_sync() -# for r in Tx.unroll(4): +# T.cuda.warp_sync() +# for r in T.unroll(4): # j = r ^ ((lane_id >> 3) & 0x3) # flat = lane_id + j * 32 # (g, l) = decompose(flat, extent=[4, 32]) # dst[ks, g, l] = regs[r] -# Tx.cuda.warp_sync() +# T.cuda.warp_sync() @register_dispatch( "permute_layout", "cuda", diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/local.py b/python/tvm/tirx/operator/tile_primitive/cuda/reduction/local.py index 9fe7f152704e..b05618f15371 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/local.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/reduction/local.py @@ -25,12 +25,11 @@ (_emit_reduction_local_thread_wise): Before: - with Tx.thread(): - Tx.sum(B_local[0:2, 0:3], A_local[0:2, 0:3, 0:4], [-1], False) + Tx.sum(B_local[0:2, 0:3], A_local[0:2, 0:3, 0:4], [-1], False) After (scheduled PrimFunc, spatial_len=6, reduction_len=4): for spa in range(6): - B_local[spa] = Tx.float32(0.0) # init (skipped if accum) + B_local[spa] = T.float32(0.0) # init (skipped if accum) for red in range(4): B_local[spa] = B_local[spa] + A_local[spa * 4 + red] @@ -44,15 +43,14 @@ accum=True + shuffle: saves old dst before reduce+shuffle, combines after (warp only). Before: - with Tx.warp(): - Tx.sum(red_view[0:16, 0:4], acc_view[0:16, 0:128], [-1], False, - thread_reduce=True) + Tx.warp.sum(red_view[0:16, 0:4], acc_view[0:16, 0:128], [-1], False, + thread_reduce=True) After (scheduled PrimFunc, local_total=2, local_red=32, 2 shuffle steps): src_local = acc_view.view(64) dst_local = red_view.view(2) for spa in range(2): - dst_local[spa] = Tx.float32(0.0) + dst_local[spa] = T.float32(0.0) for red in range(32): dst_local[spa] = dst_local[spa] + src_local[...] dst_local[spa] = dst_local[spa] + shfl_xor(..., 1, 32, 32) @@ -64,7 +62,7 @@ from typing import Any from tvm.arith.analyzer import Analyzer -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimFunc from tvm.tirx.layout import TileLayout, laneid from tvm.tirx.operator.tile_primitive import DispatchContext, fail @@ -137,15 +135,14 @@ def _gen_warp_shuffle_reduce(src, dst, reduce_width, local_elems, accum, op_type op_str = _REDUCE_OP_TO_STR[op_type] # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - with Tx.thread(): - src_local = src.local(local_elems) - dst_local = dst.local(local_elems) - for k in Tx.serial(local_elems): - if not is_same_buffer: - dst_local[k] = src_local[k] - dst_local[k] = Tx.cuda.warp_reduce(dst_local[k], op_str, reduce_width) + src_local = src.local(local_elems) + dst_local = dst.local(local_elems) + for k in T.serial(local_elems): + if not is_same_buffer: + dst_local[k] = src_local[k] + dst_local[k] = T.cuda.warp_reduce(dst_local[k], op_str, reduce_width) # fmt: on return impl @@ -271,16 +268,15 @@ def get_src_indices(spa_fused, red_fused): return full # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - with Tx.thread(): - for spa in Tx.serial(spatial_len): - dst_idx = Tx.meta_var(get_indices(spa, dst_st, dst_extent)) - if not accum: - dst[tuple(dst_idx)] = init_value - for red in Tx.serial(reduction_len): - src_idx = Tx.meta_var(get_src_indices(spa, red)) - dst[tuple(dst_idx)] = op_func(dst[tuple(dst_idx)], src[tuple(src_idx)]) + for spa in T.serial(spatial_len): + dst_idx = T.meta_var(get_indices(spa, dst_st, dst_extent)) + if not accum: + dst[tuple(dst_idx)] = init_value + for red in T.serial(reduction_len): + src_idx = T.meta_var(get_src_indices(spa, red)) + dst[tuple(dst_idx)] = op_func(dst[tuple(dst_idx)], src[tuple(src_idx)]) # fmt: on return impl @@ -346,10 +342,10 @@ def _get_src_local_index(dst_fused, red_fused): in_place = dst.same_as(src) def shuffle_data(mask, dst_local, dst_idx): - @Tx.inline + @T.inline def inner_shuffle(v, shuffle_mask): dst_local[tuple(dst_idx)] = op_func( - v, Tx.tvm_warp_shuffle_xor(mask, v, shuffle_mask, 32, 32) + v, T.tvm_warp_shuffle_xor(mask, v, shuffle_mask, 32, 32) ) for i in range(len(shuffle_masks)): @@ -359,43 +355,41 @@ def inner_shuffle(v, shuffle_mask): # fmt: off if need_save_accum: - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - with Tx.thread(): - src_local = src.local(*src_local_shape) - dst_local = dst.local(*dst_local_shape) - old_val = Tx.alloc_buffer([1], dtype, scope="local") - - for spa in Tx.serial(dst_local_total): - dst_idx = Tx.meta_var(get_indices(spa, dst_local_st, dst_local_ext)) - old_val[0] = dst_local[tuple(dst_idx)] - if not in_place: - dst_local[tuple(dst_idx)] = init_value - for red in Tx.serial(reduction_local_total): - src_idx = Tx.meta_var(_get_src_local_index(spa, red)) - dst_local[tuple(dst_idx)] = op_func(dst_local[tuple(dst_idx)], src_local[tuple(src_idx)]) # noqa: E501 - if shuffle: - mask = Tx.tvm_warp_activemask() - shuffle_data(mask, dst_local, dst_idx) - dst_local[tuple(dst_idx)] = op_func(dst_local[tuple(dst_idx)], old_val[0]) + src_local = src.local(*src_local_shape) + dst_local = dst.local(*dst_local_shape) + old_val = T.alloc_buffer([1], dtype, scope="local") + + for spa in T.serial(dst_local_total): + dst_idx = T.meta_var(get_indices(spa, dst_local_st, dst_local_ext)) + old_val[0] = dst_local[tuple(dst_idx)] + if not in_place: + dst_local[tuple(dst_idx)] = init_value + for red in T.serial(reduction_local_total): + src_idx = T.meta_var(_get_src_local_index(spa, red)) + dst_local[tuple(dst_idx)] = op_func(dst_local[tuple(dst_idx)], src_local[tuple(src_idx)]) # noqa: E501 + if shuffle: + mask = T.tvm_warp_activemask() + shuffle_data(mask, dst_local, dst_idx) + dst_local[tuple(dst_idx)] = op_func(dst_local[tuple(dst_idx)], old_val[0]) else: - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - with Tx.thread(): - src_local = src.local(*src_local_shape) - dst_local = dst.local(*dst_local_shape) - - for spa in Tx.serial(dst_local_total): - dst_idx = Tx.meta_var(get_indices(spa, dst_local_st, dst_local_ext)) - if not in_place: - if not accum: - dst_local[tuple(dst_idx)] = init_value - for red in Tx.serial(reduction_local_total): - src_idx = Tx.meta_var(_get_src_local_index(spa, red)) - dst_local[tuple(dst_idx)] = op_func(dst_local[tuple(dst_idx)], src_local[tuple(src_idx)]) # noqa: E501 - if shuffle: - mask = Tx.tvm_warp_activemask() - shuffle_data(mask, dst_local, dst_idx) + src_local = src.local(*src_local_shape) + dst_local = dst.local(*dst_local_shape) + + for spa in T.serial(dst_local_total): + dst_idx = T.meta_var(get_indices(spa, dst_local_st, dst_local_ext)) + if not in_place: + if not accum: + dst_local[tuple(dst_idx)] = init_value + for red in T.serial(reduction_local_total): + src_idx = T.meta_var(_get_src_local_index(spa, red)) + dst_local[tuple(dst_idx)] = op_func(dst_local[tuple(dst_idx)], src_local[tuple(src_idx)]) # noqa: E501 + if shuffle: + mask = T.tvm_warp_activemask() + shuffle_data(mask, dst_local, dst_idx) # fmt: on return impl diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/shared.py b/python/tvm/tirx/operator/tile_primitive/cuda/reduction/shared.py index 587688a324d8..8bee09ecc3f0 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/shared.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/reduction/shared.py @@ -28,11 +28,10 @@ Each group of threads reduces one spatial position via shfl_xor. Before: - with Tx.cta(): - Tx.sum(B_smem[0:4], A_smem[0:4, 0:8], [-1], False) + Tx.cta.sum(B_smem[0:4], A_smem[0:4, 0:8], [-1], False) After (scheduled PrimFunc, group_size=8, spatial_par=4): - thread_data[0] = Tx.float32(0.0) + thread_data[0] = T.float32(0.0) thread_data[0] = thread_data[0] + A_smem[tid_in_scope] # gather # log2(8) = 3 shuffle-xor steps with width=8 thread_data[0] = thread_data[0] + shfl_xor(thread_data[0], 1, 8, 32) @@ -45,12 +44,11 @@ Before: if tid == 65: - with Tx.thread(): - Tx.sum(B_smem[0:4], A_smem[0:4, 0:8], [-1], False) + Tx.sum(B_smem[0:4], A_smem[0:4, 0:8], [-1], False) After (scheduled PrimFunc): for spa in range(4): - B_smem[spa] = Tx.float32(0.0) # init (skipped if accum) + B_smem[spa] = T.float32(0.0) # init (skipped if accum) for red in range(8): B_smem[spa] = B_smem[spa] + A_smem[spa * 8 + red] """ @@ -60,7 +58,7 @@ import operator from tvm.arith.analyzer import Analyzer -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext, fail from tvm.tirx.operator.tile_primitive.dispatcher import predicate, register_dispatch @@ -169,46 +167,46 @@ def get_tid_in_scope(): return 0 def shuffle_data(thread_data): - @Tx.inline + @T.inline def inner_shuffle(mask, v, shuffle_mask): - v[0] = op_func(v[0], Tx.tvm_warp_shuffle_xor(mask, v[0], shuffle_mask, group_size, 32)) + v[0] = op_func(v[0], T.tvm_warp_shuffle_xor(mask, v[0], shuffle_mask, group_size, 32)) if n_shuffles > 0: - mask = Tx.tvm_warp_activemask() + mask = T.tvm_warp_activemask() for i in range(n_shuffles): inner_shuffle(mask, thread_data, 1 << i) - @Tx.inline + @T.inline def sync(): if exec_scope_name == "cta": - Tx.cuda.cta_sync() + T.cuda.cta_sync() elif exec_scope_name == "warpgroup": - Tx.cuda.warpgroup_sync(8) # TODO: fix this hardcoded value + T.cuda.warpgroup_sync(8) # TODO: fix this hardcoded value elif exec_scope_name == "warp": - Tx.cuda.warp_sync() + T.cuda.warp_sync() elif exec_scope_name == "thread": pass # fmt: off - @Tx.prim_func + @T.prim_func def impl(): tid_in_scope = get_tid_in_scope() - thread_data = Tx.alloc_buffer([1], dtype=dtype, scope="local") - group_id = Tx.meta_var(Tx.floordiv(tid_in_scope, group_size)) - lane_in_grp = Tx.meta_var(tid_in_scope % group_size) - for step in Tx.serial(Tx.ceildiv(spatial_len, spatial_par)): - spa_fused = Tx.meta_var(step * spatial_par + group_id) + thread_data = T.alloc_buffer([1], dtype=dtype, scope="local") + group_id = T.meta_var(T.floordiv(tid_in_scope, group_size)) + lane_in_grp = T.meta_var(tid_in_scope % group_size) + for step in T.serial(T.ceildiv(spatial_len, spatial_par)): + spa_fused = T.meta_var(step * spatial_par + group_id) if spa_fused < spatial_len: thread_data[0] = init_value - for t in Tx.serial(Tx.ceildiv(reduction_len, group_size)): - red_fused = Tx.meta_var(t * group_size + lane_in_grp) + for t in T.serial(T.ceildiv(reduction_len, group_size)): + red_fused = T.meta_var(t * group_size + lane_in_grp) if red_fused < reduction_len: - src_indices = Tx.meta_var(build_src_indices(spa_fused, red_fused, spatial_dims, reduce_dims, src_extent, src_st)) # noqa: E501 + src_indices = T.meta_var(build_src_indices(spa_fused, red_fused, spatial_dims, reduce_dims, src_extent, src_st)) # noqa: E501 thread_data[0] = op_func(thread_data[0], src[tuple(src_indices)]) shuffle_data(thread_data) if lane_in_grp == 0: - dst_indices = Tx.meta_var(get_indices(spa_fused, dst_st, dst_extent)) - dst[tuple(dst_indices)] = Tx.if_then_else(Tx.bool(accum), op_func(dst[tuple(dst_indices)], thread_data[0]), thread_data[0]) # noqa: E501 + dst_indices = T.meta_var(get_indices(spa_fused, dst_st, dst_extent)) + dst[tuple(dst_indices)] = T.if_then_else(T.bool(accum), op_func(dst[tuple(dst_indices)], thread_data[0]), thread_data[0]) # noqa: E501 sync() # fmt: on @@ -238,14 +236,14 @@ def _emit_reduction_shared_thread( assert op_func is not None init_value = reduce_default_value_table(dtype).get(reduce_op) - @Tx.prim_func + @T.prim_func def impl(): - for spa_fused in Tx.serial(spatial_len): - dst_indices = Tx.meta_var(get_indices(spa_fused, dst_st, dst_extent)) + for spa_fused in T.serial(spatial_len): + dst_indices = T.meta_var(get_indices(spa_fused, dst_st, dst_extent)) if not accum: dst[tuple(dst_indices)] = init_value - for red_fused in Tx.serial(reduction_len): - src_indices = Tx.meta_var( + for red_fused in T.serial(reduction_len): + src_indices = T.meta_var( build_src_indices( spa_fused, red_fused, spatial_dims, reduce_dims, src_extent, src_st ) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/sm100_packed.py b/python/tvm/tirx/operator/tile_primitive/cuda/reduction/sm100_packed.py index 70de6b37fab3..5b8540ecac25 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/sm100_packed.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/reduction/sm100_packed.py @@ -23,24 +23,21 @@ SM100+ (uses packed PTX instructions not available on older GPUs). Before (TilePrimitiveCall -- sum example): - with Tx.thread(): - Tx.sum(dst_local[0:1], src_local[0:32]) # float32, reduce 32 -> 1 + Tx.sum(dst_local[0:1], src_local[0:32]) # float32, reduce 32 -> 1 (thread scope) After -- packed_add_sum (uses add.f32x2 to reduce pairs): - with Tx.thread(): - # Iteratively reduce: 32 -> 16 -> 8 -> 4 -> 2 -> 1 - # Each step: add.f32x2 combines adjacent pairs - for i in Tx.serial(16): - Tx.cuda.func_call("add_f32x2", &buf[i*2], &buf[i*2], &buf[i*2+2]) - # ... repeat halving until scalar result - dst_local[0] = buf[0] + # Iteratively reduce: 32 -> 16 -> 8 -> 4 -> 2 -> 1 + # Each step: add.f32x2 combines adjacent pairs + for i in T.serial(16): + T.cuda.func_call("add_f32x2", &buf[i*2], &buf[i*2], &buf[i*2+2]) + # ... repeat halving until scalar result + dst_local[0] = buf[0] After -- 3input_maxmin (uses 3-input PTX max/min): - with Tx.thread(): - # Tree reduction with 3-input instructions: - # max(a, b, c) in one PTX instruction - for i in Tx.serial(n // 3): - Tx.cuda.func_call("max3_f32", &buf[i*3], &buf[i*3+1], &buf[i*3+2]) + # Tree reduction with 3-input instructions: + # max(a, b, c) in one PTX instruction + for i in T.serial(n // 3): + T.cuda.func_call("max3_f32", &buf[i*3], &buf[i*3+1], &buf[i*3+2]) With accum=True: accumulator folded into first element/pair of the reduction. """ @@ -48,7 +45,7 @@ import functools import operator -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext from tvm.tirx.operator.tile_primitive.dispatcher import predicate, register_dispatch @@ -91,54 +88,53 @@ def _emit_reduction_local_thread_packed_add_sum( remainder_base = num_full_chunks * 8 # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - with Tx.thread(): - local_sum = Tx.alloc_buffer([8], dtype, scope="local") - # First pass: copy first 8 elements (with optional accumulator) - for i in Tx.unroll(8): - if accum and i == 0: - local_sum[i] = src[src_base + i] + dst[tuple(dst_st)] - else: - local_sum[i] = src[src_base + i] - - # Process remaining full chunks of 8 - for outer in Tx.serial(num_full_chunks - 1): - for j in Tx.unroll(4): - Tx.ptx.add_f32x2( - Tx.address_of(local_sum[2 * j]), - Tx.cuda.make_float2(local_sum[2 * j], local_sum[2 * j + 1]), - Tx.cuda.make_float2( - src[src_base + 8 * (outer + 1) + 2 * j], - src[src_base + 8 * (outer + 1) + 2 * j + 1], - ), - ftz=True, - ) - - # Handle remainder elements (0 to 7) - for i in Tx.serial(remainder): - local_sum[0] = local_sum[0] + src[src_base + remainder_base + i] - - # Final packed add sum: 8 -> 4 -> 2 -> 1 - Tx.ptx.add_f32x2( - Tx.address_of(local_sum[0]), - Tx.cuda.make_float2(local_sum[0], local_sum[1]), - Tx.cuda.make_float2(local_sum[2], local_sum[3]), - ftz=True, - ) - Tx.ptx.add_f32x2( - Tx.address_of(local_sum[4]), - Tx.cuda.make_float2(local_sum[4], local_sum[5]), - Tx.cuda.make_float2(local_sum[6], local_sum[7]), - ftz=True, - ) - Tx.ptx.add_f32x2( - Tx.address_of(local_sum[0]), - Tx.cuda.make_float2(local_sum[0], local_sum[1]), - Tx.cuda.make_float2(local_sum[4], local_sum[5]), - ftz=True, - ) - dst[tuple(dst_st)] = local_sum[0] + local_sum[1] + local_sum = T.alloc_buffer([8], dtype, scope="local") + # First pass: copy first 8 elements (with optional accumulator) + for i in T.unroll(8): + if accum and i == 0: + local_sum[i] = src[src_base + i] + dst[tuple(dst_st)] + else: + local_sum[i] = src[src_base + i] + + # Process remaining full chunks of 8 + for outer in T.serial(num_full_chunks - 1): + for j in T.unroll(4): + T.ptx.add_f32x2( + T.address_of(local_sum[2 * j]), + T.cuda.make_float2(local_sum[2 * j], local_sum[2 * j + 1]), + T.cuda.make_float2( + src[src_base + 8 * (outer + 1) + 2 * j], + src[src_base + 8 * (outer + 1) + 2 * j + 1], + ), + ftz=True, + ) + + # Handle remainder elements (0 to 7) + for i in T.serial(remainder): + local_sum[0] = local_sum[0] + src[src_base + remainder_base + i] + + # Final packed add sum: 8 -> 4 -> 2 -> 1 + T.ptx.add_f32x2( + T.address_of(local_sum[0]), + T.cuda.make_float2(local_sum[0], local_sum[1]), + T.cuda.make_float2(local_sum[2], local_sum[3]), + ftz=True, + ) + T.ptx.add_f32x2( + T.address_of(local_sum[4]), + T.cuda.make_float2(local_sum[4], local_sum[5]), + T.cuda.make_float2(local_sum[6], local_sum[7]), + ftz=True, + ) + T.ptx.add_f32x2( + T.address_of(local_sum[0]), + T.cuda.make_float2(local_sum[0], local_sum[1]), + T.cuda.make_float2(local_sum[4], local_sum[5]), + ftz=True, + ) + dst[tuple(dst_st)] = local_sum[0] + local_sum[1] # fmt: on return impl @@ -162,9 +158,7 @@ def _emit_reduction_local_thread_3input_maxmin( reduction_len = functools.reduce(operator.mul, src_extent, 1) op_func = reduce_op_table[reduce_op] - reduce3_func = ( - Tx.ptx.reduce3_max_f32 if reduce_op == ReduceOpType.MAX else Tx.ptx.reduce3_min_f32 - ) + reduce3_func = T.ptx.reduce3_max_f32 if reduce_op == ReduceOpType.MAX else T.ptx.reduce3_min_f32 src_base = src_st[0] num_full_chunks = reduction_len // 8 @@ -172,33 +166,32 @@ def _emit_reduction_local_thread_3input_maxmin( remainder_base = num_full_chunks * 8 # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def impl(): - with Tx.thread(): - temp = Tx.alloc_buffer([4], dtype, scope="local") - # First pass: process first 8 elements into 4 temps - for i in Tx.unroll(4): - if accum and i == 0: - temp[i] = reduce3_func(src[src_base + 2 * i], src[src_base + 2 * i + 1], dst[tuple(dst_st)]) # noqa: E501 - else: - temp[i] = op_func(src[src_base + 2 * i], src[src_base + 2 * i + 1]) - - # Process remaining full chunks of 8 - for outer in Tx.serial(num_full_chunks - 1): - for i in Tx.unroll(4): - temp[i] = reduce3_func( - temp[i], - src[src_base + 8 * (outer + 1) + 2 * i], - src[src_base + 8 * (outer + 1) + 2 * i + 1], - ) - - # Process remainder elements (0 to 7 elements) - for i in Tx.serial(remainder): - temp[0] = op_func(temp[0], src[src_base + remainder_base + i]) - - # Final merge: combine 4 temps into result - dst[tuple(dst_st)] = op_func(temp[0], temp[1]) - dst[tuple(dst_st)] = reduce3_func(dst[tuple(dst_st)], temp[2], temp[3]) + temp = T.alloc_buffer([4], dtype, scope="local") + # First pass: process first 8 elements into 4 temps + for i in T.unroll(4): + if accum and i == 0: + temp[i] = reduce3_func(src[src_base + 2 * i], src[src_base + 2 * i + 1], dst[tuple(dst_st)]) # noqa: E501 + else: + temp[i] = op_func(src[src_base + 2 * i], src[src_base + 2 * i + 1]) + + # Process remaining full chunks of 8 + for outer in T.serial(num_full_chunks - 1): + for i in T.unroll(4): + temp[i] = reduce3_func( + temp[i], + src[src_base + 8 * (outer + 1) + 2 * i], + src[src_base + 8 * (outer + 1) + 2 * i + 1], + ) + + # Process remainder elements (0 to 7 elements) + for i in T.serial(remainder): + temp[0] = op_func(temp[0], src[src_base + remainder_base + i]) + + # Final merge: combine 4 temps into result + dst[tuple(dst_st)] = op_func(temp[0], temp[1]) + dst[tuple(dst_st)] = reduce3_func(dst[tuple(dst_st)], temp[2], temp[3]) # fmt: on return impl diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/utils.py b/python/tvm/tirx/operator/tile_primitive/cuda/reduction/utils.py index f575aa7cf42f..b53b5d181068 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/utils.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/reduction/utils.py @@ -22,7 +22,7 @@ import operator from tvm.arith.analyzer import Analyzer -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion from tvm.tirx.operator.tile_primitive import DispatchContext from tvm.tirx.stmt import TilePrimitiveCall @@ -32,16 +32,16 @@ reduce_op_table = { ReduceOpType.SUM: lambda a, b: a + b, - ReduceOpType.MAX: Tx.max, - ReduceOpType.MIN: Tx.min, + ReduceOpType.MAX: T.max, + ReduceOpType.MIN: T.min, } def reduce_default_value_table(dtype): return { ReduceOpType.SUM: 0.0, - ReduceOpType.MAX: Tx.min_value(dtype), - ReduceOpType.MIN: Tx.max_value(dtype), + ReduceOpType.MAX: T.min_value(dtype), + ReduceOpType.MIN: T.max_value(dtype), } diff --git a/python/tvm/tirx/operator/tile_primitive/ops.py b/python/tvm/tirx/operator/tile_primitive/ops.py index 97f16def6e55..21e02793fd02 100644 --- a/python/tvm/tirx/operator/tile_primitive/ops.py +++ b/python/tvm/tirx/operator/tile_primitive/ops.py @@ -19,12 +19,56 @@ from tvm.ir import Op from tvm.tirx import PrimExpr -from tvm.tirx.stmt import TilePrimitiveCall, _ffi_api, normalize_const_arg +from tvm.tirx.stmt import TilePrimitiveCall + +_DISPATCH_OPS = { + "zero", + "sqrt", + "exp", + "exp2", + "reciprocal", + "add", + "sub", + "mul", + "fdiv", + "maximum", + "minimum", + "copy", + "fill", + "gemm", + "sum", + "max", + "min", + "memset", + "reduce_negate", + "binary_reduce", + "unary_reduce", + "binary_chain", + "select", + "cast", + "fma", + "silu", +} +_COMPOSE_OPS = {"compose_op"} +_ASYNC_OPS = {"copy_async", "gemm_async"} +_MARKER_OPS = {"tvm_kernel_replace_point"} + + +def _tile_primitive_kind(op_name: str) -> str: + if op_name in _DISPATCH_OPS: + return "dispatch" + if op_name in _COMPOSE_OPS: + return "compose" + if op_name in _ASYNC_OPS: + return "async" + if op_name in _MARKER_OPS: + return "marker" + return "dispatch" def get_tirx_op(op_name: str): assert isinstance(op_name, str) - return Op.get("tirx." + op_name) + return Op.get("tirx.tile." + op_name) class ArgProperty: @@ -551,30 +595,6 @@ def dsts(self) -> list[PrimExpr]: ) -def _register_permute_layout_op(): - """Register tirx.permute_layout dynamically (Python-only, no C++ rebuild). - - Mirrors the TIRX_DEFINE_DISPATCH_OP macro: marks the op as a TIRx op - and a dispatch op so the well-formed verifier and printer accept it. - """ - - tirx_name = "tirx.permute_layout" - try: - return Op.get(tirx_name) - except Exception: - from tvm.ir import _ffi_api as ir_ffi - from tvm.ir.op import register_op_attr - - ir_ffi.RegisterOp(tirx_name, "Permute the physical layout of a buffer in-place.") - register_op_attr(tirx_name, "TIsTIRxOp", True) - register_op_attr(tirx_name, "TIsDispatchOp", True) - register_op_attr(tirx_name, "TScriptPrinterName", "permute_layout") - return Op.get(tirx_name) - - -_register_permute_layout_op() - - class PermuteLayout(TilePrimitiveCall): """Move data so the buffer's bytes are arranged under a different layout. @@ -605,25 +625,3 @@ def srcs(self) -> list[PrimExpr]: @property def dsts(self) -> list[PrimExpr]: return [self.dst] - - -class GenericOp(TilePrimitiveCall): - """Generic operator for dynamically-resolved TIRx ops.""" - - def __init__(self, *args, op_name=None, workspace=None, config=None, dispatch=None): - workspace = workspace or {} - config = config or {} - tirx_name = f"tirx.{op_name}" - try: - resolved_op = Op.get(tirx_name) - except Exception: - from tvm.ir import _ffi_api as ir_ffi - from tvm.ir.op import register_op_attr - - ir_ffi.RegisterOp(tirx_name, f"Dynamic tirx op: {op_name}") - register_op_attr(tirx_name, "TIsTIRxOp", True) - resolved_op = Op.get(tirx_name) - args = list(map(normalize_const_arg, args)) - self.__init_handle_by_constructor__( - _ffi_api.TilePrimitiveCall, resolved_op, args, workspace, config, dispatch - ) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/binary/default.py b/python/tvm/tirx/operator/tile_primitive/trn/binary/default.py index 09b70ce16667..3fa565b1f41d 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/binary/default.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/binary/default.py @@ -17,7 +17,7 @@ """Implementation of binary operator dispatches.""" -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import FloatImm, PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext, fail from tvm.tirx.stmt import TilePrimitiveCall @@ -32,8 +32,8 @@ def binary_trn( op: TilePrimitiveCall, binary_op: MapOpType, sctx: DispatchContext ) -> PrimFunc | None: """Generate a binary operation schedule for Trainium.""" - if not (sctx.is_trn() and sctx.scope_kind == "kernel"): - fail("requires Trainium target and kernel exec_scope") + if not (sctx.is_trn() and sctx.scope_kind == "thread"): + fail("requires Trainium target and thread exec_scope") assert binary_op in binary_map_ops, f"Unsupported binary operation {binary_op}" @@ -53,9 +53,9 @@ def binary_trn( dst, src1 = _dst.buffer, _src1.buffer src2 = None if CONST is not None else _src2.buffer - p_var = Tx.Var("P", "int32") - b_var = Tx.Var("B", "int32") - f_var = Tx.Var("F", "int32") + p_var = T.Var("P", "int32") + b_var = T.Var("B", "int32") + f_var = T.Var("F", "int32") p_size = dst.layout.size("P") inst_size_limit = op.config.get("max_inst_size", 512) inst_repr.bound_inst_size(inst_size_limit, analyzer) @@ -66,26 +66,26 @@ def binary_trn( opcode = binary_map_ops[binary_op] # Select appropriate NKI function based on instruction type - _func = Tx.nki.tensortensor if inst_types[0] == InstType.TENSOR_TENSOR else Tx.nki.tensorscalar + _func = T.nki.tensortensor if inst_types[0] == InstType.TENSOR_TENSOR else T.nki.tensorscalar def func(*args): return _func(*args, reverse[0]) if inst_types[0] == InstType.TENSOR_SCALAR else _func(*args) # Define the implementation function - @Tx.prim_func + @T.prim_func def impl(): - for b_loop in Tx.serial(0, b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst_repr.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst_repr.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, f_var: f_loop, b_var: b_loop}) if inst_gen.make_guard(_dst): - dst_indices = Tx.meta_var(inst_gen.generate_indices(_dst)) - src1_indices = Tx.meta_var(inst_gen.generate_indices(_src1)) + dst_indices = T.meta_var(inst_gen.generate_indices(_dst)) + src1_indices = T.meta_var(inst_gen.generate_indices(_src1)) if CONST is None: - src2_indices = Tx.meta_var(inst_gen.generate_indices(_src2)) - Tx.evaluate( + src2_indices = T.meta_var(inst_gen.generate_indices(_src2)) + T.evaluate( func( dst[tuple(dst_indices)], src1[tuple(src1_indices)], @@ -94,7 +94,7 @@ def impl(): ) ) else: - Tx.evaluate( + T.evaluate( func( dst[tuple(dst_indices)], src1[tuple(src1_indices)], diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_chain.py b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_chain.py index 551731770df3..daa64fe8adb9 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_chain.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_chain.py @@ -17,7 +17,7 @@ """Implementation of BinaryChain dispatch.""" -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimFunc, TilePrimitiveCall from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, register_dispatch from tvm.tirx.operator.tile_primitive.ops import BinaryChain @@ -56,9 +56,9 @@ def binary_chain_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | if reverse[0]: srcs[0], srcs[1] = srcs[1], srcs[0] - p_var = Tx.Var("P", "int32") - b_var = Tx.Var("B", "int32") - f_var = Tx.Var("F", "int32") + p_var = T.Var("P", "int32") + b_var = T.Var("B", "int32") + f_var = T.Var("F", "int32") p_size = output.buffer.layout.size("P") inst_size_limit = op.config.get("max_inst_size", 512) inst_repr.bound_inst_size(inst_size_limit, analyzer) @@ -72,9 +72,9 @@ def binary_chain_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | # Determine operation function based on instruction type func = ( - Tx.nki.scalar_tensor_scalar + T.nki.scalar_tensor_scalar if inst_types[1] == InstType.TENSOR_SCALAR - else Tx.nki.scalar_tensor_tensor + else T.nki.scalar_tensor_tensor ) # Helper function to get source indices @@ -90,17 +90,17 @@ def get_srcs(inst_gen): # Create implementation # fmt: off - @Tx.prim_func + @T.prim_func def impl(): - for b_loop in Tx.serial(0, b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst_repr.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst_repr.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, f_var: f_loop, b_var: b_loop}) - dst_indices = Tx.meta_var(inst_gen.generate_indices(output)) - srcs = Tx.meta_var(get_srcs(inst_gen)) + dst_indices = T.meta_var(inst_gen.generate_indices(output)) + srcs = T.meta_var(get_srcs(inst_gen)) if inst_gen.make_guard(output): - Tx.evaluate(func(dst[tuple(dst_indices)], *srcs, opcode0, opcode1, reverse[0], reverse[1])) # noqa: E501 + T.evaluate(func(dst[tuple(dst_indices)], *srcs, opcode0, opcode1, reverse[0], reverse[1])) # noqa: E501 # fmt: on return impl @@ -115,7 +115,7 @@ def impl(): predicate( "exec_scope", lambda op, sctx: ( - sctx.scope_kind == "kernel", + sctx.scope_kind == "thread", f"unsupported exec_scope {sctx.scope_kind}", ), ) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_reduce.py b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_reduce.py index 770343c10d2d..d0c64d415331 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_reduce.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_reduce.py @@ -17,7 +17,7 @@ """Implementation of BinaryReduce dispatch.""" -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimFunc, TilePrimitiveCall from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, register_dispatch from tvm.tirx.operator.tile_primitive.ops import BinaryReduce @@ -73,10 +73,10 @@ def binary_reduce_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc binary_input1, binary_input2 = binary_input2, binary_input1 # Generate intermediate buffer for reduction if needed - p_var = Tx.Var("P", "int32") - f_var = Tx.Var("F", "int32") - reduction_b_var = Tx.Var("rB", "int32") - spatial_b_var = Tx.Var("sB", "int32") + p_var = T.Var("P", "int32") + f_var = T.Var("F", "int32") + reduction_b_var = T.Var("rB", "int32") + spatial_b_var = T.Var("sB", "int32") p_size = binary_output.buffer.layout.size("P") inst_gen.bind_inst_iter(binary_output, p_var, p_size, 1, False) inst_gen.bind_inst_iter(binary_output, f_var, inst_repr.size, inst_repr.stride, True) @@ -100,49 +100,49 @@ def binary_reduce_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc if reduction_b_extent == 1: # Direct implementation without intermediate buffer # fmt: off - @Tx.prim_func + @T.prim_func def impl(): - for b_loop in Tx.serial(0, spatial_b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst_repr.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, spatial_b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst_repr.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, f_var: f_loop, spatial_b_var: b_loop}) # noqa: E501 - src_1_indices = Tx.meta_var(inst_gen.generate_indices(binary_input1)) - vec_dst_idx = Tx.meta_var(inst_gen.generate_indices(binary_output)) - reduce_dst_idx = Tx.meta_var(inst_gen.generate_indices(reduce_output)) + src_1_indices = T.meta_var(inst_gen.generate_indices(binary_input1)) + vec_dst_idx = T.meta_var(inst_gen.generate_indices(binary_output)) + reduce_dst_idx = T.meta_var(inst_gen.generate_indices(reduce_output)) if inst_gen.make_guard(binary_output): if CONST is None: - src_2_indices = Tx.meta_var(inst_gen.generate_indices(binary_input2)) # noqa: E501 - Tx.nki.tensorscalar_reduce(dst2[tuple(reduce_dst_idx)], dst1[tuple(vec_dst_idx)], src1[tuple(src_1_indices)], src2[tuple(src_2_indices)], binary_opcode, reduce_opcode, reverse[0]) # noqa: E501 + src_2_indices = T.meta_var(inst_gen.generate_indices(binary_input2)) # noqa: E501 + T.nki.tensorscalar_reduce(dst2[tuple(reduce_dst_idx)], dst1[tuple(vec_dst_idx)], src1[tuple(src_1_indices)], src2[tuple(src_2_indices)], binary_opcode, reduce_opcode, reverse[0]) # noqa: E501 else: - Tx.nki.tensorscalar_reduce(dst2[tuple(reduce_dst_idx)], dst1[tuple(vec_dst_idx)], src1[tuple(src_1_indices)], CONST, binary_opcode, reduce_opcode, reverse[0]) # noqa: E501 + T.nki.tensorscalar_reduce(dst2[tuple(reduce_dst_idx)], dst1[tuple(vec_dst_idx)], src1[tuple(src_1_indices)], CONST, binary_opcode, reduce_opcode, reverse[0]) # noqa: E501 # fmt: on else: # Implementation with intermediate buffer # fmt: off - @Tx.prim_func + @T.prim_func def impl(): - for b_loop in Tx.serial(0, spatial_b_extent): - for reduction_b_loop in Tx.serial(0, reduction_b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst_repr.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, spatial_b_extent): + for reduction_b_loop in T.serial(0, reduction_b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst_repr.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, f_var: f_loop, spatial_b_var: b_loop, reduction_b_var: reduction_b_loop}) # noqa: E501 if inst_gen.make_guard(binary_output): - src_1_indices = Tx.meta_var(inst_gen.generate_indices(binary_input1)) # noqa: E501 - vec_dst_idx = Tx.meta_var(inst_gen.generate_indices(binary_output)) # noqa: E501 + src_1_indices = T.meta_var(inst_gen.generate_indices(binary_input1)) # noqa: E501 + vec_dst_idx = T.meta_var(inst_gen.generate_indices(binary_output)) # noqa: E501 if CONST is None: - src_2_indices = Tx.meta_var(inst_gen.generate_indices(binary_input2)) # noqa: E501 - Tx.nki.tensorscalar_reduce(intermediate_buffer[p_loop, reduction_b_loop], dst1[tuple(vec_dst_idx)], src1[tuple(src_1_indices)], src2[tuple(src_2_indices)], binary_opcode, reduce_opcode, reverse[0]) # noqa: E501 + src_2_indices = T.meta_var(inst_gen.generate_indices(binary_input2)) # noqa: E501 + T.nki.tensorscalar_reduce(intermediate_buffer[p_loop, reduction_b_loop], dst1[tuple(vec_dst_idx)], src1[tuple(src_1_indices)], src2[tuple(src_2_indices)], binary_opcode, reduce_opcode, reverse[0]) # noqa: E501 else: - Tx.nki.tensorscalar_reduce(intermediate_buffer[p_loop, reduction_b_loop], dst1[tuple(vec_dst_idx)], src1[tuple(src_1_indices)], CONST, binary_opcode, reduce_opcode, reverse[0]) # noqa: E501 - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, reduction_b_extent, annotations={nki_dim: "F"}): + T.nki.tensorscalar_reduce(intermediate_buffer[p_loop, reduction_b_loop], dst1[tuple(vec_dst_idx)], src1[tuple(src_1_indices)], CONST, binary_opcode, reduce_opcode, reverse[0]) # noqa: E501 + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, reduction_b_extent, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, spatial_b_var: b_loop}) if inst_gen.make_guard(reduce_output): - dst_2_indices = Tx.meta_var(inst_gen.generate_indices(reduce_output)) # noqa: E501 - Tx.nki.tensorreduce(dst2[tuple(dst_2_indices)], intermediate_buffer[p_loop, f_loop], reduce_opcode, False, -1) # noqa: E501 + dst_2_indices = T.meta_var(inst_gen.generate_indices(reduce_output)) + T.nki.tensorreduce(dst2[tuple(dst_2_indices)], intermediate_buffer[p_loop, f_loop], reduce_opcode, False, -1) # noqa: E501 # fmt: on return impl @@ -158,7 +158,7 @@ def impl(): predicate( "exec_scope", lambda op, sctx: ( - sctx.scope_kind == "kernel", + sctx.scope_kind == "thread", f"unsupported exec_scope {sctx.scope_kind}", ), ) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/compose_op.py b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/compose_op.py index 86f39230b365..5fb5a9a20133 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/compose_op.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/compose_op.py @@ -37,7 +37,7 @@ def compose_op_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | N predicate( "exec_scope", lambda op, sctx: ( - sctx.scope_kind == "kernel", + sctx.scope_kind == "thread", f"unsupported exec_scope {sctx.scope_kind}", ), ) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/reduce_negate.py b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/reduce_negate.py index 4112eb1042b9..986e91a2b84d 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/reduce_negate.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/reduce_negate.py @@ -41,7 +41,7 @@ def reduce_negate_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc predicate( "exec_scope", lambda op, sctx: ( - sctx.scope_kind == "kernel", + sctx.scope_kind == "thread", f"unsupported exec_scope {sctx.scope_kind}", ), ) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/unary_reduce.py b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/unary_reduce.py index 1fc801403842..a7c9f86c7b7a 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/unary_reduce.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/unary_reduce.py @@ -17,7 +17,7 @@ """Implementation of UnaryReduce dispatch.""" -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimFunc, TilePrimitiveCall from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, register_dispatch from tvm.tirx.operator.tile_primitive.ops import UnaryReduce @@ -68,10 +68,10 @@ def unary_reduce_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | inst_size_limit = op.config.get("max_inst_size", None) inst_repr.bound_inst_size(inst_size_limit, analyzer) - p_var = Tx.Var("P", "int32") - f_var = Tx.Var("F", "int32") - reduction_b_var = Tx.Var("rB", "int32") - spatial_b_var = Tx.Var("sB", "int32") + p_var = T.Var("P", "int32") + f_var = T.Var("F", "int32") + reduction_b_var = T.Var("rB", "int32") + spatial_b_var = T.Var("sB", "int32") p_size = unary_output.buffer.layout.size("P") inst_gen.bind_inst_iter(unary_output, p_var, p_size, 1, False) inst_gen.bind_inst_iter(unary_output, f_var, inst_repr.size, inst_repr.stride, True) @@ -97,22 +97,22 @@ def unary_reduce_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | if reduction_b_extent == 1: # Direct implementation without intermediate buffer # fmt: off - @Tx.prim_func + @T.prim_func def impl(): - for b_loop in Tx.serial(0, spatial_b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst_repr.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, spatial_b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst_repr.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, f_var: f_loop, spatial_b_var: b_loop}) # noqa: E501 - src_1_indices = Tx.meta_var(inst_gen.generate_indices(unary_input)) - dst_1_indices = Tx.meta_var(inst_gen.generate_indices(unary_output)) - dst_2_indices = Tx.meta_var(inst_gen.generate_indices(reduce_output)) + src_1_indices = T.meta_var(inst_gen.generate_indices(unary_input)) + dst_1_indices = T.meta_var(inst_gen.generate_indices(unary_output)) + dst_2_indices = T.meta_var(inst_gen.generate_indices(reduce_output)) if inst_gen.make_guard(unary_output): if isinstance(bias, BufferRegion): - src_bias_indices = Tx.meta_var(inst_gen.generate_indices(bias)) - Tx.evaluate(Tx.nki.activation_reduce(dst2[tuple(dst_2_indices)], dst1[tuple(dst_1_indices)], src[tuple(src_1_indices)], unary_opcode, reduce_opcode, bias_buffer[tuple(src_bias_indices)], scale)) # noqa: E501 + src_bias_indices = T.meta_var(inst_gen.generate_indices(bias)) + T.evaluate(T.nki.activation_reduce(dst2[tuple(dst_2_indices)], dst1[tuple(dst_1_indices)], src[tuple(src_1_indices)], unary_opcode, reduce_opcode, bias_buffer[tuple(src_bias_indices)], scale)) # noqa: E501 else: - Tx.evaluate(Tx.nki.activation_reduce(dst2[tuple(dst_2_indices)], dst1[tuple(dst_1_indices)], src[tuple(src_1_indices)], unary_opcode, reduce_opcode, bias_buffer[p_loop, f_loop], scale)) # noqa: E501 + T.evaluate(T.nki.activation_reduce(dst2[tuple(dst_2_indices)], dst1[tuple(dst_1_indices)], src[tuple(src_1_indices)], unary_opcode, reduce_opcode, bias_buffer[p_loop, f_loop], scale)) # noqa: E501 # fmt: on import tvm @@ -122,30 +122,30 @@ def impl(): return mod["main"] else: # fmt: off - @Tx.prim_func + @T.prim_func def impl(): - for b_loop in Tx.serial(0, spatial_b_extent): - for reduction_b_loop in Tx.serial(0, reduction_b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst_repr.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, spatial_b_extent): + for reduction_b_loop in T.serial(0, reduction_b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst_repr.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, f_var: f_loop, spatial_b_var: b_loop, reduction_b_var: reduction_b_loop}) # noqa: E501 - src_1_indices = Tx.meta_var(inst_gen.generate_indices(unary_input)) - dst_1_indices = Tx.meta_var(inst_gen.generate_indices(unary_output)) + src_1_indices = T.meta_var(inst_gen.generate_indices(unary_input)) + dst_1_indices = T.meta_var(inst_gen.generate_indices(unary_output)) if inst_gen.make_guard(unary_output): if isinstance(bias, BufferRegion): - src_bias_indices = Tx.meta_var(inst_gen.generate_indices(bias)) # noqa: E501 - Tx.evaluate(Tx.nki.activation_reduce(intermediate_buffer[p_loop, reduction_b_loop], dst1[tuple(dst_1_indices)], src[tuple(src_1_indices)], unary_opcode, reduce_opcode, bias_buffer[tuple(src_bias_indices)], scale)) # noqa: E501 + src_bias_indices = T.meta_var(inst_gen.generate_indices(bias)) # noqa: E501 + T.evaluate(T.nki.activation_reduce(intermediate_buffer[p_loop, reduction_b_loop], dst1[tuple(dst_1_indices)], src[tuple(src_1_indices)], unary_opcode, reduce_opcode, bias_buffer[tuple(src_bias_indices)], scale)) # noqa: E501 else: - Tx.evaluate(Tx.nki.activation_reduce(intermediate_buffer[p_loop, reduction_b_loop], dst1[tuple(dst_1_indices)], src[tuple(src_1_indices)], unary_opcode, reduce_opcode, bias_buffer[p_loop, f_loop], scale)) # noqa: E501 - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, reduction_b_extent, annotations={nki_dim: "F"}): + T.evaluate(T.nki.activation_reduce(intermediate_buffer[p_loop, reduction_b_loop], dst1[tuple(dst_1_indices)], src[tuple(src_1_indices)], unary_opcode, reduce_opcode, bias_buffer[p_loop, f_loop], scale)) # noqa: E501 + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, reduction_b_extent, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, spatial_b_var: b_loop}) if inst_gen.make_guard(reduce_output): - dst_2_indices = Tx.meta_var(inst_gen.generate_indices(reduce_output)) # noqa: E501 + dst_2_indices = T.meta_var(inst_gen.generate_indices(reduce_output)) # TODO: we should use nki.activation_reduce as second stage reduction # noqa: E501 - Tx.evaluate(Tx.nki.tensorreduce(dst2[tuple(dst_2_indices)], intermediate_buffer[p_loop, f_loop], reduce_opcode, False, -1)) # noqa: E501 + T.evaluate(T.nki.tensorreduce(dst2[tuple(dst_2_indices)], intermediate_buffer[p_loop, f_loop], reduce_opcode, False, -1)) # noqa: E501 # fmt: on return impl @@ -160,7 +160,7 @@ def impl(): predicate( "exec_scope", lambda op, sctx: ( - sctx.scope_kind == "kernel", + sctx.scope_kind == "thread", f"unsupported exec_scope {sctx.scope_kind}", ), ) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/utils.py b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/utils.py index 0dd59240ad2d..9fbaa524fd2e 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/utils.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/compose_op/utils.py @@ -23,20 +23,20 @@ # Operation code mappings opcode_table = { - Op.get("tirx.add"): "add", - Op.get("tirx.sub"): "sub", - Op.get("tirx.mul"): "mul", - Op.get("tirx.maximum"): "max", - Op.get("tirx.minimum"): "min", - Op.get("tirx.sqrt"): "sqrt", - Op.get("tirx.sum"): "add", - Op.get("tirx.max"): "max", - Op.get("tirx.min"): "min", - Op.get("tirx.exp"): "exp", + Op.get("tirx.tile.add"): "add", + Op.get("tirx.tile.sub"): "sub", + Op.get("tirx.tile.mul"): "mul", + Op.get("tirx.tile.maximum"): "max", + Op.get("tirx.tile.minimum"): "min", + Op.get("tirx.tile.sqrt"): "sqrt", + Op.get("tirx.tile.sum"): "add", + Op.get("tirx.tile.max"): "max", + Op.get("tirx.tile.min"): "min", + Op.get("tirx.tile.exp"): "exp", } optype_table = { - Op.get("tirx.sum"): ReduceOpType.SUM, - Op.get("tirx.max"): ReduceOpType.MAX, - Op.get("tirx.min"): ReduceOpType.MIN, + Op.get("tirx.tile.sum"): ReduceOpType.SUM, + Op.get("tirx.tile.max"): ReduceOpType.MAX, + Op.get("tirx.tile.min"): ReduceOpType.MIN, } diff --git a/python/tvm/tirx/operator/tile_primitive/trn/copy/default.py b/python/tvm/tirx/operator/tile_primitive/trn/copy/default.py index 323c80a40bc2..0005723ec193 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/copy/default.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/copy/default.py @@ -17,7 +17,8 @@ """Implementation of copy operator dispatchs.""" -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx import PrimFunc from tvm.tirx.operator.tile_primitive import ( DispatchContext, @@ -41,11 +42,11 @@ def transpose_schedule( inst_repr_dst, inst_repr_src = inst_gen.find_max_inst_size_transpose(dst_region, src_region) - lhs_f = Tx.Var("lhs_F", "int32") - lhs_p = Tx.Var("lhs_P", "int32") - dst_f = Tx.Var("dst_F", "int32") - b_var = Tx.Var("B", "int32") - extend_b = Tx.Var("extend_B", "int32") + lhs_f = T.Var("lhs_F", "int32") + lhs_p = T.Var("lhs_P", "int32") + dst_f = T.Var("dst_F", "int32") + b_var = T.Var("B", "int32") + extend_b = T.Var("extend_B", "int32") p_size = src_region.buffer.layout.size("P") lhs_f_size = dst_region.buffer.layout.size("P") rhs_f_size = p_size @@ -88,17 +89,17 @@ def transpose_schedule( assert sctx.alloc_only, ( "Identity tensor must be specified in workspace. Run tvm.tirx.transform.trn.TrnPrivateBufferAlloc first." # noqa: E501 ) - identity_tensor = Tx.buffer( + identity_tensor = T.buffer( (p_size, rhs_f_size), src_region.buffer.dtype, scope="trn.sbuf", buffer_name="identity" ) sctx.add_alloc_buffer(identity_tensor) - @Tx.prim_func + @T.prim_func def identity_init(): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for rhs_f_loop in Tx.serial(0, rhs_f_size, annotations={nki_dim: "F"}): - Tx.evaluate(Tx.nki.identity(identity_tensor[p_loop, rhs_f_loop], p_size)) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for rhs_f_loop in T.serial(0, rhs_f_size, annotations={nki_dim: "F"}): + T.evaluate(T.nki.identity(identity_tensor[p_loop, rhs_f_loop], p_size)) Tx.tvm_kernel_replace_point() sctx.add_init_stmt(identity_init.body) @@ -110,13 +111,13 @@ def identity_init(): src_buffer = src_region.buffer if dst_buffer.scope() == "trn.psum": - @Tx.prim_func + @T.prim_func def transpose_psum_output(): - for b_loop in Tx.serial(0, b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for lhs_f_loop in Tx.serial(0, lhs_f_size, annotations={nki_dim: "lhs_F"}): - for rhs_f_loop in Tx.serial( + for b_loop in T.serial(0, b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for lhs_f_loop in T.serial(0, lhs_f_size, annotations={nki_dim: "lhs_F"}): + for rhs_f_loop in T.serial( 0, rhs_f_size, annotations={nki_dim: "rhs_F"} ): inst_gen.set_bind_map( @@ -126,13 +127,13 @@ def transpose_psum_output(): inst_gen.set_bind_map( src_region, {b_var: b_loop, lhs_f: lhs_f_loop, lhs_p: p_loop} ) - src_indices = Tx.meta_var(inst_gen.generate_indices(src_region)) - dst_indices = Tx.meta_var(inst_gen.generate_indices(dst_region)) - src_guard = Tx.meta_var(inst_gen.make_guard(src_region)) - dst_guard = Tx.meta_var(inst_gen.make_guard(dst_region)) + src_indices = T.meta_var(inst_gen.generate_indices(src_region)) + dst_indices = T.meta_var(inst_gen.generate_indices(dst_region)) + src_guard = T.meta_var(inst_gen.make_guard(src_region)) + dst_guard = T.meta_var(inst_gen.make_guard(dst_region)) if src_guard and dst_guard: - Tx.evaluate( - Tx.nki.matmul( + T.evaluate( + T.nki.matmul( dst_buffer[tuple(dst_indices)], src_buffer[tuple(src_indices)], identity_tensor[p_loop, rhs_f_loop], @@ -145,7 +146,7 @@ def transpose_psum_output(): assert sctx.alloc_only, ( "Accumulation psum buffer must be specified in workspace. Run tvm.tirx.transform.trn.TrnPrivateBufferAlloc first." # noqa: E501 ) - acc_psum = Tx.buffer( + acc_psum = T.buffer( (max_psum_banks, p_size, largest_psum_per_bank), "float32", scope="trn.psum", @@ -160,27 +161,27 @@ def transpose_psum_output(): max_psum_slots = acc_psum.shape[0] # fmt: off - @Tx.prim_func + @T.prim_func def transpose_sbuf_output(): - for b_loop in Tx.serial(0, b_extent): - for extend_b_loop in Tx.serial(0, extend_len): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for lhs_f_loop in Tx.serial(0, lhs_f_size, annotations={nki_dim: "lhs_F"}): - for rhs_f_loop in Tx.serial(0, rhs_f_size, annotations={nki_dim: "rhs_F"}): # noqa: E501 + for b_loop in T.serial(0, b_extent): + for extend_b_loop in T.serial(0, extend_len): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for lhs_f_loop in T.serial(0, lhs_f_size, annotations={nki_dim: "lhs_F"}): + for rhs_f_loop in T.serial(0, rhs_f_size, annotations={nki_dim: "rhs_F"}): # noqa: E501 inst_gen.set_bind_map(src_region, {b_var: b_loop, lhs_f: lhs_f_loop, lhs_p: p_loop, extend_b: extend_b_loop}) # noqa: E501 - src_indices = Tx.meta_var(inst_gen.generate_indices(src_region)) - src_guard = Tx.meta_var(inst_gen.make_guard(src_region)) + src_indices = T.meta_var(inst_gen.generate_indices(src_region)) + src_guard = T.meta_var(inst_gen.make_guard(src_region)) if src_guard: - Tx.evaluate(Tx.nki.matmul(acc_psum[b_loop % max_psum_slots, lhs_f_loop,extend_b_loop * rhs_f_size + rhs_f_loop], src_buffer[tuple(src_indices)], identity_tensor[p_loop, rhs_f_loop])) # noqa: E501 - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, rhs_f_size * extend_len, annotations={nki_dim: "F"}): + T.evaluate(T.nki.matmul(acc_psum[b_loop % max_psum_slots, lhs_f_loop,extend_b_loop * rhs_f_size + rhs_f_loop], src_buffer[tuple(src_indices)], identity_tensor[p_loop, rhs_f_loop])) # noqa: E501 + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, rhs_f_size * extend_len, annotations={nki_dim: "F"}): inst_gen.set_bind_map(dst_region, {b_var: b_loop, lhs_f: p_loop, dst_f: f_loop % rhs_f_size, extend_b: f_loop // rhs_f_size}) # noqa: E501 - dst_guard = Tx.meta_var(inst_gen.make_guard(dst_region)) - dst_indices = Tx.meta_var(inst_gen.generate_indices(dst_region)) + dst_guard = T.meta_var(inst_gen.make_guard(dst_region)) + dst_indices = T.meta_var(inst_gen.generate_indices(dst_region)) if dst_guard: - Tx.evaluate(Tx.nki.tensor_copy(dst_buffer[tuple(dst_indices)], acc_psum[b_loop % max_psum_slots, p_loop, f_loop])) # noqa: E501 + T.evaluate(T.nki.tensor_copy(dst_buffer[tuple(dst_indices)], acc_psum[b_loop % max_psum_slots, p_loop, f_loop])) # noqa: E501 # fmt: on return transpose_sbuf_output @@ -188,8 +189,8 @@ def transpose_sbuf_output(): def copy_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: """Schedule copy operation between global and shared memory on CUDA.""" # Basic validation checks - if sctx.scope_kind != "kernel": - fail("requires kernel exec_scope for TRN copy") + if sctx.scope_kind != "thread": + fail("requires thread exec_scope for TRN copy") dst_region, src_region = op.args src, dst = src_region.buffer, dst_region.buffer @@ -201,9 +202,9 @@ def copy_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: src.scope() in ["global", "trn.sbuf", "trn.psum"], dst.scope() in ["global", "trn.sbuf", "trn.psum"], src.scope() != "global" or dst.scope() != "global", - (src.scope() == "global" and isinstance(src.layout, Tx.TileLayout)) + (src.scope() == "global" and isinstance(src.layout, T.TileLayout)) or (src.scope() in ["trn.sbuf", "trn.psum"] and src.layout.is_trainium()), - (dst.scope() == "global" and isinstance(dst.layout, Tx.TileLayout)) + (dst.scope() == "global" and isinstance(dst.layout, T.TileLayout)) or (dst.scope() in ["trn.sbuf", "trn.psum"] and dst.layout.is_trainium()), ] ) @@ -242,21 +243,21 @@ def copy_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: src_to_dst = False if src.scope() == "global": - func = Tx.nki.load + func = T.nki.load elif dst.scope() == "global": - func = Tx.nki.store + func = T.nki.store else: - func = Tx.nki.tensor_copy + func = T.nki.tensor_copy - if func == Tx.nki.tensor_copy: + if func == T.nki.tensor_copy: inst_size_limit = op.config.get("max_inst_size", 512) inst.bound_inst_size(inst_size_limit, analyzer) else: assert "max_inst_size" not in op.config, "max_inst_size is not supported for load/store" - p_var = Tx.Var("P", "int32") - f_var = Tx.Var("F", "int32") - b_var = Tx.Var("B", "int32") + p_var = T.Var("P", "int32") + f_var = T.Var("F", "int32") + b_var = T.Var("B", "int32") if src_to_dst: from_region, _to_region = src_region, dst_region else: @@ -267,17 +268,17 @@ def copy_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: b_extent = inst_gen.fill_in_block_dim(from_region, b_var) # fmt: off - @Tx.prim_func + @T.prim_func def impl(): # the additional b loop is to satisfy hardware instuction size limit - for b_loop in Tx.serial(0, b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({b_var: b_loop, p_var: p_loop, f_var: f_loop}) if inst_gen.make_guard(dst_region): - src_indices = Tx.meta_var(inst_gen.generate_indices(src_region)) - dst_indices = Tx.meta_var(inst_gen.generate_indices(dst_region)) + src_indices = T.meta_var(inst_gen.generate_indices(src_region)) + dst_indices = T.meta_var(inst_gen.generate_indices(dst_region)) func(dst[tuple(dst_indices)], src[tuple(src_indices)]) # fmt: on return impl @@ -293,7 +294,7 @@ def impl(): predicate( "exec_scope", lambda op, sctx: ( - sctx.scope_kind == "kernel", + sctx.scope_kind == "thread", f"unsupported exec_scope {sctx.scope_kind}", ), ) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/dim_utils.py b/python/tvm/tirx/operator/tile_primitive/trn/dim_utils.py index 4b77bd1c3c3e..ff7064f63d68 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/dim_utils.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/dim_utils.py @@ -20,7 +20,7 @@ from collections import namedtuple from tvm.arith.analyzer import Analyzer -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion # Represents the part of data iter covered by the buffer region @@ -34,14 +34,14 @@ def normalize_and_group(layout, shape): Parameters ---------- - layout : Union[Tx.TrainiumLayout, Tx.TileLayout] + layout : Union[T.TrainiumLayout, T.TileLayout] The layout to normalize shape : List[int] The shape to normalize with Returns ------- - Tuple[Union[Tx.TrainiumLayout, Tx.TileLayout], List[int]] : + Tuple[Union[T.TrainiumLayout, T.TileLayout], List[int]] : Normalized layout and separators Raises @@ -49,7 +49,7 @@ def normalize_and_group(layout, shape): ValueError : If layout is not a valid layout type """ - if isinstance(layout, Tx.TileLayout): + if isinstance(layout, T.TileLayout): return layout.canonicalize().group(shape) else: raise ValueError("Invalid layout") diff --git a/python/tvm/tirx/operator/tile_primitive/trn/gemm/default.py b/python/tvm/tirx/operator/tile_primitive/trn/gemm/default.py index 22c3c3cd7f77..ca572ba781da 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/gemm/default.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/gemm/default.py @@ -22,7 +22,7 @@ from tvm.arith.analyzer import Analyzer from tvm.ir import assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimFunc from tvm.tirx.operator.tile_primitive import ( DispatchContext, @@ -110,8 +110,8 @@ def get_pf_dim_from_buffer_region( def matmul_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: """Schedule GEMM operation on Trainium.""" # Basic validation checks - if not (sctx.is_trn() and sctx.scope_kind == "kernel"): - fail("requires Trainium target and kernel exec_scope") + if not (sctx.is_trn() and sctx.scope_kind == "thread"): + fail("requires Trainium target and thread exec_scope") # Extract arguments ( @@ -199,12 +199,12 @@ def matmul_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: inst_repr = inst_gen.find_max_inst_size_from_one_region(B_buffer_region, [rhs_f_dim]) inst_repr = inst_gen.fit_inst_tile_to_region(inst_repr, C_buffer_region, [acc_f_dim]) inst_repr.bound_inst_size(512, analyzer) - rhs_f = Tx.Var("rhs_f", "int32") - lhs_f = Tx.Var("lhs_f", "int32") - p = Tx.Var("p", "int32") - reduction_b = Tx.Var("reduction_b", "int32") - lhs_b = Tx.Var("lhs_b", "int32") - rhs_b = Tx.Var("rhs_b", "int32") + rhs_f = T.Var("rhs_f", "int32") + lhs_f = T.Var("lhs_f", "int32") + p = T.Var("p", "int32") + reduction_b = T.Var("reduction_b", "int32") + lhs_b = T.Var("lhs_b", "int32") + rhs_b = T.Var("rhs_b", "int32") lhs_f_size = C.layout.size("P") inst_gen.bind_inst_iter( B_buffer_region, rhs_f, inst_repr.size, inst_repr.stride, is_free_dim=True @@ -218,29 +218,29 @@ def matmul_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: # FIXME: we need to lower the guard to things like matmul(lhs[...][lhs_guard], rhs[...][rhs_guard], mask=p_guard) # noqa: E501 # so we need to separate the guard for lhs_f, rhs_f and p # fmt: off - @Tx.inline + @T.inline def matmul_inst_macro(lhs_b_loop, rhs_b_loop, reduction_b_loop, acc, C_as_output, max_psum_slots): # noqa: E501 - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={"nki_dim": "P"}): - for lhs_f_loop in Tx.serial(0, lhs_f_size, annotations={"nki_dim": "lhs_F"}): - for rhs_f_loop in Tx.serial(0, inst_repr.size, annotations={"nki_dim": "rhs_F"}): # noqa: E501 - b_idx = Tx.meta_var(lhs_b_loop * rhs_b_extent + rhs_b_loop) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={"nki_dim": "P"}): + for lhs_f_loop in T.serial(0, lhs_f_size, annotations={"nki_dim": "lhs_F"}): + for rhs_f_loop in T.serial(0, inst_repr.size, annotations={"nki_dim": "rhs_F"}): + b_idx = T.meta_var(lhs_b_loop * rhs_b_extent + rhs_b_loop) inst_gen.set_bind_map(A_buffer_region, {lhs_b: lhs_b_loop, lhs_f: lhs_f_loop, p: p_loop, reduction_b: reduction_b_loop}) # noqa: E501 inst_gen.set_bind_map(B_buffer_region, {rhs_b: rhs_b_loop, rhs_f: rhs_f_loop, p: p_loop, reduction_b: reduction_b_loop}) # noqa: E501 inst_gen.set_bind_map(C_buffer_region, {lhs_f: lhs_f_loop, rhs_f: rhs_f_loop, lhs_b: lhs_b_loop, rhs_b: rhs_b_loop}) # noqa: E501 - lhs_indices = Tx.meta_var(inst_gen.generate_indices(A_buffer_region)) - rhs_indices = Tx.meta_var(inst_gen.generate_indices(B_buffer_region)) - C_indices = Tx.meta_var(inst_gen.generate_indices(C_buffer_region)) + lhs_indices = T.meta_var(inst_gen.generate_indices(A_buffer_region)) + rhs_indices = T.meta_var(inst_gen.generate_indices(B_buffer_region)) + C_indices = T.meta_var(inst_gen.generate_indices(C_buffer_region)) if inst_gen.make_guard(A_buffer_region) and inst_gen.make_guard(B_buffer_region): # noqa: E501 if C_as_output: - Tx.evaluate(Tx.nki.matmul(acc[C_indices], A[lhs_indices], B[rhs_indices])) # noqa: E501 + T.evaluate(T.nki.matmul(acc[C_indices], A[lhs_indices], B[rhs_indices])) # noqa: E501 else: - Tx.evaluate(Tx.nki.matmul(acc[b_idx % max_psum_slots, lhs_f_loop, rhs_f_loop], A[lhs_indices], B[rhs_indices])) # noqa: E501 + T.evaluate(T.nki.matmul(acc[b_idx % max_psum_slots, lhs_f_loop, rhs_f_loop], A[lhs_indices], B[rhs_indices])) # noqa: E501 if C.scope() == "trn.psum": - @Tx.prim_func + @T.prim_func def impl_C_psum(): - for lhs_b_loop, rhs_b_loop, reduction_b_loop in Tx.grid(lhs_b_extent, rhs_b_extent, reduction_b_extent): # noqa: E501 + for lhs_b_loop, rhs_b_loop, reduction_b_loop in T.grid(lhs_b_extent, rhs_b_extent, reduction_b_extent): # noqa: E501 matmul_inst_macro(lhs_b_loop, rhs_b_loop, reduction_b_loop, C, True, None) return impl_C_psum @@ -253,7 +253,7 @@ def impl_C_psum(): acc_psum_shape = (max_psum_banks, p_size, largest_psum_per_bank) if "acc_psum" not in op.workspace: assert sctx.alloc_only, "Accumulation psum buffer must be specified in workspace. Run tvm.tirx.transform.trn.TrnPrivateBufferAlloc first." # noqa: E501 - acc_psum = Tx.buffer( + acc_psum = T.buffer( acc_psum_shape, "float32", scope="trn.psum", @@ -267,19 +267,19 @@ def impl_C_psum(): check_workspace_buffer(acc_psum, (p_size, largest_psum_per_bank), "trn.psum") max_psum_slots = acc_psum.shape[0] - @Tx.prim_func + @T.prim_func def impl_C_sbuf(): - for lhs_b_loop, rhs_b_loop in Tx.grid(lhs_b_extent, rhs_b_extent): - for reduction_b_loop in Tx.serial(0, reduction_b_extent): + for lhs_b_loop, rhs_b_loop in T.grid(lhs_b_extent, rhs_b_extent): + for reduction_b_loop in T.serial(0, reduction_b_extent): matmul_inst_macro(lhs_b_loop, rhs_b_loop, reduction_b_loop, acc_psum, False, max_psum_slots) # noqa: E501 - with Tx.attr(0, "tensorized_nki_instruction", 1): - for lhs_f_loop in Tx.serial(0, lhs_f_size, annotations={"nki_dim": "P"}): - for rhs_f_loop in Tx.serial(0, inst_repr.size, annotations={"nki_dim": "F"}): - b_idx = Tx.meta_var(lhs_b_loop * rhs_b_extent + rhs_b_loop) + with T.attr(0, "tensorized_nki_instruction", 1): + for lhs_f_loop in T.serial(0, lhs_f_size, annotations={"nki_dim": "P"}): + for rhs_f_loop in T.serial(0, inst_repr.size, annotations={"nki_dim": "F"}): + b_idx = T.meta_var(lhs_b_loop * rhs_b_extent + rhs_b_loop) inst_gen.set_bind_map(C_buffer_region, {lhs_f: lhs_f_loop, rhs_f: rhs_f_loop, lhs_b: lhs_b_loop, rhs_b: rhs_b_loop}) # noqa: E501 if inst_gen.make_guard(C_buffer_region): - acc_indices = Tx.meta_var(inst_gen.generate_indices(C_buffer_region)) - Tx.evaluate(Tx.nki.tensor_copy(C[acc_indices], acc_psum[b_idx % max_psum_slots, lhs_f_loop, rhs_f_loop])) # noqa: E501 + acc_indices = T.meta_var(inst_gen.generate_indices(C_buffer_region)) + T.evaluate(T.nki.tensor_copy(C[acc_indices], acc_psum[b_idx % max_psum_slots, lhs_f_loop, rhs_f_loop])) # noqa: E501 # fmt: on return impl_C_sbuf @@ -294,7 +294,7 @@ def impl_C_sbuf(): predicate( "exec_scope", lambda op, sctx: ( - sctx.scope_kind == "kernel", + sctx.scope_kind == "thread", f"unsupported exec_scope {sctx.scope_kind}", ), ) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/instruction_generator.py b/python/tvm/tirx/operator/tile_primitive/trn/instruction_generator.py index 11c9edca8f75..58163d4b148a 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/instruction_generator.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/instruction_generator.py @@ -26,7 +26,7 @@ import tvm from tvm.arith.analyzer import Analyzer from tvm.ir import Range -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimExpr, Var from tvm.tirx.expr_functor import ExprMutator from tvm.tirx.layout import Iter @@ -42,13 +42,13 @@ class LogicalIterDim: @staticmethod def default(): - return LogicalIterDim(1, 1, Tx.int32(0)) + return LogicalIterDim(1, 1, T.int32(0)) LogicalIterList = tuple[tuple[tuple[LogicalIterDim]]] -def to_int_list(intimm_list: list[Tx.IntImm]): +def to_int_list(intimm_list: list[T.IntImm]): return [int(i) for i in intimm_list] @@ -532,7 +532,7 @@ def make_guard(self, buffer_region: BufferRegion): ] axes = self.generate_axes(buffer_region) guard = reduce( - Tx.And, + T.And, [axes[i] < r.extent for i, r in enumerate(buffer_region.region) if i in relaxed_dims], True, ) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py b/python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py index bfcbb5bc27e5..90b97aeb62fd 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py @@ -17,7 +17,8 @@ from typing import Any -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx import Buffer, FloatImm, Stmt from tvm.tirx.operator.tile_primitive.dispatch_context import DispatchContext from tvm.tirx.operator.tile_primitive.ops import ( @@ -53,14 +54,14 @@ def alloc_const_bias_trn( return {"const_bias": ("const_bias", bias.value)} else: new_shape = (par_size, max_inst_size) - new_buffer = Tx.buffer(new_shape, dtype=bias.dtype, scope="trn.sbuf", buffer_name="const_bias") + new_buffer = T.buffer(new_shape, dtype=bias.dtype, scope="trn.sbuf", buffer_name="const_bias") - @Tx.prim_func + @T.prim_func def const_bias_init(): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, par_size, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, max_inst_size, annotations={nki_dim: "F"}): - Tx.evaluate(Tx.nki.memset(new_buffer[p_loop, f_loop], bias)) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, par_size, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, max_inst_size, annotations={nki_dim: "F"}): + T.evaluate(T.nki.memset(new_buffer[p_loop, f_loop], bias)) Tx.tvm_kernel_replace_point() buffer_dict[("const_bias", bias.value)] = (new_buffer, const_bias_init.body) @@ -101,16 +102,16 @@ def alloc_identity_trn( return {"identity": "identity"} else: new_shape = (par_size, par_size) - new_buffer = Tx.buffer( + new_buffer = T.buffer( new_shape, dtype=op.srcs[0].buffer.dtype, scope="trn.sbuf", buffer_name="identity" ) - @Tx.prim_func + @T.prim_func def identity_init(): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, par_size, annotations={nki_dim: "P"}): - for rhs_f_loop in Tx.serial(0, par_size, annotations={nki_dim: "F"}): - Tx.evaluate(Tx.nki.identity(new_buffer[p_loop, rhs_f_loop], par_size)) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, par_size, annotations={nki_dim: "P"}): + for rhs_f_loop in T.serial(0, par_size, annotations={nki_dim: "F"}): + T.evaluate(T.nki.identity(new_buffer[p_loop, rhs_f_loop], par_size)) Tx.tvm_kernel_replace_point() buffer_dict["identity"] = (new_buffer, identity_init.body) @@ -123,7 +124,7 @@ def alloc_acc_psum_trn( if "acc_psum" in op.workspace or op.dsts[0].buffer.scope() == "trn.psum": return {} par_size = op.dsts[0].buffer.layout.size("P") - acc_psum = Tx.buffer( + acc_psum = T.buffer( (8, par_size, 512), "float32", scope="trn.psum", diff --git a/python/tvm/tirx/operator/tile_primitive/trn/reduction/utils.py b/python/tvm/tirx/operator/tile_primitive/trn/reduction/utils.py index c76aa39fce62..1d9840e5d674 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/reduction/utils.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/reduction/utils.py @@ -17,7 +17,7 @@ """Shared helpers for reduction schedules.""" -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext, fail from tvm.tirx.stmt import TilePrimitiveCall @@ -48,7 +48,7 @@ def generate_intermediate_buffer( assert sctx.alloc_only, ( "Partial reduce buffer must be specified in workspace. Run tvm.tirx.transform.trn.TrnPrivateBufferAlloc first." # noqa: E501 ) - intermediate_buffer = Tx.buffer( + intermediate_buffer = T.buffer( intermediate_shape, dtype=dst_buffer_region.buffer.dtype, scope="trn.sbuf", @@ -73,8 +73,8 @@ def reduction_trn( Returns: Optional[PrimFunc]: The scheduled function, or None if not applicable. """ - if not (sctx.is_trn() and sctx.scope_kind == "kernel"): - fail("requires Trainium target and kernel exec_scope") + if not (sctx.is_trn() and sctx.scope_kind == "thread"): + fail("requires Trainium target and thread exec_scope") dst_buffer_region, src_buffer_region, axes, accum = op.args[:4] assert not accum, "Accumulation is not supported for reduction on Trainium" @@ -109,10 +109,10 @@ def reduction_trn( # Get partition size and extents p_size = src.layout.size("P") - f_var = Tx.Var("F", "int32") - p_var = Tx.Var("P", "int32") - spatial_b_var = Tx.Var("sB", "int32") - reduction_b_var = Tx.Var("rB", "int32") + f_var = T.Var("F", "int32") + p_var = T.Var("P", "int32") + spatial_b_var = T.Var("sB", "int32") + reduction_b_var = T.Var("rB", "int32") inst_gen.bind_inst_iter(src_buffer_region, f_var, inst_repr.size, inst_repr.stride, True) inst_gen.bind_inst_iter(src_buffer_region, p_var, p_size, 1, False) reduction_b_extent = inst_gen.fill_in_block_dim(src_buffer_region, reduction_b_var, axes) @@ -129,38 +129,38 @@ def reduction_trn( # fmt: off # Single-stage reduction implementation if reduction_b_extent == 1: - @Tx.prim_func + @T.prim_func def impl(): - for b_loop in Tx.serial(0, spatial_b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst_repr.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, spatial_b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst_repr.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, f_var: f_loop, spatial_b_var: b_loop}) # noqa: E501 if inst_gen.make_guard(src_buffer_region): - src_indices = Tx.meta_var(inst_gen.generate_indices(src_buffer_region)) # noqa: E501 - dst_indices = Tx.meta_var(inst_gen.generate_indices(dst_buffer_region)) # noqa: E501 - Tx.evaluate(Tx.nki.tensorreduce(dst[tuple(dst_indices)], src[tuple(src_indices)], opcode, negate, -1)) # noqa: E501 + src_indices = T.meta_var(inst_gen.generate_indices(src_buffer_region)) # noqa: E501 + dst_indices = T.meta_var(inst_gen.generate_indices(dst_buffer_region)) # noqa: E501 + T.evaluate(T.nki.tensorreduce(dst[tuple(dst_indices)], src[tuple(src_indices)], opcode, negate, -1)) # noqa: E501 return impl # Two-stage reduction implementation else: - @Tx.prim_func + @T.prim_func def two_stage_reduction(): - for b_loop in Tx.serial(0, spatial_b_extent): - for reduction_b_loop in Tx.serial(0, reduction_b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst_repr.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, spatial_b_extent): + for reduction_b_loop in T.serial(0, reduction_b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst_repr.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, f_var: f_loop, spatial_b_var: b_loop, reduction_b_var: reduction_b_loop}) # noqa: E501 if inst_gen.make_guard(src_buffer_region): - src_indices = Tx.meta_var(inst_gen.generate_indices(src_buffer_region)) # noqa: E501 - Tx.evaluate(Tx.nki.tensorreduce(intermediate_buffer[p_loop, reduction_b_loop], src[src_indices], opcode, False, -1)) # noqa: E501 - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, reduction_b_extent, annotations={nki_dim: "F"}): + src_indices = T.meta_var(inst_gen.generate_indices(src_buffer_region)) # noqa: E501 + T.evaluate(T.nki.tensorreduce(intermediate_buffer[p_loop, reduction_b_loop], src[src_indices], opcode, False, -1)) # noqa: E501 + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, reduction_b_extent, annotations={nki_dim: "F"}): inst_gen.set_bind_map(src_buffer_region, {p_var: p_loop, f_var: 0, spatial_b_var: b_loop, reduction_b_var: f_loop}) # noqa: E501 inst_gen.set_bind_map(dst_buffer_region, {p_var: p_loop, spatial_b_var: b_loop}) # noqa: E501 if inst_gen.make_guard(src_buffer_region): - dst_indices = Tx.meta_var(inst_gen.generate_indices(dst_buffer_region)) # noqa: E501 - Tx.evaluate(Tx.nki.tensorreduce(dst[dst_indices], intermediate_buffer[p_loop, f_loop], opcode, negate, -1)) # noqa: E501 + dst_indices = T.meta_var(inst_gen.generate_indices(dst_buffer_region)) # noqa: E501 + T.evaluate(T.nki.tensorreduce(dst[dst_indices], intermediate_buffer[p_loop, f_loop], opcode, negate, -1)) # noqa: E501 return two_stage_reduction # fmt: on diff --git a/python/tvm/tirx/operator/tile_primitive/trn/select/default.py b/python/tvm/tirx/operator/tile_primitive/trn/select/default.py index 54de3005a3db..27136a3ac342 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/select/default.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/select/default.py @@ -17,7 +17,7 @@ """Implementation of select schedules.""" -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import BufferRegion, FloatImm, PrimFunc, TilePrimitiveCall from tvm.tirx.operator.tile_primitive import ( DispatchContext, @@ -34,8 +34,8 @@ def select_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: """Generate schedule for select operation on Trainium.""" - if sctx.scope_kind != "kernel": - fail("requires kernel exec_scope for TRN select") + if sctx.scope_kind != "thread": + fail("requires thread exec_scope for TRN select") op = TilePrimitiveCall.downcast(op) assert isinstance(op, Select), f"{op} is not a Select" @@ -94,9 +94,9 @@ def select_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: inst_repr = inst_gen.restrict_inst_to_one_dim(inst_repr) inst_repr.bound_inst_size(op.config.get("max_inst_size", 512), analyzer) - p_var = Tx.Var("p", "int32") - b_var = Tx.Var("b", "int32") - f_var = Tx.Var("f", "int32") + p_var = T.Var("p", "int32") + b_var = T.Var("b", "int32") + f_var = T.Var("f", "int32") p_size = dst.buffer.layout.size("P") inst_gen.bind_inst_iter(dst, f_var, inst_repr.size, inst_repr.stride, True) inst_gen.bind_inst_iter(dst, p_var, p_size, 1, False) @@ -107,18 +107,18 @@ def select_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: true_value_buffer = true_value.buffer # fmt: off - @Tx.prim_func + @T.prim_func def impl(): - for b_loop in Tx.serial(0, b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst_repr.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst_repr.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({f_var: f_loop, p_var: p_loop, b_var: b_loop}) if inst_gen.make_guard(dst): - dst_indices = Tx.meta_var(inst_gen.generate_indices(dst)) - true_value_indices = Tx.meta_var(inst_gen.generate_indices(true_value)) - pred = Tx.meta_var(analyzer.simplify(op.predicate.apply(inst_gen.generate_axes(dst)))) # noqa: E501 - Tx.evaluate(Tx.nki.affine_select(dst_buffer[tuple(dst_indices)], pred, true_value_buffer[tuple(true_value_indices)], false_value)) # noqa: E501 + dst_indices = T.meta_var(inst_gen.generate_indices(dst)) + true_value_indices = T.meta_var(inst_gen.generate_indices(true_value)) + pred = T.meta_var(analyzer.simplify(op.predicate.apply(inst_gen.generate_axes(dst)))) # noqa: E501 + T.evaluate(T.nki.affine_select(dst_buffer[tuple(dst_indices)], pred, true_value_buffer[tuple(true_value_indices)], false_value)) # noqa: E501 # fmt: on return impl @@ -134,7 +134,7 @@ def impl(): predicate( "exec_scope", lambda op, sctx: ( - sctx.scope_kind == "kernel", + sctx.scope_kind == "thread", f"unsupported exec_scope {sctx.scope_kind}", ), ) diff --git a/python/tvm/tirx/operator/tile_primitive/trn/unary/default.py b/python/tvm/tirx/operator/tile_primitive/trn/unary/default.py index 0b7c9badd25a..e336daa717a3 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/unary/default.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/unary/default.py @@ -35,8 +35,8 @@ def unary_trn(op: TilePrimitiveCall, unary_op: MapOpType, sctx: DispatchContext) -> PrimFunc | None: """Schedule unary operation on Trainium.""" # Check execution environment - if not (sctx.is_trn() and sctx.scope_kind == "kernel"): - fail("requires Trainium target and kernel exec_scope") + if not (sctx.is_trn() and sctx.scope_kind == "thread"): + fail("requires Trainium target and thread exec_scope") # Extract operation arguments dst_buffer_region, _src = op.args diff --git a/python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py b/python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py index 33ee83eb6a92..24d1704923de 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py @@ -18,7 +18,8 @@ """Shared helpers, op tables, and validation functions for unary operator dispatches.""" from tvm.arith.analyzer import Analyzer -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx import BufferRegion, FloatImm from ...common import MapOpType @@ -100,15 +101,15 @@ def get_const_bias_tensor(bias, shape, dtype, workspace, sctx): "Constant bias tensor must be specified in workspace. Run tvm.tirx.transform.trn.TrnPrivateBufferAlloc first." # noqa: E501 ) # Create new bias buffer - bias_buffer = Tx.buffer(shape, dtype, scope="trn.sbuf", buffer_name="const_bias") + bias_buffer = T.buffer(shape, dtype, scope="trn.sbuf", buffer_name="const_bias") sctx.add_alloc_buffer(bias_buffer) - @Tx.prim_func + @T.prim_func def const_bias_init(): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, shape[0], annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, shape[1], annotations={nki_dim: "F"}): - Tx.evaluate(Tx.nki.memset(bias_buffer[p_loop, f_loop], bias)) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, shape[0], annotations={nki_dim: "P"}): + for f_loop in T.serial(0, shape[1], annotations={nki_dim: "F"}): + T.evaluate(T.nki.memset(bias_buffer[p_loop, f_loop], bias)) Tx.tvm_kernel_replace_point() sctx.add_init_stmt(const_bias_init.body) @@ -141,9 +142,9 @@ def generate_unary_func( inst_size_limit = config.get("max_inst_size", 512) inst_repr.bound_inst_size(inst_size_limit, analyzer) - f_var = Tx.Var("F", "int32") - p_var = Tx.Var("P", "int32") - b_var = Tx.Var("B", "int32") + f_var = T.Var("F", "int32") + p_var = T.Var("P", "int32") + b_var = T.Var("B", "int32") inst_gen.bind_inst_iter(dst_buffer_region, f_var, inst_repr.size, inst_repr.stride, True) inst_gen.bind_inst_iter(dst_buffer_region, p_var, p_size, 1, False) b_extent = inst_gen.fill_in_block_dim(dst_buffer_region, b_var) @@ -164,26 +165,26 @@ def generate_unary_func( bias_buffer = bias.buffer # fmt: off - @Tx.prim_func + @T.prim_func def impl(): - for b_loop in Tx.serial(0, b_extent): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, p_size, annotations={nki_dim: "P"}): - for f_loop in Tx.serial(0, inst_repr.size, annotations={nki_dim: "F"}): + for b_loop in T.serial(0, b_extent): + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): + for f_loop in T.serial(0, inst_repr.size, annotations={nki_dim: "F"}): inst_gen.set_bind_map_all({p_var: p_loop, f_var: f_loop, b_var: b_loop}) - dst_indices = Tx.meta_var(inst_gen.generate_indices(dst_buffer_region)) + dst_indices = T.meta_var(inst_gen.generate_indices(dst_buffer_region)) if inst_gen.make_guard(dst_buffer_region): if unary_op == MapOpType.FILL: - Tx.evaluate(Tx.nki.memset(dst[tuple(dst_indices)], _src)) + T.evaluate(T.nki.memset(dst[tuple(dst_indices)], _src)) else: - src_indices = Tx.meta_var(inst_gen.generate_indices(_src)) + src_indices = T.meta_var(inst_gen.generate_indices(_src)) if unary_op == MapOpType.RECIPROCAL: - Tx.evaluate(Tx.nki.reciprocal(dst[tuple(dst_indices)], src[tuple(src_indices)])) # noqa: E501 + T.evaluate(T.nki.reciprocal(dst[tuple(dst_indices)], src[tuple(src_indices)])) # noqa: E501 elif isinstance(bias, BufferRegion): - bias_indices = Tx.meta_var(inst_gen.generate_indices(bias)) - Tx.evaluate(Tx.nki.activation(dst[tuple(dst_indices)], src[tuple(src_indices)], opcode, scale=scale, bias=bias_buffer[tuple(bias_indices)])) # noqa: E501 + bias_indices = T.meta_var(inst_gen.generate_indices(bias)) + T.evaluate(T.nki.activation(dst[tuple(dst_indices)], src[tuple(src_indices)], opcode, scale=scale, bias=bias_buffer[tuple(bias_indices)])) # noqa: E501 else: - Tx.evaluate(Tx.nki.activation(dst[tuple(dst_indices)], src[tuple(src_indices)], opcode, scale=scale, bias=bias_buffer[p_loop, f_loop])) # noqa: E501 + T.evaluate(T.nki.activation(dst[tuple(dst_indices)], src[tuple(src_indices)], opcode, scale=scale, bias=bias_buffer[p_loop, f_loop])) # noqa: E501 # fmt: on return impl diff --git a/python/tvm/tirx/operator/tile_primitive/trn/unary/with_bias_scale.py b/python/tvm/tirx/operator/tile_primitive/trn/unary/with_bias_scale.py index fac26a85f10e..399d8cfa6d11 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/unary/with_bias_scale.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/unary/with_bias_scale.py @@ -33,8 +33,8 @@ def unary_with_bias_scale_trn( ) -> PrimFunc | None: """Schedule unary operation with bias and scale on Trainium.""" # Check execution environment - if not (sctx.is_trn() and sctx.scope_kind == "kernel"): - fail("requires Trainium target and kernel exec_scope") + if not (sctx.is_trn() and sctx.scope_kind == "thread"): + fail("requires Trainium target and thread exec_scope") # Extract operation arguments with defaults dst_buffer_region, src_buffer_region, _bias, scale = op.args diff --git a/python/tvm/tirx/script/__init__.py b/python/tvm/tirx/script/__init__.py index 57877f4e73b8..8abbcda38781 100644 --- a/python/tvm/tirx/script/__init__.py +++ b/python/tvm/tirx/script/__init__.py @@ -30,51 +30,8 @@ from .parser import macro except ImportError: macro = None -from .builder.ir import TensorMap, meta_class -from .builder.tirx import * - - -def __getattr__(name: str): - """Resolve undefined attributes as dynamic TilePrimitiveCall ops. - - Registers ``tirx.`` lazily so the op is available for IR walks - after the prim_func is built. - """ - if name.startswith("_"): - raise AttributeError(f"module 'tvm.tirx.script' has no attribute {name!r}") - import tvm_ffi - - from tvm.ir import Op - from tvm.tirx.stmt import TilePrimitiveCall +from tvm.tirx.lang.alloc_pool import SMEMPool, TMEMPool, TMEMStages - op_name = "tirx." + name - _register_op = tvm_ffi.get_global_func("ir.RegisterOp") - from tvm.ir import register_op_attr - - def _fn(*args, workspace=None, config=None, dispatch=None, **kwargs): - try: - op = Op.get(op_name) - except Exception: - _register_op(op_name, "") - register_op_attr(op_name, "TIsTIRxOp", True) - op = Op.get(op_name) - if workspace is None: - workspace = {} - if config is None: - config = kwargs or {} - # Convert Buffer args to BufferRegion (covers full extent) - from tvm.tirx import Buffer as _TBuffer - - new_args = [] - for a in args: - if isinstance(a, _TBuffer): - slices = [slice(None) for _ in range(len(a.shape))] - a = a[slices] - new_args.append(a) - # Insert into the active frame using same FFI hook as registered ops. - from .builder.tirx import f_insert as _f_insert - - return _f_insert(TilePrimitiveCall(*new_args, op=op, workspace=workspace, config=config)) - - _fn.__name__ = name - return _fn +from . import tile +from .builder.ir import TensorMap, meta_class +from .tile import cluster, cta, thread, warp, warpgroup, wg diff --git a/python/tvm/tirx/script/builder/__init__.py b/python/tvm/tirx/script/builder/__init__.py index 35f53fb49fc1..5cada7493a95 100644 --- a/python/tvm/tirx/script/builder/__init__.py +++ b/python/tvm/tirx/script/builder/__init__.py @@ -21,4 +21,6 @@ from .ir import boolean as bool # pylint: disable=redefined-builtin from .ir import buffer as Buffer from .utils import buffer_proxy, frame_scope, seq_scope -from .tirx import * +from tvm.tirx.lang.alloc_pool import SMEMPool, TMEMPool, TMEMStages +from . import tirx as tile +from .tirx import cluster, cta, thread, warp, warpgroup, wg diff --git a/python/tvm/tirx/script/builder/frame.py b/python/tvm/tirx/script/builder/frame.py index 21920e893448..d36fd5364bf8 100644 --- a/python/tvm/tirx/script/builder/frame.py +++ b/python/tvm/tirx/script/builder/frame.py @@ -34,18 +34,6 @@ class PrimFuncFrame(TIRFrame): ... class SBlockFrame(TIRFrame): ... -@_register_object("script.ir_builder.tirx.ExecScopeFrame") -class ExecScopeFrame(TIRFrame): - """A frame that represents an execution scope (e.g. cta, warp, thread). - - When exiting this frame, it produces an ExecScopeStmt wrapping the body. - To narrow execution to a subset of the scope, wrap the ``with`` in an - ``if`` guard with a canonical thread-filter predicate -- e.g. - ``if lo <= var and var < hi:`` -- recognized by the lowering pass (see - ``src/tirx/analysis/filter_canonical.h``). - """ - - @_register_object("script.ir_builder.tirx.SBlockInitFrame") class BlockInitFrame(TIRFrame): ... diff --git a/python/tvm/tirx/script/builder/ir.py b/python/tvm/tirx/script/builder/ir.py index 318a2d2bb652..7f527413e375 100644 --- a/python/tvm/tirx/script/builder/ir.py +++ b/python/tvm/tirx/script/builder/ir.py @@ -568,43 +568,6 @@ def sblock(name: str = "", no_realize: bool = False, exec_scope: str = "") -> fr return _ffi_api.Block(name, no_realize, exec_scope) # type: ignore[attr-defined] # pylint: disable=no-member -def _scope_guards(args: tuple[Any, ...]) -> list[PrimExpr]: - if not args: - return [] - if len(args) == 1: - return [args[0]] - raise ValueError( - "Exec scope guards expect no args or one predicate expression. " - "Use `with Tx.scope((0 <= var) & (var < hi))` for structural predicates, " - "or `with Tx.scope(Tx.filter(var, opaque_selector))` when a selector annotation is needed." - ) - - -def cluster(*guards: Any) -> frame.ExecScopeFrame: - """Open a ``cluster``-level execution scope.""" - return _ffi_api.Cluster(_scope_guards(guards)) # type: ignore[attr-defined] # pylint: disable=no-member - - -def cta(*guards: Any) -> frame.ExecScopeFrame: - """Open a ``cta``-level execution scope.""" - return _ffi_api.CTA(_scope_guards(guards)) # type: ignore[attr-defined] # pylint: disable=no-member - - -def warpgroup(*guards: Any) -> frame.ExecScopeFrame: - """Open a ``warpgroup``-level execution scope.""" - return _ffi_api.WarpGroup(_scope_guards(guards)) # type: ignore[attr-defined] # pylint: disable=no-member - - -def warp(*guards: Any) -> frame.ExecScopeFrame: - """Open a ``warp``-level execution scope.""" - return _ffi_api.Warp(_scope_guards(guards)) # type: ignore[attr-defined] # pylint: disable=no-member - - -def thread(*guards: Any) -> frame.ExecScopeFrame: - """Open a ``thread``-level execution scope.""" - return _ffi_api.Thread(_scope_guards(guards)) # type: ignore[attr-defined] # pylint: disable=no-member - - def device_entry() -> None: """Mark the device-region entry within the enclosing PrimFunc body. @@ -612,16 +575,16 @@ def device_entry() -> None: accumulate into an ``AttrStmt("tirx.device_entry", True, body=...)``; the wrapping is closed by the PrimFunc frame at function end. - Anything written before this marker is host code (e.g. ``Tx.match_buffer``); + Anything written before this marker is host code (e.g. ``T.match_buffer``); anything after is device code. Example:: - @Tx.prim_func + @T.prim_func def kernel(...): - A = Tx.match_buffer(...) - Tx.device_entry() # device region starts here - bx = Tx.cta_id([SM_COUNT]) # standalone scope-id def + A = T.match_buffer(...) + T.device_entry() # device region starts here + bx = T.cta_id([SM_COUNT]) # standalone scope-id def ... """ attr_frame = _ffi_api.DeviceEntry() # type: ignore[attr-defined] # pylint: disable=no-member @@ -631,17 +594,16 @@ def kernel(...): def elected(): - """Stub that rejects the removed ``Tx.elected()`` sugar. + """Stub that rejects the removed ``T.elected()`` sugar. Write the explicit form instead:: - if Tx.ptx.elect_sync(): - with Tx.thread(): - ... + if T.ptx.elect_sync(): + ... # thread is the default scope """ raise RuntimeError( - "Tx.elected() is no longer available. Write explicitly: " - "`if Tx.ptx.elect_sync(): with Tx.thread():`" + "T.elected() is no longer available. Write explicitly: " + "`if T.ptx.elect_sync(): ...` (thread is the default scope)" ) @@ -924,7 +886,7 @@ def wg_reg_tile(elem_per_thread: int, dtype: str = "float32") -> Buffer: Sugar for the recurring pattern:: - Tx.alloc_buffer( + T.alloc_buffer( (128, elem_per_thread), dtype, layout=wg_local_layout(elem_per_thread), scope="local", @@ -1532,7 +1494,7 @@ def as_var(self, rhs_dtype=None): """Resolve to a tir.Var.""" if self.type_spec is not None: if isinstance(self.type_spec, Var): - return self.type_spec # Already a Var (e.g. Tx.handle(...)) + return self.type_spec # Already a Var (e.g. T.handle(...)) elif callable(self.type_spec): return self.type_spec() # e.g. T.int32() -> Var elif isinstance(self.type_spec, Type): @@ -1551,8 +1513,8 @@ def as_var(self, rhs_dtype=None): class LocalVectorAnnotation: """Marker for local vector/tensor allocation via type annotation subscript. - Created when a DtypeConstructor is subscripted, e.g. ``Tx.float32[N]`` or - ``Tx.float32[M, N]``. The parser's ``visit_ann_assign`` recognises this + Created when a DtypeConstructor is subscripted, e.g. ``T.float32[N]`` or + ``T.float32[M, N]``. The parser's ``visit_ann_assign`` recognises this object and lowers it to ``T.alloc_local(shape=..., dtype=...)``. """ @@ -1568,10 +1530,10 @@ class DtypeConstructor: Replaces the plain functions previously returned by ``func_gen``. - * ``Tx.float32()`` — same FFI call as before (returns ``Var``). - * ``Tx.float32[N]`` — returns ``LocalVectorAnnotation("float32", (N,))``. - * ``Tx.float32[M, N]`` — returns ``LocalVectorAnnotation("float32", (M, N))``. - * ``x: Tx.float32`` — parser calls this object, gets a ``Var``. + * ``T.float32()`` — same FFI call as before (returns ``Var``). + * ``T.float32[N]`` — returns ``LocalVectorAnnotation("float32", (N,))``. + * ``T.float32[M, N]`` — returns ``LocalVectorAnnotation("float32", (M, N))``. + * ``x: T.float32`` — parser calls this object, gets a ``Var``. """ def __init__(self, ffi_name: str, dtype_str: str): @@ -1898,11 +1860,11 @@ def alloc_tcgen05_ldst_frag(instr_shape, tensor_shape, dtype): Examples -------- M=128 readback (existing dispatch): - ``frag = Tx.alloc_tcgen05_ldst_frag("32x32b", (128, 64), "float32")`` + ``frag = T.alloc_tcgen05_ldst_frag("32x32b", (128, 64), "float32")`` ``Tx.copy_async(frag[:, :], tmem[:, 0:64])`` M=64 readback (.16x64b dispatch): - ``frag = Tx.alloc_tcgen05_ldst_frag("16x64b", (64, 64), "float32")`` + ``frag = T.alloc_tcgen05_ldst_frag("16x64b", (64, 64), "float32")`` ``Tx.copy_async(frag[:, :], tmem[0:64, 0:64])`` """ from tvm.tirx.layout import tcgen05_atom_layout # local import to avoid cycle @@ -3017,10 +2979,20 @@ def wrapped(*args, **kwargs): return wrapped +def _ptx_ldg32(reg, guard, addr, local_addr): + if isinstance(addr, Buffer): + addr = addr[0] + return _tir_op.call_intrin(reg.dtype, "tirx.ptx.ldg32", reg, guard, addr, local_addr) + + +_ptx_ldg32.__tir_op_name__ = "ptx.ldg32" + + class PTXNamespace: """The PTX instruction submodule.""" def __init__(self): + self.ldg32 = _ptx_ldg32 self.ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) # Apache-compatible variant. Same lowered intrinsic as # ``ldmatrix`` but accepts the historical ``(trans, num, dtype, @@ -3393,6 +3365,7 @@ def __init__(self): self.cta_sum = _op_wrapper(_tir_op.cuda_cta_sum) self.cta_max = _op_wrapper(_tir_op.cuda_cta_max) self.cta_min = _op_wrapper(_tir_op.cuda_cta_min) + self.copy_bytes = _op_wrapper(_tir_op.cuda_copy_bytes) self.copy_128b = _op_wrapper(_tir_op.cuda_copy_128b) self.copy_64b = _op_wrapper(_tir_op.cuda_copy_64b) self.copy_32b = _op_wrapper(_tir_op.cuda_copy_32b) @@ -3440,6 +3413,85 @@ def __init__(self): self.hmin2 = _op_wrapper(_tir_op.cuda_hmin2) self.hmax2 = _op_wrapper(_tir_op.cuda_hmax2) self.fp8x4_e4m3_from_float4 = _op_wrapper(_tir_op.cuda_fp8x4_e4m3_from_float4) + setattr(self, "__shfl_sync", self._shfl_sync) + setattr(self, "__shfl_up_sync", self._shfl_up_sync) + setattr(self, "__shfl_down_sync", self._shfl_down_sync) + setattr(self, "__shfl_xor_sync", self._shfl_xor_sync) + setattr(self, "__activemask", self._activemask) + + @staticmethod + def _shfl_sync(mask, var, lane, width): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.cuda.__shfl_sync", mask, var, lane, width) + + @staticmethod + def _shfl_up_sync(mask, var, delta, width): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.cuda.__shfl_up_sync", mask, var, delta, width) + + @staticmethod + def _shfl_down_sync(mask, var, delta, width): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.cuda.__shfl_down_sync", mask, var, delta, width) + + @staticmethod + def _shfl_xor_sync(mask, var, lane_mask, width): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin( + var.dtype, "tirx.cuda.__shfl_xor_sync", mask, var, lane_mask, width + ) + + @staticmethod + def _activemask(): + return _tir_op.call_intrin("uint32", "tirx.cuda.__activemask") + + +class MetalNamespace: + """The Metal intrinsics submodule.""" + + @staticmethod + def simd_shuffle(var, lane): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle", var, lane) + + @staticmethod + def simd_shuffle_up(var, delta): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_up", var, delta) + + @staticmethod + def simd_shuffle_down(var, delta): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_down", var, delta) + + +class WebGPUNamespace: + """The WebGPU intrinsics submodule.""" + + @staticmethod + def subgroup_shuffle(var, lane): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.webgpu.subgroup_shuffle", var, lane) + + @staticmethod + def subgroup_shuffle_up(var, delta): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.webgpu.subgroup_shuffle_up", var, delta) + + @staticmethod + def subgroup_shuffle_down(var, delta): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.webgpu.subgroup_shuffle_down", var, delta) class NVSHMEMNamespace: @@ -3524,6 +3576,8 @@ def __init__(self): ptx = PTXNamespace() cuda = CUDANamespace() +metal = MetalNamespace() +webgpu = WebGPUNamespace() nvshmem = NVSHMEMNamespace() nki = NKINamespace() @@ -3534,11 +3588,23 @@ def __init__(self): # This keeps parser and printer consistent using a single registration source. # def _register_tir_namespace_printer_names(): + def register_printer_name(op_name, script_name): + try: + ir.Op.get(op_name) + except Exception: + return + try: + _register_op_attr(op_name, "TScriptPrinterName", script_name, level=20) + except Exception: + pass + def visit(ns_obj, dotted_prefix): # If the namespace object itself maps to an op via __call__ call_op = getattr(ns_obj, "__tir_call_op_name__", None) if call_op: - _register_op_attr(f"tirx.{call_op}", "TScriptPrinterName", dotted_prefix, level=20) + flat_name = f"tirx.{call_op}" + for op_name in {flat_name, _tir_op._canonical_device_intrin_name(flat_name)}: + register_printer_name(op_name, dotted_prefix) # Walk attributes to find wrapped ops and sub-namespaces for name in dir(ns_obj): if name.startswith("_"): @@ -3554,13 +3620,16 @@ def visit(ns_obj, dotted_prefix): # Wrapped op (callable with attached __tir_op_name__) op_name = getattr(val, "__tir_op_name__", None) if callable(val) and op_name: - _register_op_attr( - f"tirx.{op_name}", "TScriptPrinterName", f"{dotted_prefix}.{name}", level=20 - ) + flat_name = f"tirx.{op_name}" + script_name = f"{dotted_prefix}.{name}" + for full_op_name in {flat_name, _tir_op._canonical_device_intrin_name(flat_name)}: + register_printer_name(full_op_name, script_name) try: visit(ptx, "ptx") visit(cuda, "cuda") + visit(metal, "metal") + visit(webgpu, "webgpu") visit(nvshmem, "nvshmem") visit(nki, "nki") except Exception: @@ -3601,6 +3670,7 @@ def visit(ns_obj, dotted_prefix): floordiv = _op_wrapper(_tir_op.floordiv) floormod = _op_wrapper(_tir_op.floormod) fmod = _op_wrapper(_tir_op.fmod) +fma = _op_wrapper(_tir_op.fma) hypot = _op_wrapper(_tir_op.hypot) if_then_else = _op_wrapper(_tir_op.if_then_else) infinity = _op_wrapper(_tir_op.infinity) @@ -3923,6 +3993,7 @@ def visit(ns_obj, dotted_prefix): "floordiv", "floormod", "fmod", + "fma", "filter", "selector", "hypot", @@ -4098,6 +4169,7 @@ def visit(ns_obj, dotted_prefix): "S", "ScopeIdDef", "SwizzleLayout", + "TensorMap", "TileLayout", "Var", "add_to_parent", @@ -4106,9 +4178,7 @@ def visit(ns_obj, dotted_prefix): "alloc_scalar", "alloc_shared", "alloc_tcgen05_ldst_frag", - "cluster", "cluster_id", - "cta", "cta_id", "cta_id_in_cluster", "cta_id_in_pair", @@ -4117,6 +4187,8 @@ def visit(ns_obj, dotted_prefix): "device_entry", "lane_id", "local_scalar", + "meta_class", + "metal", "nki", "nvshmem", "ptx", @@ -4125,15 +4197,13 @@ def visit(ns_obj, dotted_prefix): "shared_scalar", "smem", "static_assert", - "thread", "thread_id", "thread_id_in_wg", "tmem", - "warp", "warp_id", "warp_id_in_wg", - "warpgroup", "warpgroup_id", + "webgpu", ] # Shorthand dtype aliases diff --git a/python/tvm/tirx/script/builder/tirx.py b/python/tvm/tirx/script/builder/tirx.py index 880efe13880b..23f702ebc570 100644 --- a/python/tvm/tirx/script/builder/tirx.py +++ b/python/tvm/tirx/script/builder/tirx.py @@ -22,6 +22,7 @@ import tvm.tirx.operator as tirx_op from tvm.ir import Op from tvm.tirx import Buffer, BufferRegion, PrimExpr +from tvm.tirx.exec_scope import _SCOPE_KIND_TO_NAME, ExecScope from tvm.tirx.expr import FloatImm from tvm.tirx.lang.alloc_pool import SMEMPool, TMEMPool, TMEMStages from tvm.tirx.predicate import Predicate @@ -30,6 +31,97 @@ from .ir import decl_buffer, meta_class +def _normalize_scope(scope) -> ExecScope: + """Normalize a scope selector to an ``ExecScope``. + + Accepts an ``ExecScope`` (passed through), a scope-name ``str`` + (e.g. ``"warp"``, normalized via the FFI ctor / ``StringToScopeKind``), + or an ``int`` ``ScopeKind`` value. ``None`` resolves to the default + ``thread`` scope, keeping the default in one place. + """ + if scope is None: + return ExecScope("thread") + if isinstance(scope, ExecScope): + return scope + if isinstance(scope, str): + return ExecScope(scope) + if isinstance(scope, int): + return ExecScope(_SCOPE_KIND_TO_NAME[scope]) + raise TypeError(f"Cannot interpret {scope!r} as an execution scope") + + +class ScopedOp: + """Make a tile-primitive op callable at the default ``thread`` scope. + + A bare ``Tx.copy(...)`` emits a call at ``thread`` scope. To cooperate at a + wider scope, reach the op through a scope namespace -- ``Tx.warp.copy(...)``, + ``Tx.wg.sum(...)``, ``Tx.cta.fill(...)`` (see :class:`ScopeNamespace`). + + The wrapped ``fn`` must accept a keyword-only ``scope`` parameter that it + threads into the constructed ``TilePrimitiveCall``. + """ + + def __init__(self, fn): + self._fn = fn + functools.update_wrapper(self, fn) + + def __call__(self, *args, **kwargs): + return self._fn(*args, scope=ExecScope("thread"), **kwargs) + + def _bind(self, scope: ExecScope): + """Return a callable that emits this op at ``scope``. + + Used by :class:`ScopeNamespace`; not part of the user-facing surface. + """ + return lambda *args, **kwargs: self._fn(*args, scope=scope, **kwargs) + + +class ScopeNamespace: + """Bind a cooperation scope to every tile primitive reached through it. + + ``Tx.cluster`` / ``Tx.cta`` / ``Tx.wg`` (warpgroup) / ``Tx.warp`` are the + instances exposed on the ``Tx`` surface. Attribute access resolves a + tile-primitive op name against the public ``Tx`` surface (registered and + dynamic ops alike) and binds this namespace's scope, so + ``Tx.warp.copy(dst, src)`` emits a copy at warp scope and + ``Tx.cta.sum(out, x)`` reduces at CTA scope. A bare ``Tx.copy(...)`` (no + namespace prefix) stays at the default ``thread`` scope. + """ + + def __init__(self, scope, label: str): + self._scope = _normalize_scope(scope) + self._label = label + + def __repr__(self): + return f"" + + def __getattr__(self, name: str): + if name.startswith("_"): + raise AttributeError(name) + from tvm.tirx.script import tile as _tile_script + + op = getattr(_tile_script, name) + if not isinstance(op, ScopedOp): + # AttributeError (not TypeError) so hasattr()/getattr(..., default) + # degrade gracefully on a scope namespace. + raise AttributeError( + f"'Tx.{self._label}.{name}' is not a tile primitive; the " + f"'Tx.{self._label}.' scope prefix applies only to tile primitives" + ) + return op._bind(self._scope) + + +# Scope-prefix namespaces: ``Tx.warp.copy(...)`` / ``Tx.wg.sum(...)`` / +# ``Tx.cta.fill(...)`` / ``Tx.cluster.copy(...)``. ``wg`` == warpgroup. A bare +# ``Tx.copy(...)`` (no prefix) runs at the default ``thread`` scope. +cluster = ScopeNamespace("cluster", "cluster") +cta = ScopeNamespace("cta", "cta") +wg = ScopeNamespace("warpgroup", "wg") +warpgroup = ScopeNamespace("warpgroup", "warpgroup") # full-name alias of ``wg`` +warp = ScopeNamespace("warp", "warp") +thread = ScopeNamespace("thread", "thread") + + def _is_buffer_or_region(x): return isinstance(x, Buffer | BufferRegion) @@ -50,11 +142,13 @@ def _wrap_elem_in_tuple(e): f_insert = _ffi_api.TilePrimitiveCall # pylint: disable=no-member +@ScopedOp def zero( dst: BufferRegion | Buffer, src: BufferRegion | Buffer | None = None, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Zero out all elements in src and store to dst. @@ -78,9 +172,12 @@ def zero( config = kwargs or {} dst = _to_region(dst) src = _to_region(src) - return f_insert(tirx_op.Zero(dst, src, workspace=workspace, config=config, dispatch=dispatch)) + return f_insert( + tirx_op.Zero(dst, src, workspace=workspace, config=config, dispatch=dispatch, scope=scope) + ) +@ScopedOp def sqrt( dst: BufferRegion | Buffer, src: BufferRegion | Buffer | None = None, @@ -88,6 +185,7 @@ def sqrt( scale: FloatImm | None = None, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Sqrt all elements in src and store to dst. @@ -127,16 +225,27 @@ def sqrt( if bias is not None and isinstance(bias, Buffer): bias = _to_region(bias) return f_insert( - tirx_op.Sqrt(dst, src, bias, scale, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Sqrt( + dst, + src, + bias, + scale, + workspace=workspace, + config=config, + dispatch=dispatch, + scope=scope, + ) ) +@ScopedOp def add( dst: BufferRegion | Buffer, src1: BufferRegion | Buffer | FloatImm, src2: BufferRegion | Buffer | FloatImm, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Add data from src1 and src2, store to dst. @@ -164,16 +273,20 @@ def add( if isinstance(src2, Buffer): src2 = _to_region(src2) return f_insert( - tirx_op.Add(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Add( + dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch, scope=scope + ) ) +@ScopedOp def sub( dst: BufferRegion | Buffer, src1: BufferRegion | Buffer, src2: BufferRegion | Buffer | FloatImm, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Sub data from src2 to src1, store to dst. @@ -201,16 +314,20 @@ def sub( if isinstance(src2, Buffer): src2 = _to_region(src2) return f_insert( - tirx_op.Sub(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Sub( + dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch, scope=scope + ) ) +@ScopedOp def mul( dst: BufferRegion | Buffer, src1: BufferRegion | Buffer | FloatImm, src2: BufferRegion | Buffer | FloatImm, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Multiply data from src1 and src2, store to dst. @@ -238,16 +355,20 @@ def mul( if isinstance(src2, Buffer): src2 = _to_region(src2) return f_insert( - tirx_op.Mul(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Mul( + dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch, scope=scope + ) ) +@ScopedOp def fdiv( dst: BufferRegion | Buffer, src1: BufferRegion | Buffer, src2: BufferRegion | Buffer | FloatImm, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """(Float) Div data from src2 to src1, store to dst. @@ -274,10 +395,13 @@ def fdiv( if isinstance(src2, Buffer): src2 = _to_region(src2) return f_insert( - tirx_op.FDiv(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.FDiv( + dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch, scope=scope + ) ) +@ScopedOp def fma( dst: BufferRegion | Buffer, src: BufferRegion | Buffer, @@ -285,6 +409,7 @@ def fma( bias: BufferRegion | Buffer | PrimExpr, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Fused multiply-add: dst = src * scale + bias. @@ -316,12 +441,27 @@ def fma( if isinstance(bias, Buffer): bias = _to_region(bias) return f_insert( - tirx_op.FMA(dst, src, scale, bias, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.FMA( + dst, + src, + scale, + bias, + workspace=workspace, + config=config, + dispatch=dispatch, + scope=scope, + ) ) +@ScopedOp def cast( - dst, src=None, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, **kwargs + dst, + src=None, + workspace: dict[str, Buffer] | None = None, + dispatch: str | None = None, + scope: ExecScope | None = None, + **kwargs, ): """Cast — overloaded. @@ -344,14 +484,18 @@ def cast( config = kwargs or {} dst = _to_region(dst) src = _to_region(src) - return f_insert(tirx_op.Cast(dst, src, workspace=workspace, config=config, dispatch=dispatch)) + return f_insert( + tirx_op.Cast(dst, src, workspace=workspace, config=config, dispatch=dispatch, scope=scope) + ) +@ScopedOp def copy( dst: BufferRegion | Buffer, src: BufferRegion | Buffer, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Copy data from src to dst. @@ -372,14 +516,18 @@ def copy( config = kwargs or {} dst = _to_region(dst) src = _to_region(src) - return f_insert(tirx_op.Copy(dst, src, workspace=workspace, config=config, dispatch=dispatch)) + return f_insert( + tirx_op.Copy(dst, src, workspace=workspace, config=config, dispatch=dispatch, scope=scope) + ) +@ScopedOp def copy_async( dst: BufferRegion | Buffer, src: BufferRegion | Buffer, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): if workspace is None: @@ -388,10 +536,13 @@ def copy_async( dst = _to_region(dst) src = _to_region(src) return f_insert( - tirx_op.CopyAsync(dst, src, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.CopyAsync( + dst, src, workspace=workspace, config=config, dispatch=dispatch, scope=scope + ) ) +@ScopedOp def gemm_async( C: BufferRegion | Buffer, A: BufferRegion | Buffer, @@ -403,6 +554,7 @@ def gemm_async( accum: bool = False, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """General matrix multiplication asynchronously. @@ -461,20 +613,32 @@ def gemm_async( workspace=workspace, config=config, dispatch=dispatch, + scope=scope, ) ) return f_insert( tirx_op.GemmAsync( - C, A, B, transA, transB, accum, workspace=workspace, config=config, dispatch=dispatch + C, + A, + B, + transA, + transB, + accum, + workspace=workspace, + config=config, + dispatch=dispatch, + scope=scope, ) ) +@ScopedOp def fill( dst: BufferRegion | Buffer, value: PrimExpr, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Fill the buffer region with the value. @@ -494,9 +658,12 @@ def fill( workspace = {} config = kwargs or {} dst = _to_region(dst) - return f_insert(tirx_op.Fill(dst, value, workspace=workspace, config=config, dispatch=dispatch)) + return f_insert( + tirx_op.Fill(dst, value, workspace=workspace, config=config, dispatch=dispatch, scope=scope) + ) +@ScopedOp def gemm( D: BufferRegion | Buffer, A: BufferRegion | Buffer, @@ -508,6 +675,7 @@ def gemm( beta: PrimExpr = 0.0, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """General matrix multiplication. @@ -563,10 +731,12 @@ def gemm( workspace=workspace, config=config, dispatch=dispatch, + scope=scope, ) ) +@ScopedOp def sum( dst: BufferRegion | Buffer, src: BufferRegion | Buffer, @@ -574,6 +744,7 @@ def sum( accum: bool = False, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """ @@ -603,10 +774,20 @@ def sum( src = _to_region(src) axes = _wrap_elem_in_tuple(axes) return f_insert( - tirx_op.Sum(dst, src, axes, accum, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Sum( + dst, + src, + axes, + accum, + workspace=workspace, + config=config, + dispatch=dispatch, + scope=scope, + ) ) +@ScopedOp def max( dst, src=None, @@ -614,6 +795,7 @@ def max( accum: bool = False, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Max — overloaded. @@ -633,10 +815,20 @@ def max( src = _to_region(src) axes = _wrap_elem_in_tuple(axes) return f_insert( - tirx_op.Max(dst, src, axes, accum, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Max( + dst, + src, + axes, + accum, + workspace=workspace, + config=config, + dispatch=dispatch, + scope=scope, + ) ) +@ScopedOp def min( dst, src=None, @@ -644,6 +836,7 @@ def min( accum: bool = False, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Min — overloaded. @@ -662,15 +855,26 @@ def min( src = _to_region(src) axes = _wrap_elem_in_tuple(axes) return f_insert( - tirx_op.Min(dst, src, axes, accum, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Min( + dst, + src, + axes, + accum, + workspace=workspace, + config=config, + dispatch=dispatch, + scope=scope, + ) ) +@ScopedOp def reciprocal( dst: BufferRegion | Buffer, src: BufferRegion | Buffer | None = None, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Reciprocal all elements in src and store to dst. @@ -700,15 +904,19 @@ def reciprocal( dst = _to_region(dst) src = _to_region(src) return f_insert( - tirx_op.Reciprocal(dst, src, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Reciprocal( + dst, src, workspace=workspace, config=config, dispatch=dispatch, scope=scope + ) ) +@ScopedOp def silu( dst: BufferRegion | Buffer, src: BufferRegion | Buffer, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Compute SiLU (x * sigmoid(x)) for all elements in src and store to dst. @@ -734,14 +942,18 @@ def silu( config = kwargs or {} dst = _to_region(dst) src = _to_region(src) - return f_insert(tirx_op.SiLU(dst, src, workspace=workspace, config=config, dispatch=dispatch)) + return f_insert( + tirx_op.SiLU(dst, src, workspace=workspace, config=config, dispatch=dispatch, scope=scope) + ) +@ScopedOp def memset( dst: BufferRegion | Buffer, value: PrimExpr, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Set all elements in dst to value. @@ -762,16 +974,20 @@ def memset( config = kwargs or {} dst = _to_region(dst) return f_insert( - tirx_op.Memset(dst, value, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Memset( + dst, value, workspace=workspace, config=config, dispatch=dispatch, scope=scope + ) ) +@ScopedOp def maximum( dst: BufferRegion | Buffer, src1: BufferRegion | Buffer | FloatImm, src2: BufferRegion | Buffer | FloatImm, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Maximum all elements in src1 and src2 and store to dst. @@ -799,16 +1015,20 @@ def maximum( if isinstance(src2, Buffer): src2 = _to_region(src2) return f_insert( - tirx_op.Maximum(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Maximum( + dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch, scope=scope + ) ) +@ScopedOp def minimum( dst: BufferRegion | Buffer, src1: BufferRegion | Buffer | FloatImm, src2: BufferRegion | Buffer | FloatImm, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Minimum all elements in src1 and src2 and store to dst. @@ -836,10 +1056,13 @@ def minimum( if isinstance(src2, Buffer): src2 = _to_region(src2) return f_insert( - tirx_op.Minimum(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Minimum( + dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch, scope=scope + ) ) +@ScopedOp def exp( dst: BufferRegion | Buffer, src: BufferRegion | Buffer | None = None, @@ -847,6 +1070,7 @@ def exp( scale: FloatImm | None = None, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Exponentiate all elements in src and store to dst. @@ -884,10 +1108,20 @@ def exp( if bias is not None and isinstance(bias, Buffer): bias = _to_region(bias) return f_insert( - tirx_op.Exp(dst, src, bias, scale, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Exp( + dst, + src, + bias, + scale, + workspace=workspace, + config=config, + dispatch=dispatch, + scope=scope, + ) ) +@ScopedOp def exp2( dst: BufferRegion | Buffer, src: BufferRegion | Buffer | None = None, @@ -895,6 +1129,7 @@ def exp2( scale: FloatImm | None = None, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Compute base-2 exponential (2^x) of all elements in src and store to dst. @@ -932,7 +1167,16 @@ def exp2( if bias is not None and isinstance(bias, Buffer): bias = _to_region(bias) return f_insert( - tirx_op.Exp2(dst, src, bias, scale, workspace=workspace, config=config, dispatch=dispatch) + tirx_op.Exp2( + dst, + src, + bias, + scale, + workspace=workspace, + config=config, + dispatch=dispatch, + scope=scope, + ) ) @@ -962,6 +1206,7 @@ def tvm_kernel_replace_point(): return f_insert(tirx_op.KernelReplacePoint(workspace={}, config={})) +@ScopedOp def binary_reduce( binary_output: BufferRegion | Buffer, reduce_output: BufferRegion | Buffer, @@ -972,6 +1217,7 @@ def binary_reduce( reduce_axes: int | tuple[int] = -1, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Combine a binary operation with a reduction operation. @@ -1033,10 +1279,12 @@ def binary_reduce( workspace=workspace, config=config, dispatch=dispatch, + scope=scope, ) ) +@ScopedOp def unary_reduce( unary_output: BufferRegion | Buffer, reduce_output: BufferRegion | Buffer, @@ -1048,6 +1296,7 @@ def unary_reduce( reduce_axes: int | tuple[int] = -1, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Combine a unary operation with a reduction operation. @@ -1114,10 +1363,12 @@ def unary_reduce( workspace=workspace, config=config, dispatch=dispatch, + scope=scope, ) ) +@ScopedOp def binary_chain( output: BufferRegion | Buffer, data: BufferRegion | Buffer, @@ -1128,6 +1379,7 @@ def binary_chain( reverse1: bool = False, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Chain multiple binary operations together. @@ -1194,10 +1446,12 @@ def binary_chain( workspace=workspace, config=config, dispatch=dispatch, + scope=scope, ) ) +@ScopedOp def reduce_negate( output: BufferRegion | Buffer, input: BufferRegion | Buffer, @@ -1206,6 +1460,7 @@ def reduce_negate( accum: bool = False, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Negate the result of a reduction operation. @@ -1253,15 +1508,18 @@ def reduce_negate( workspace=workspace, config=config, dispatch=dispatch, + scope=scope, ) ) +@ScopedOp def select( dst: BufferRegion | Buffer, true_value: BufferRegion | Buffer | FloatImm, false_value: BufferRegion | Buffer | FloatImm, pred: Predicate | Callable[..., PrimExpr], + scope: ExecScope | None = None, ): """Select between two values based on a predicate. @@ -1286,7 +1544,7 @@ def select( false_value = _to_region(false_value) if not isinstance(pred, Predicate): pred = Predicate(pred) - return f_insert(tirx_op.Select(dst, true_value, false_value, pred)) + return f_insert(tirx_op.Select(dst, true_value, false_value, pred, scope=scope)) def reshape(buffer: Buffer, shape: list[PrimExpr]): @@ -1325,11 +1583,13 @@ def reshape(buffer: Buffer, shape: list[PrimExpr]): ) +@ScopedOp def permute_layout( dst: BufferRegion | Buffer, src: BufferRegion | Buffer, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, **kwargs, ): """Move data so the buffer's bytes are arranged under a different layout. @@ -1368,21 +1628,26 @@ def _to_region(b): workspace=workspace, config=config, dispatch=dispatch, + scope=scope, ) ) __all__ = [ "SMEMPool", + "ScopeNamespace", + "ScopedOp", "TMEMPool", "TMEMStages", "add", "binary_chain", "binary_reduce", "cast", + "cluster", "compose_op", "copy", "copy_async", + "cta", "exp", "exp2", "fdiv", @@ -1405,7 +1670,11 @@ def _to_region(b): "sqrt", "sub", "sum", + "thread", "tvm_kernel_replace_point", "unary_reduce", + "warp", + "warpgroup", + "wg", "zero", ] diff --git a/python/tvm/tirx/script/parser/__init__.py b/python/tvm/tirx/script/parser/__init__.py index 5f6b8d38f1d0..7d59fc17545a 100644 --- a/python/tvm/tirx/script/parser/__init__.py +++ b/python/tvm/tirx/script/parser/__init__.py @@ -38,6 +38,9 @@ __all__ = _tir.__all__ + [ "Buffer", "Ptr", + "SMEMPool", + "TMEMPool", + "TMEMStages", "bool", "constexpr", "inline", diff --git a/python/tvm/tirx/script/parser/entry.py b/python/tvm/tirx/script/parser/entry.py index 9bd40f51b5a1..8f922ff851d8 100644 --- a/python/tvm/tirx/script/parser/entry.py +++ b/python/tvm/tirx/script/parser/entry.py @@ -215,12 +215,12 @@ class TIRJit: """Top-level kernel decorator with constexpr params + ``.specialize()``. Parses the function body lazily: parsing is deferred until ``.specialize()`` - supplies concrete values for the params annotated as ``Tx.constexpr``. The + supplies concrete values for the params annotated as ``T.constexpr``. The return type of ``.specialize()`` is a ``tvm.tirx.PrimFunc``, identical in - type to what ``@Tx.prim_func`` produces today. + type to what ``@T.prim_func`` produces today. Constexpr params are removed from the resulting PrimFunc's parameter list; - their values are baked into the IR (e.g. into ``Tx.Buffer((M, K), ...)`` + their values are baked into the IR (e.g. into ``T.Buffer((M, K), ...)`` shape annotations and into the body). """ @@ -240,11 +240,11 @@ def __init__( # Resolved closure vars (computed once; the function itself is the # capture point, so this never changes between specializations). self._closure_vars: dict[str, Any] = utils.inspect_function_capture(func) - # Detect which params are marked Tx.constexpr. With PEP 563 + # Detect which params are marked T.constexpr. With PEP 563 # (``from __future__ import annotations``), each annotation is a # string; we eval them one-by-one so a constexpr probe is not # blocked by sibling annotations that reference yet-undefined names - # (e.g. ``A: Tx.Buffer((N,), ...)`` referencing constexpr ``N``). + # (e.g. ``A: T.Buffer((N,), ...)`` referencing constexpr ``N``). raw_anns = getattr(func, "__annotations__", {}) or {} eval_globals = {**func.__globals__, **self._closure_vars} sig = inspect.signature(func) @@ -271,7 +271,7 @@ def specialize(self, **constexpr_kwargs) -> PrimFunc: Parameters ---------- **constexpr_kwargs - One value per ``Tx.constexpr``-annotated parameter. All such + One value per ``T.constexpr``-annotated parameter. All such parameters must be supplied; passing names that are not constexpr-annotated is an error. @@ -279,7 +279,7 @@ def specialize(self, **constexpr_kwargs) -> PrimFunc: ------- PrimFunc A concrete TIRx PrimFunc, identical in type to the output of - ``@Tx.prim_func``. + ``@T.prim_func``. """ extra = constexpr_kwargs.keys() - self.constexpr_names if extra: @@ -327,24 +327,23 @@ def jit( ) -> "TIRJit | Callable": """Decorator: capture the kernel and defer parsing until ``.specialize()``. - Use ``@Tx.jit`` (instead of ``@Tx.prim_func``) when the kernel takes - compile-time parameters annotated with ``Tx.constexpr``. The resulting + Use ``@T.jit`` (instead of ``@T.prim_func``) when the kernel takes + compile-time parameters annotated with ``T.constexpr``. The resulting object exposes ``.specialize(**constexpr_kwargs)``, which returns a ``tvm.tirx.PrimFunc``. Example:: - from tvm.script import tirx as Tx + from tvm.script import tirx as T - @Tx.jit + @T.jit def add( - A: Tx.Buffer((N,), "float32"), - B: Tx.Buffer((N,), "float32"), + A: T.Buffer((N,), "float32"), + B: T.Buffer((N,), "float32"), *, - N: Tx.constexpr, + N: T.constexpr, ): - with Tx.thread(): - ... + ... kernel = add.specialize(N=1024) # returns a PrimFunc """ @@ -503,7 +502,7 @@ def __getitem__(self, keys): class _ConstexprProxy: """Sentinel marker for compile-time (specialization-time) parameters. - Used as a parameter annotation in ``@Tx.jit`` decorated functions to mark + Used as a parameter annotation in ``@T.jit`` decorated functions to mark a parameter as constexpr — its value is supplied to ``.specialize(**kwargs)`` rather than at call time, and it is removed from the generated PrimFunc's runtime parameter list. diff --git a/python/tvm/tirx/script/parser/parser.py b/python/tvm/tirx/script/parser/parser.py index 8e5f7a1b6257..54c18db374d8 100644 --- a/python/tvm/tirx/script/parser/parser.py +++ b/python/tvm/tirx/script/parser/parser.py @@ -641,7 +641,7 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: if ann is None: raise if ann is _constexpr_sentinel: - # Tx.constexpr param: value was bound in extra_vars by + # T.constexpr param: value was bound in extra_vars by # TIRJit.specialize() and lives in an outer var_table # frame; do not register a runtime PrimFunc param. continue diff --git a/python/tvm/tirx/script/tile.py b/python/tvm/tirx/script/tile.py new file mode 100644 index 000000000000..bbc2c131bad0 --- /dev/null +++ b/python/tvm/tirx/script/tile.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tile primitive shorthand namespace for TIRx TVMScript.""" + +import functools + +from tvm.tirx import Buffer, BufferRegion + +from .builder import tirx as _builder + +_TILE_ARG_TYPES = (Buffer, BufferRegion) + + +def _get_arg(args, kwargs, index, name): + if len(args) > index: + return args[index] + return kwargs.get(name) + + +def _require_buffer_arg(op_name, arg_name, value): + if not isinstance(value, _TILE_ARG_TYPES): + raise TypeError( + f"Tx.{op_name} is tile-only and expects `{arg_name}` to be a Buffer " + f"or BufferRegion; use T.{op_name} for expression/builtin calls" + ) + + +def _validate_tile_call(op_name, args, kwargs): + dst = _get_arg(args, kwargs, 0, "dst") + _require_buffer_arg(op_name, "dst", dst) + + if op_name in {"cast", "max", "min", "permute_layout", "silu"}: + src = _get_arg(args, kwargs, 1, "src") + _require_buffer_arg(op_name, "src", src) + elif op_name in {"sqrt", "exp", "exp2", "reciprocal"}: + src = _get_arg(args, kwargs, 1, "src") + if src is not None: + _require_buffer_arg(op_name, "src", src) + + +def _tile_scoped_op(op_name): + scoped_op = getattr(_builder, op_name) + + @functools.wraps(scoped_op._fn) # pylint: disable=protected-access + def wrapper(*args, scope=None, **kwargs): + _validate_tile_call(op_name, args, kwargs) + return scoped_op._fn(*args, scope=scope, **kwargs) # pylint: disable=protected-access + + return _builder.ScopedOp(wrapper) + + +_SCOPED_TILE_OP_NAMES = [ + "add", + "binary_chain", + "binary_reduce", + "cast", + "copy", + "copy_async", + "exp", + "exp2", + "fdiv", + "fill", + "fma", + "gemm", + "gemm_async", + "max", + "maximum", + "memset", + "min", + "minimum", + "mul", + "permute_layout", + "reciprocal", + "reduce_negate", + "select", + "silu", + "sqrt", + "sub", + "sum", + "unary_reduce", + "zero", +] + +for _op_name in _SCOPED_TILE_OP_NAMES: + globals()[_op_name] = _tile_scoped_op(_op_name) + +cluster = _builder.ScopeNamespace("cluster", "cluster") +cta = _builder.ScopeNamespace("cta", "cta") +wg = _builder.ScopeNamespace("warpgroup", "wg") +warpgroup = _builder.ScopeNamespace("warpgroup", "warpgroup") +warp = _builder.ScopeNamespace("warp", "warp") +thread = _builder.ScopeNamespace("thread", "thread") + +compose_op = _builder.compose_op +tvm_kernel_replace_point = _builder.tvm_kernel_replace_point + +__all__ = [ + *_SCOPED_TILE_OP_NAMES, + "cluster", + "compose_op", + "cta", + "thread", + "tvm_kernel_replace_point", + "warp", + "warpgroup", + "wg", +] diff --git a/python/tvm/tirx/stmt.py b/python/tvm/tirx/stmt.py index 4972c715188a..532bf35b254a 100644 --- a/python/tvm/tirx/stmt.py +++ b/python/tvm/tirx/stmt.py @@ -815,39 +815,6 @@ def __init__( ) # type: ignore -@tvm_ffi.register_object("tirx.ExecScopeStmt") -class ExecScopeStmt(Stmt): - """ExecScopeStmt node. - - A statement that annotates the execution scope (e.g. cta, warp, thread) - for its body. This decouples the execution scope concept from SBlock. - - Parameters - ---------- - exec_scope : ExecScope - The execution scope. - - body : Stmt - The body statement under this execution scope. - - span : Optional[Span] - The location of this statement in the source code. - """ - - exec_scope: ExecScope - body: Stmt - span: Span | None - - def __init__(self, exec_scope: ExecScope, body: Stmt, span: Span | None = None) -> None: - body = _normalize_legacy_stmt(body) - self.__init_handle_by_constructor__( - _ffi_api.ExecScopeStmt, # type: ignore - exec_scope, - body, - span, - ) # type: ignore - - @tvm_ffi.register_object("tirx.ScopeIdDefStmt") class ScopeIdDefStmt(Stmt): """ScopeIdDefStmt node. @@ -975,12 +942,16 @@ class TilePrimitiveCall(Stmt): dispatch : Optional[str] The explicit variant name to dispatch to. + + scope : ExecScope + The cooperation scope of this call. Defaults to ``thread`` (an unscoped call). """ args: list[PrimExpr] workspace: dict[str, Buffer] config: dict[str, Any] dispatch: str | None + scope: ExecScope _registry: ClassVar[dict[Op, type["TilePrimitiveCall"]]] = {} def __init__( @@ -990,11 +961,14 @@ def __init__( workspace: dict[str, Buffer] | None = None, config: dict[str, Any] | None = None, dispatch: str | None = None, + scope: ExecScope | None = None, ) -> None: if workspace is None: workspace = {} if config is None: config = {} + if scope is None: + scope = ExecScope("thread") if op is None: assert self.__class__ != TilePrimitiveCall, ( "Directly instantiating TilePrimitiveCall needs to specify the op" @@ -1007,7 +981,8 @@ def __init__( args, workspace, config, - dispatch, # pylint: disable=no-member + dispatch, + scope, # pylint: disable=no-member ) def __init_subclass__(cls, **kwargs): @@ -1027,6 +1002,41 @@ def downcast(cls, instance: "TilePrimitiveCall") -> "TilePrimitiveCall": ) return new_instance + def replace(self, **changes: Any) -> "TilePrimitiveCall": + """Return a copy of this call with selected fields replaced. + + Every field that is not overridden in ``changes`` is preserved from + ``self`` (including ``scope``), so rebuilds never silently drop fields. + The returned node is downcast to the registered subclass for ``op``. + + Parameters + ---------- + **changes : Any + Field overrides; any of ``op``, ``args``, ``workspace``, ``config``, + ``dispatch``, ``scope``. + + Returns + ------- + new_call : TilePrimitiveCall + A new call with the requested fields replaced. + """ + unknown = set(changes) - {"op", "args", "workspace", "config", "dispatch", "scope"} + if unknown: + raise TypeError(f"Unknown field(s) for TilePrimitiveCall.replace: {sorted(unknown)}") + new_call = TilePrimitiveCall( + *changes.get("args", self.args), + op=changes.get("op", self.op), + workspace=changes.get("workspace", self.workspace), + config=changes.get("config", self.config), + dispatch=changes.get("dispatch", self.dispatch), + scope=changes.get("scope", self.scope), + ) + return TilePrimitiveCall.downcast(new_call) + + def with_workspace(self, workspace: dict[str, Buffer]) -> "TilePrimitiveCall": + """Return a copy with ``workspace`` replaced, preserving all other fields.""" + return self.replace(workspace=workspace) + @property def srcs(self) -> list[PrimExpr]: raise NotImplementedError("Subclass must implement this method") diff --git a/python/tvm/tirx/stmt_functor.py b/python/tvm/tirx/stmt_functor.py index c67032d4b047..33e801dd9559 100644 --- a/python/tvm/tirx/stmt_functor.py +++ b/python/tvm/tirx/stmt_functor.py @@ -53,7 +53,6 @@ def __init__(self): "tirx.Evaluate": self.visit_evaluate_, "tirx.SBlock": self.visit_block_, "tirx.SBlockRealize": self.visit_block_realize_, - "tirx.ExecScopeStmt": self.visit_exec_scope_stmt_, "tirx.ScopeIdDefStmt": self.visit_scope_id_def_stmt_, "tirx.TilePrimitiveCall": self.visit_op_call_, "tirx.AllocBuffer": self.visit_alloc_buffer_, @@ -173,10 +172,6 @@ def visit_block_realize_(self, op): """Visitor for BlockRealize nodes.""" return self.visit_stmt_default_(op) - def visit_exec_scope_stmt_(self, op): - """Visitor for ExecScopeStmt nodes.""" - return self.visit_stmt_default_(op) - def visit_scope_id_def_stmt_(self, op): """Visitor for ScopeIdDefStmt nodes.""" return self.visit_stmt_default_(op) @@ -339,10 +334,6 @@ def visit_block_realize_(self, op): self.visit_expr(op.predicate) self.visit_stmt(op.block) - def visit_exec_scope_stmt_(self, op): - """Visitor implementation for ExecScopeStmt.""" - self.visit_stmt(op.body) - def visit_scope_id_def_stmt_(self, op): """Visitor implementation for ScopeIdDefStmt. @@ -794,15 +785,6 @@ def visit_block_realize_(self, op): return tvm.tirx.SBlockRealize(iter_values, predicate, block) - def visit_exec_scope_stmt_(self, op): - """Mutator implementation for ExecScopeStmt.""" - body = self.visit_stmt(op.body) - - if body is op.body: - return op - - return tvm.tirx.ExecScopeStmt(op.exec_scope, body, op.span) - def visit_scope_id_def_stmt_(self, op): """Mutator implementation for ScopeIdDefStmt. @@ -873,7 +855,12 @@ def visit_op_call_(self, op): return op return tvm.tirx.TilePrimitiveCall( - *new_args, op=op.op, workspace=op.workspace, config=new_config, dispatch=op.dispatch + *new_args, + op=op.op, + workspace=op.workspace, + config=new_config, + dispatch=op.dispatch, + scope=op.scope, ) def visit_buffer_region_(self, op): diff --git a/python/tvm/tirx/transform/common.py b/python/tvm/tirx/transform/common.py index c1475ee4a5c3..16995c1d6c5e 100644 --- a/python/tvm/tirx/transform/common.py +++ b/python/tvm/tirx/transform/common.py @@ -160,7 +160,12 @@ def visit_op_call_(self, op): for arg in op.args: args.append(arg) return TilePrimitiveCall( - *args, op=op.op, workspace=new_workspace, config=new_config, dispatch=op.dispatch + *args, + op=op.op, + workspace=new_workspace, + config=new_config, + dispatch=op.dispatch, + scope=op.scope, ) diff --git a/python/tvm/tirx/transform/trn/private_buffer_alloc.py b/python/tvm/tirx/transform/trn/private_buffer_alloc.py index 76883b42f28d..77908210bafd 100644 --- a/python/tvm/tirx/transform/trn/private_buffer_alloc.py +++ b/python/tvm/tirx/transform/trn/private_buffer_alloc.py @@ -23,7 +23,6 @@ from tvm.tirx.stmt import ( AllocBuffer, AttrStmt, - ExecScopeStmt, For, SeqStmt, Stmt, @@ -38,17 +37,11 @@ class PrivateAllocCollector(StmtVisitor): def __init__(self, target: Target): super().__init__() self.target = target - self.exec_scope_stack_ = [] self.launch_params = {} self.var_range_map = {} self.buffer_dict = {} self.private_buf_refs = {} - def visit_exec_scope_stmt_(self, op: ExecScopeStmt): - self.exec_scope_stack_.append(op.exec_scope) - super().visit_exec_scope_stmt_(op) - self.exec_scope_stack_.pop() - def visit_attr_(self, op: AttrStmt): if op.attr_key == "thread_extent": self.launch_params[op.node.thread_tag] = op.value @@ -59,19 +52,9 @@ def visit_for_(self, op: For): super().visit_for_(op) def visit_op_call_(self, op: TilePrimitiveCall): - # Mirror tile_primitive_dispatch.cc: at the device-region root, - # dispatchers see scope_kind="kernel" so trn dispatchers that key - # off "kernel" continue to fire at the entry. - from tvm.tirx.exec_scope import ExecScope - - if not self.exec_scope_stack_: - # Inside AttrStmt(kDeviceEntry) with no inner ExecScope. - # Provide a placeholder ExecScope (not load-bearing for trn). - scope_kind = "kernel" - exec_scope = ExecScope("thread") - else: - scope_kind = self.exec_scope_stack_[-1].name - exec_scope = self.exec_scope_stack_[-1] + # Scope is a per-call field on the node; read it directly. + exec_scope = op.scope + scope_kind = op.scope.name sctx = DispatchContext( target=self.target, exec_scope=exec_scope, @@ -120,9 +103,7 @@ def visit_op_call_(self, op): return op new_workspace = dict(op.workspace) new_workspace.update(self.added_workspace[op]) - op = TilePrimitiveCall( - *op.args, op=op.op, workspace=new_workspace, config=op.config, dispatch=op.dispatch - ) + op = TilePrimitiveCall.downcast(op).with_workspace(new_workspace) return op diff --git a/src/target/cuda/codegen_cuda.cc b/src/target/cuda/codegen_cuda.cc index fd7a250c120f..6361bd9a6565 100644 --- a/src/target/cuda/codegen_cuda.cc +++ b/src/target/cuda/codegen_cuda.cc @@ -45,6 +45,18 @@ namespace tvm { namespace codegen { +namespace { + +bool IsOp(const tirx::CallNode* call, const Op& compat_op, const char* canonical_name) { + if (call->op.same_as(compat_op)) { + return true; + } + const auto* op_node = call->op.as(); + return op_node != nullptr && op_node->name == canonical_name; +} + +} // namespace + std::string GetFP8Type(DataType type) { std::stringstream stream; int32_t lanes = type.lanes(); @@ -184,8 +196,6 @@ class ThreadIdxExtractor : public tirx::StmtVisitor { if (iv->var->name_hint == "clusterCtaIdx.z" || iv->thread_tag == "clusterCtaIdx.z") { clusterCtaIdx_z_ext = op->value; } - } else if (op->attr_key == tirx::attr::kPersistentKernel) { - is_persistent_kernel = op->value.as()->value; } StmtVisitor::VisitStmt_(op); } @@ -197,17 +207,11 @@ class ThreadIdxExtractor : public tirx::StmtVisitor { PrimExpr clusterCtaIdx_x_ext = IntImm(DataType::Int(32), 1); PrimExpr clusterCtaIdx_y_ext = IntImm(DataType::Int(32), 1); PrimExpr clusterCtaIdx_z_ext = IntImm(DataType::Int(32), 1); - bool is_persistent_kernel = false; }; void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { ThreadIdxExtractor extractor; extractor(f->body); - // Also check PrimFunc attrs for persistent kernel (decorator-level) - bool is_persistent = extractor.is_persistent_kernel; - if (!is_persistent && f->attrs->dict.count(tirx::attr::kPersistentKernel)) { - is_persistent = true; - } arith::Analyzer analyzer; PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * extractor.threadIdx_z_ext); @@ -223,8 +227,11 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { // unable to extract the number of threads per block, hence directly return return; } - if (is_persistent) { - os << " __launch_bounds__(" << threadIdx_ext_int->value << ", 1)"; + auto min_blocks_per_sm = f->GetAttr(tirx::attr::kLaunchBoundsMinBlocksPerSM); + if (min_blocks_per_sm.has_value()) { + TVM_FFI_ICHECK_GT(min_blocks_per_sm.value(), 0); + os << " __launch_bounds__(" << threadIdx_ext_int->value << ", " << min_blocks_per_sm.value() + << ")"; } else { os << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; } @@ -1005,7 +1012,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } - } else if (op->op.same_as(builtin::ptx_mma())) { + } else if (IsOp(op, builtin::ptx_mma(), "tirx.ptx.mma")) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col // arg 2: B layout: row/col @@ -1040,7 +1047,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); this->stream << asm_code; - } else if (op->op.same_as(builtin::ptx_mma_sp())) { + } else if (IsOp(op, builtin::ptx_mma_sp(), "tirx.ptx.mma_sp")) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col // arg 2: B layout: row/col @@ -1136,7 +1143,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; os << dst << "[" << dst_offset << " + i] = 0.0;"; os << "}\n"; - } else if (op->op.same_as(tvm::tirx::builtin::ptx_mma_legacy())) { + } else if (IsOp(op, tvm::tirx::builtin::ptx_mma_legacy(), "tirx.ptx.mma_legacy")) { // args: shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, // a_ptr_var, a_offset, b_ptr_var, b_offset, // c_ptr_var, c_offset, saturate, [bit_op] @@ -1159,7 +1166,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->stream << PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); - } else if (op->op.same_as(tvm::tirx::builtin::ptx_ldmatrix_legacy())) { + } else if (IsOp(op, tvm::tirx::builtin::ptx_ldmatrix_legacy(), "tirx.ptx.ldmatrix_legacy")) { // args: trans, num, type, local_ptr_var, local_offset, smem_ptr_var, smem_offset codegen_tags_.insert("mma"); TVM_FFI_ICHECK_EQ(op->args.size(), 7U); @@ -1236,7 +1243,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; os << dst << "[" << dst_offset << " + i] = 0.0;"; os << "}\n"; - } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) { + } else if (IsOp(op, builtin::ptx_cp_async_bulk(), "tirx.ptx.cp_async_bulk")) { codegen_tags_.insert("cast_smem_ptr_to_int"); std::string dst = this->PrintExpr(op->args[0]); std::string dst_offset = this->PrintExpr(op->args[1]); @@ -1250,7 +1257,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string barrier_arr = barrier_name_ + "_" + std::to_string(barrier_arr_id); std::string barrier = barrier_arr + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier); - } else if (op->op.same_as(builtin::ptx_cp_async_mbarrier_arrive())) { + } else if (IsOp(op, builtin::ptx_cp_async_mbarrier_arrive(), + "tirx.ptx.cp_async_mbarrier_arrive")) { codegen_tags_.insert("cast_smem_ptr_to_int"); int barrier_arr_id = Downcast(op->args[0])->value; int barrier_id = Downcast(op->args[1])->value; @@ -1260,7 +1268,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string barrier_arr = barrier_name_ + "_" + std::to_string(barrier_arr_id); std::string barrier = barrier_arr + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintCpAsyncBarrierAsm(barrier); - } else if (op->op.same_as(builtin::ptx_ldg32())) { + } else if (IsOp(op, builtin::ptx_ldg32(), "tirx.ptx.ldg32")) { /* asm volatile ( "{.reg .pred p;\n" @@ -1522,7 +1530,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "}\n" << "// print_buffer ends\n"; - } else if (op->op.same_as(builtin::cuda_func_call())) { + } else if (op->op.same_as(builtin::cuda_func_call()) || + (op->op.as() && op->op.as().value()->name == "tirx.cuda.func_call")) { print_cuda_func_call(op, os); } else if (op->op.same_as(builtin::thread_return())) { os << "return"; diff --git a/src/target/cuda/intrin_rule_cuda.cc b/src/target/cuda/intrin_rule_cuda.cc index c56da3046bc0..a9aadf1aeed8 100644 --- a/src/target/cuda/intrin_rule_cuda.cc +++ b/src/target/cuda/intrin_rule_cuda.cc @@ -258,7 +258,7 @@ TVM_REGISTER_OP("tirx.tvm_warp_activemask") TVM_REGISTER_OP("tirx.fmod") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -// Register low-level builtin ops. +// Register low-level CUDA device intrinsics. // TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins. TVM_REGISTER_OP("tirx.cuda.__shfl_sync") .set_num_inputs(4) @@ -266,6 +266,9 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_sync") .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane", "Expr", "The source thread id.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), 10) + .set_attr("TScriptPrinterName", ffi::String("cuda.__shfl_sync"), 10) .set_attr("TGlobalSymbol", "__shfl_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); @@ -276,6 +279,10 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_up_sync") .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be added.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), 10) + .set_attr("TScriptPrinterName", ffi::String("cuda.__shfl_up_sync"), + 10) .set_attr("TGlobalSymbol", "__shfl_up_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); @@ -286,6 +293,10 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_down_sync") .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), 10) + .set_attr("TScriptPrinterName", ffi::String("cuda.__shfl_down_sync"), + 10) .set_attr("TGlobalSymbol", "__shfl_down_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); @@ -296,12 +307,19 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_xor_sync") .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane_mask", "Expr", "The lane mask.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), 10) + .set_attr("TScriptPrinterName", ffi::String("cuda.__shfl_xor_sync"), + 10) .set_attr("TGlobalSymbol", "__shfl_xor_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); TVM_REGISTER_OP("tirx.cuda.__activemask") .set_num_inputs(0) + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), 10) + .set_attr("TScriptPrinterName", ffi::String("cuda.__activemask"), 10) .set_attr("TGlobalSymbol", "__activemask") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)) .set_attr("cuda.need_warp_shuffle", true); diff --git a/src/target/hexagon/llvm/intrin_rule_hexagon.cc b/src/target/hexagon/llvm/intrin_rule_hexagon.cc index cc5328b2b3f4..23a4e6a52a14 100644 --- a/src/target/hexagon/llvm/intrin_rule_hexagon.cc +++ b/src/target/hexagon/llvm/intrin_rule_hexagon.cc @@ -96,6 +96,7 @@ TVM_REGISTER_OP("tirx.round") DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); TVM_REGISTER_OP("tirx.ctpop") + .set_attr("TIRxOpCategory", ffi::String("builtin"), 1) .set_attr("hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); TVM_REGISTER_OP("tirx.tanh") diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 8d35ef87238f..c89aca5f76a5 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -199,6 +199,7 @@ TVM_REGISTER_OP("tirx.sigmoid") }); TVM_REGISTER_OP("tirx.isfinite") + .set_attr("TIRxOpCategory", ffi::String("builtin"), 1) .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); @@ -206,6 +207,7 @@ TVM_REGISTER_OP("tirx.isfinite") }); TVM_REGISTER_OP("tirx.isinf") + .set_attr("TIRxOpCategory", ffi::String("builtin"), 1) .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 44308be5ba2f..8edbc17ce5e2 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -2112,8 +2112,6 @@ void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { void CodeGenLLVM::VisitStmt_(const DeclBufferNode* op) { EmitDebugLocation(op); } -void CodeGenLLVM::VisitStmt_(const ExecScopeStmtNode* op) { VisitStmt(op->body); } - void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { EmitDebugLocation(op); MakeValue(op->value); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 61d7da8ce402..b57a1a446bcf 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -231,7 +231,6 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const DeclBufferNode* op) override; - void VisitStmt_(const ExecScopeStmtNode* op) override; // Get constant string llvm::Constant* GetConstString(const std::string& str); diff --git a/src/target/metal/intrin_rule_metal.cc b/src/target/metal/intrin_rule_metal.cc index 217ad164b8e1..54417c6cdc94 100644 --- a/src/target/metal/intrin_rule_metal.cc +++ b/src/target/metal/intrin_rule_metal.cc @@ -138,11 +138,15 @@ TVM_REGISTER_OP("tirx.tvm_warp_shuffle_up") TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down") .set_attr("metal.FLowerIntrinsic", DispatchMetalShuffle); -// Register low-level builtin ops. +// Register low-level Metal device intrinsics. TVM_REGISTER_OP("tirx.metal.simd_shuffle") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane", "Expr", "The source thread id.") + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("metal"), + 10) + .set_attr("TScriptPrinterName", ffi::String("metal.simd_shuffle"), 10) .set_attr("TGlobalSymbol", "simd_shuffle") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -150,6 +154,11 @@ TVM_REGISTER_OP("tirx.metal.simd_shuffle_up") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be added.") + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("metal"), + 10) + .set_attr("TScriptPrinterName", ffi::String("metal.simd_shuffle_up"), + 10) .set_attr("TGlobalSymbol", "simd_shuffle_up") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -157,6 +166,11 @@ TVM_REGISTER_OP("tirx.metal.simd_shuffle_down") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("metal"), + 10) + .set_attr("TScriptPrinterName", + ffi::String("metal.simd_shuffle_down"), 10) .set_attr("TGlobalSymbol", "simd_shuffle_down") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index a52c9fc4c2ef..b398af65684f 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -836,8 +836,6 @@ void CodeGenC::VisitStmt_(const DeclBufferNode* op) { // DeclBuffer is a flat statement with no body — nothing to emit. } -void CodeGenC::VisitStmt_(const ExecScopeStmtNode* op) { this->PrintStmt(op->body); } - void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index eea591790e80..934d1af83a36 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -200,7 +200,6 @@ class CodeGenC : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const DeclBufferNode* op) override; - void VisitStmt_(const ExecScopeStmtNode* op) override; /*! * \brief Print expr representing the thread tag diff --git a/src/target/source/codegen_trn.cc b/src/target/source/codegen_trn.cc index 9e43be54bcb8..6a2eb7168ff4 100644 --- a/src/target/source/codegen_trn.cc +++ b/src/target/source/codegen_trn.cc @@ -356,22 +356,26 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL TVM_FFI_ICHECK(!op->op.as()) << "CodegenTrainium does not support inter-function calls, " << "but expression " << ffi::GetRef(op) << " calls PrimFunc " << op->op; - if (op->op.same_as(builtin::nki_matmul())) { + const auto* op_node = op->op.as(); + auto is_op = [&](const Op& compat, const char* canonical_name) { + return op->op.same_as(compat) || (op_node != nullptr && op_node->name == canonical_name); + }; + if (is_op(builtin::nki_matmul(), "tirx.nki.matmul")) { TVM_FFI_ICHECK_EQ(op->args.size(), 4); std::string accum = is_one(op->args[3]) ? " += " : " = "; os << PrintExpr(op->args[0]) << accum; ctx_.is_matmul_input = true; os << "nisa.nc_matmul(" << PrintExpr(op->args[1]) << "," << PrintExpr(op->args[2]); - } else if (op->op.same_as(builtin::nki_load())) { + } else if (is_op(builtin::nki_load(), "tirx.nki.load")) { TVM_FFI_ICHECK_EQ(op->args.size(), 2); os << PrintExpr(op->args[0]) << " = nl.load(" << PrintExpr(op->args[1]); - } else if (op->op.same_as(builtin::nki_store())) { + } else if (is_op(builtin::nki_store(), "tirx.nki.store")) { TVM_FFI_ICHECK_EQ(op->args.size(), 2); os << "nl.store(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]); - } else if (op->op.same_as(builtin::nki_tensor_copy())) { + } else if (is_op(builtin::nki_tensor_copy(), "tirx.nki.tensor_copy")) { TVM_FFI_ICHECK_EQ(op->args.size(), 2); os << PrintExpr(op->args[0]) << " = nisa.tensor_copy(" << PrintExpr(op->args[1]); - } else if (op->op.same_as(builtin::nki_activation())) { + } else if (is_op(builtin::nki_activation(), "tirx.nki.activation")) { TVM_FFI_ICHECK_EQ(op->args.size(), 5); // nki_activation(result, data, opcode, bias, scale) TVM_FFI_ICHECK(opcode_map_.count(op->args[2].as()->value)); @@ -379,17 +383,17 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL os << PrintExpr(op->args[0]) << " = nisa.activation(op=" << nki_op << ", data=" << PrintExpr(op->args[1]) << ","; os << "bias=" << PrintExpr(op->args[3]) << ", scale=" << PrintExpr(op->args[4]); - } else if (op->op.same_as(builtin::nki_reciprocal())) { + } else if (is_op(builtin::nki_reciprocal(), "tirx.nki.reciprocal")) { TVM_FFI_ICHECK_EQ(op->args.size(), 2); os << PrintExpr(op->args[0]) << " = nisa.reciprocal(" << PrintExpr(op->args[1]); - } else if (op->op.same_as(builtin::nki_tensortensor())) { + } else if (is_op(builtin::nki_tensortensor(), "tirx.nki.tensortensor")) { TVM_FFI_ICHECK_EQ(op->args.size(), 4); // nki_tensortensor(result, data1, data2, opcode) TVM_FFI_ICHECK(opcode_map_.count(op->args[3].as()->value)); std::string nki_op = opcode_map_[op->args[3].as()->value]; os << PrintExpr(op->args[0]) << " = nisa.tensor_tensor(" << PrintExpr(op->args[1]) << ", "; os << PrintExpr(op->args[2]) << ", op=" << nki_op; - } else if (op->op.same_as(builtin::nki_tensorscalar())) { + } else if (is_op(builtin::nki_tensorscalar(), "tirx.nki.tensorscalar")) { TVM_FFI_ICHECK_EQ(op->args.size(), 5); // nki_tensorscalar(result, operand0, operand1, opcode, reverse) TVM_FFI_ICHECK(opcode_map_.count(op->args[3].as()->value)); @@ -398,13 +402,13 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL os << PrintExpr(op->args[0]) << " = nisa.tensor_scalar(" << PrintExpr(op->args[1]) << ", operand0="; os << PrintExpr(op->args[2]) << ", op0=" << nki_op << ", reverse0=" << PrintBool(reverse); - } else if (op->op.same_as(builtin::nki_memset())) { + } else if (is_op(builtin::nki_memset(), "tirx.nki.memset")) { TVM_FFI_ICHECK_GE(op->args.size(), 2); // result, value os << PrintExpr(op->args[0]) << " = " << PrintExpr(op->args[1]); TVM_FFI_ICHECK(!ctx_.mask.defined()) << "memset cannot have mask"; return; - } else if (op->op.same_as(builtin::nki_tensorreduce())) { + } else if (is_op(builtin::nki_tensorreduce(), "tirx.nki.tensorreduce")) { TVM_FFI_ICHECK(op->args.size() >= 5) << "nki_tensorreduce expects at least 5 arguments, but got " << op->args.size(); // nki_tensorreduce(result, data, opcode, negate, *axes) @@ -414,7 +418,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL Array axes(op->args.begin() + 4, op->args.end()); os << PrintExpr(op->args[0]) << " = nisa.tensor_reduce(data=" << PrintExpr(op->args[1]) << ", op=" << nki_op << ", negate=" << PrintBool(negate) << ", axis=" << axes; - } else if (op->op.same_as(builtin::nki_activation_reduce())) { + } else if (is_op(builtin::nki_activation_reduce(), "tirx.nki.activation_reduce")) { TVM_FFI_ICHECK(op->args.size() == 7) << "nki_activation_reduce expects 7 arguments, but got " << op->args.size(); // nki_activation_reduce(reduce_res, act_res, data, opcode, reduce_opcode, bias, scale) @@ -426,7 +430,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL << ", op=" << nki_op; os << ", reduce_op=" << reduce_nki_op << ", reduce_res=" << PrintExpr(op->args[0]) << ", bias=" << PrintExpr(op->args[5]) << ", scale=" << PrintExpr(op->args[6]); - } else if (op->op.same_as(builtin::nki_tensorscalar_reduce())) { + } else if (is_op(builtin::nki_tensorscalar_reduce(), "tirx.nki.tensorscalar_reduce")) { TVM_FFI_ICHECK(op->args.size() == 7) << "nki_tensorscalar_reduce expects 7 arguments, but got " << op->args.size(); // nki_tensorscalar_reduce(reduce_res, tensorscalar_res, operand0, operand1, opcode, @@ -440,7 +444,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL << ", op0=" << nki_op << ", operand0=" << PrintExpr(op->args[3]) << ", reduce_op=" << reduce_nki_op << ", reduce_res=" << PrintExpr(op->args[0]) << ", reverse0=" << PrintBool(reverse); - } else if (op->op.same_as(builtin::nki_identity())) { + } else if (is_op(builtin::nki_identity(), "tirx.nki.identity")) { // nki_identity(result, size) TVM_FFI_ICHECK_EQ(op->args.size(), 2); auto identity_np_name = name_supply_->FreshName("identity_np"); @@ -450,7 +454,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL os << ' '; } os << PrintExpr(op->args[0]) << " = nl.load(" << identity_np_name; - } else if (op->op.same_as(builtin::nki_scalar_tensor_tensor())) { + } else if (is_op(builtin::nki_scalar_tensor_tensor(), "tirx.nki.scalar_tensor_tensor")) { TVM_FFI_ICHECK_EQ(op->args.size(), 8); // nki_scalar_tensor_tensor(result, data, operand0, operand1, opcode0, opcode1, reverse0, // reverse1) @@ -464,7 +468,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL << ", operand0=" << PrintExpr(op->args[2]) << ", op0=" << nki_op0 << ", reverse0=" << PrintBool(reverse0) << ", operand1=" << PrintExpr(op->args[3]) << ", op1=" << nki_op1 << ", reverse1=" << PrintBool(reverse1); - } else if (op->op.same_as(builtin::nki_scalar_tensor_scalar())) { + } else if (is_op(builtin::nki_scalar_tensor_scalar(), "tirx.nki.scalar_tensor_scalar")) { TVM_FFI_ICHECK_EQ(op->args.size(), 8); // nki_scalar_tensor_scalar(result, data, operand0, operand1, opcode0, opcode1, reverse0, // reverse1) @@ -478,7 +482,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL << ", operand0=" << PrintExpr(op->args[2]) << ", op0=" << nki_op0 << ", reverse0=" << PrintBool(reverse0) << ", operand1=" << PrintExpr(op->args[3]) << ", op1=" << nki_op1 << ", reverse1=" << PrintBool(reverse1); - } else if (op->op.same_as(builtin::nki_affine_select())) { + } else if (is_op(builtin::nki_affine_select(), "tirx.nki.affine_select")) { TVM_FFI_ICHECK_EQ(op->args.size(), 4); // nki_affine_select(result, pred, true_value, false_value) os << PrintExpr(op->args[0]) << " = nisa.affine_select(pred=" << PrintExpr(op->args[1]) diff --git a/src/target/webgpu/intrin_rule_webgpu.cc b/src/target/webgpu/intrin_rule_webgpu.cc index 889b85e56aad..14dfd7959146 100644 --- a/src/target/webgpu/intrin_rule_webgpu.cc +++ b/src/target/webgpu/intrin_rule_webgpu.cc @@ -158,11 +158,16 @@ TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down") .set_attr("webgpu.FLowerIntrinsic", DispatchWebGPUShuffle); -// Register low-level builtin ops. +// Register low-level WebGPU device intrinsics. TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane", "Expr", "The source thread id.") + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("webgpu"), + 10) + .set_attr("TScriptPrinterName", + ffi::String("webgpu.subgroup_shuffle"), 10) .set_attr("TGlobalSymbol", "subgroupShuffle") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -170,6 +175,11 @@ TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_up") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be added.") + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("webgpu"), + 10) + .set_attr("TScriptPrinterName", + ffi::String("webgpu.subgroup_shuffle_up"), 10) .set_attr("TGlobalSymbol", "subgroupShuffleUp") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -177,6 +187,11 @@ TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_down") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("webgpu"), + 10) + .set_attr("TScriptPrinterName", + ffi::String("webgpu.subgroup_shuffle_down"), 10) .set_attr("TGlobalSymbol", "subgroupShuffleDown") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); diff --git a/src/tirx/analysis/exec_context.cc b/src/tirx/analysis/exec_context.cc index 93c2781da210..1dcd9dc1066a 100644 --- a/src/tirx/analysis/exec_context.cc +++ b/src/tirx/analysis/exec_context.cc @@ -660,16 +660,6 @@ bool ExecContext::WithCtaAxisModulo(const std::string& axis, int64_t modulus, in return true; } -bool ExecContext::WithScopeSwitch(ScopeKind new_scope_kind, ExecContext* out, - std::string* err) const { - ExecSplit new_split; - if (!ScopeSwitch(A, new_scope_kind, &new_split, err)) return false; - out->A = A; - out->scope_kind = new_scope_kind; - out->split = std::move(new_split); - return true; -} - ffi::Map> EncodeSplitSide( const std::unordered_map& side) { ffi::Map> out; diff --git a/src/tirx/analysis/filter_canonical.cc b/src/tirx/analysis/filter_canonical.cc index c1d27c3ece84..fbf098cced98 100644 --- a/src/tirx/analysis/filter_canonical.cc +++ b/src/tirx/analysis/filter_canonical.cc @@ -44,6 +44,14 @@ bool IsBitwiseAndCall(const CallNode* call) { return call->op.same_as(tirx::builtin::bitwise_and()) && call->args.size() == 2; } +bool IsPtxElectSyncCall(const CallNode* call) { + if (call->op.same_as(tirx::builtin::ptx_elect_sync())) return true; + if (auto op = call->op.as()) { + return op.value()->name == "tirx.ptx.elect_sync"; + } + return false; +} + // Strip implicit Cast wrappers from a predicate. Bool-vs-int mixing in the // Python frontend can insert ``Cast(bool_expr)`` (e.g. when an // ``elect_sync()`` uint32 result is combined with a bool comparison via @@ -191,7 +199,7 @@ bool TryParseCompareAtom(const PrimExpr& expr, const ScopeIdPredicate& is_scope_ bool TryParseElectSyncAtom(const PrimExpr& expr, FilterAtom* out) { const auto* call = expr.as(); if (call == nullptr) return false; - if (!call->op.same_as(tirx::builtin::ptx_elect_sync())) return false; + if (!IsPtxElectSyncCall(call)) return false; out->kind = FilterAtomKind::kElectSync; out->scopeid_var = Var(); out->lo = 0; diff --git a/src/tirx/analysis/verify_tirx_well_formed.cc b/src/tirx/analysis/verify_tirx_well_formed.cc index dbc5e672507a..aabda41ade99 100644 --- a/src/tirx/analysis/verify_tirx_well_formed.cc +++ b/src/tirx/analysis/verify_tirx_well_formed.cc @@ -51,13 +51,11 @@ class ExecScopeVerifier : public Verifier { using Verifier::Visit; void VisitStmt_(const SBlockNode* op, ffi::reflection::AccessPath path) override { - Verify(false) << "TIRxError: SBlock is not allowed in tirx=True mode at " << path - << ". Use ExecScopeStmt with T.attr() instead."; + Verify(false) << "TIRxError: SBlock is not allowed in tirx=True mode at " << path; } void VisitStmt_(const SBlockRealizeNode* op, ffi::reflection::AccessPath path) override { - Verify(false) << "TIRxError: SBlockRealize is not allowed in tirx=True mode at " << path - << ". Use ExecScopeStmt with T.attr() instead."; + Verify(false) << "TIRxError: SBlockRealize is not allowed in tirx=True mode at " << path; } void VisitStmt_(const tirx::TilePrimitiveCallNode* op, @@ -66,13 +64,6 @@ class ExecScopeVerifier : public Verifier { Verify(tirx_op_map_.count(op->op)) << "TIRxError: TilePrimitiveCall at " << path << " has unknown TIRX op " << op->op; } - - void VisitStmt_(const ExecScopeStmtNode* op, ffi::reflection::AccessPath path) override { - // ExecScope ctor FATALs on unknown name, so a constructed scope is always - // structurally valid. Scope nesting is a perspective change rather than - // an active-set narrowing, so any ScopeKind may nest inside any other. - Verifier::VisitStmt_(op, path); - } }; class ScopeIdVerifier : public Verifier { @@ -82,18 +73,6 @@ class ScopeIdVerifier : public Verifier { private: using Verifier::Visit; - void VisitStmt_(const ExecScopeStmtNode* op, ffi::reflection::AccessPath path) override { - size_t baseline = scope_id_def_.size(); - Verifier::VisitStmt_(op, path); - size_t total = scope_id_def_.size(); - if (total > baseline) { - RunScopeIdVerify(path, baseline, /*is_root=*/false); - } - while (scope_id_def_.size() > baseline) { - scope_id_def_.pop_back(); - } - } - void VisitStmt_(const AttrStmtNode* op, ffi::reflection::AccessPath path) override { if (op->attr_key == tvm::tirx::attr::kDeviceEntry) { // Device-region marker: defs gathered from the body are verified when @@ -161,11 +140,6 @@ class LayoutVerifier : public Verifier { void VisitStmt_(const SBlockRealizeNode* op, ffi::reflection::AccessPath path) override { Verify(false) << "TIRxError: SBlockRealize is not allowed in tirx=True mode at " << path; } - - void VisitStmt_(const ExecScopeStmtNode* op, ffi::reflection::AccessPath path) override { - // Check buffer layouts in alloc_buffers that appear as AllocBuffer stmts - Verifier::VisitStmt_(op, path); - } }; class AsyncStructsVerifier : public Verifier { @@ -182,14 +156,6 @@ class AsyncStructsVerifier : public Verifier { void VisitStmt_(const SBlockRealizeNode* op, ffi::reflection::AccessPath path) override { Verify(false) << "TIRxError: SBlockRealize is not allowed in tirx=True mode at " << path; } - - void VisitStmt_(const ExecScopeStmtNode* op, ffi::reflection::AccessPath path) override { - scope_stack_.push_back(op->exec_scope); - Verifier::VisitStmt_(op, path); - scope_stack_.pop_back(); - } - - std::vector scope_stack_; }; class DeviceFuncVerifier : public Verifier { @@ -206,23 +172,6 @@ class DeviceFuncVerifier : public Verifier { void VisitStmt_(const SBlockRealizeNode* op, ffi::reflection::AccessPath path) override { Verify(false) << "TIRxError: SBlockRealize is not allowed in tirx=True mode at " << path; } - - void VisitStmt_(const ExecScopeStmtNode* op, ffi::reflection::AccessPath path) override { - if (!inside_root_scope_) { - // At the top level: only one root scope is allowed - Verify(!root_.has_value()) << "TIRxError: Only one root scope is allowed in device function"; - root_ = op->exec_scope; - inside_root_scope_ = true; - Verifier::VisitStmt_(op, path); - inside_root_scope_ = false; - } else { - // Already inside a root scope: nested scopes are allowed - Verifier::VisitStmt_(op, path); - } - } - - ffi::Optional root_ = std::nullopt; - bool inside_root_scope_ = false; }; bool VerifyTIRxWellFormed(const PrimFunc& func, bool assert_mode, bool device_func) { diff --git a/src/tirx/ir/stmt.cc b/src/tirx/ir/stmt.cc index b0c0f6d037d3..be782da0cd92 100644 --- a/src/tirx/ir/stmt.cc +++ b/src/tirx/ir/stmt.cc @@ -52,7 +52,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { MatchBufferRegionNode::RegisterReflection(); SBlockNode::RegisterReflection(); SBlockRealizeNode::RegisterReflection(); - ExecScopeStmtNode::RegisterReflection(); ScopeIdDefStmtNode::RegisterReflection(); } @@ -632,17 +631,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -// ExecScopeStmt -ExecScopeStmt::ExecScopeStmt(ExecScope exec_scope, Stmt body, Span span) { - TVM_FFI_ICHECK(exec_scope.defined()); - TVM_FFI_ICHECK(body.defined()); - ffi::ObjectPtr node = ffi::make_object(); - node->exec_scope = std::move(exec_scope); - node->body = std::move(body); - node->span = std::move(span); - data_ = std::move(node); -} - // ScopeIdDefStmt ScopeIdDefStmt::ScopeIdDefStmt(ScopeIdDef def, Span span) { TVM_FFI_ICHECK(def.defined()); @@ -654,9 +642,6 @@ ScopeIdDefStmt::ScopeIdDefStmt(ScopeIdDef def, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tirx.ExecScopeStmt", [](ExecScope exec_scope, Stmt body, Span span) { - return ExecScopeStmt(exec_scope, body, span); - }); refl::GlobalDef().def("tirx.ScopeIdDefStmt", [](ScopeIdDef def, Span span) { return ScopeIdDefStmt(def, span); }); } diff --git a/src/tirx/ir/stmt_functor.cc b/src/tirx/ir/stmt_functor.cc index cccbaec5567b..a28f055cca0a 100644 --- a/src/tirx/ir/stmt_functor.cc +++ b/src/tirx/ir/stmt_functor.cc @@ -147,12 +147,6 @@ void StmtVisitor::VisitStmt_(const SBlockRealizeNode* op) { this->VisitStmt(op->block); } -void StmtVisitor::VisitStmt_(const ExecScopeStmtNode* op) { - // ScopeIdDefStmts are now separate body stmts and are visited via the - // standard StmtFunctor dispatch; nothing extra to do here. - this->VisitStmt(op->body); -} - void StmtVisitor::VisitStmt_(const ScopeIdDefStmtNode* op) { // Flat stmt -- no body. Visit extents (skip deferred defs whose extents // are NullOpt) and any preferred_extents. @@ -646,16 +640,6 @@ Stmt StmtMutator::VisitStmt_(const ScopeIdDefStmtNode* op) { return Stmt(n); } -Stmt StmtMutator::VisitStmt_(const ExecScopeStmtNode* op) { - Stmt body = this->VisitStmt(op->body); - if (body.same_as(op->body)) { - return ffi::GetRef(op); - } - auto n = CopyOnWrite(op); - n->body = std::move(body); - return Stmt(n); -} - Stmt StmtMutator::VisitStmt_(const tirx::TilePrimitiveCallNode* op) { auto fmutate = [&](const ffi::Any& e) -> ffi::Any { if (e == nullptr) return e; diff --git a/src/tirx/ir/tir_visitor_with_path.cc b/src/tirx/ir/tir_visitor_with_path.cc index 795facac2b25..8638d5d786bd 100644 --- a/src/tirx/ir/tir_visitor_with_path.cc +++ b/src/tirx/ir/tir_visitor_with_path.cc @@ -327,10 +327,6 @@ void TIRVisitorWithPath::VisitStmt_(const tirx::TilePrimitiveCallNode* op, Acces } } -void TIRVisitorWithPath::VisitStmt_(const ExecScopeStmtNode* op, AccessPath path) { - Visit(op->body, path->Attr("body")); -} - void TIRVisitorWithPath::VisitStmt_(const ScopeIdDefStmtNode* op, AccessPath path) { // Flat stmt -- no body. Visit extents and preferred_extents (if present), // then push the bound Var(s) into the current scope so subsequent siblings diff --git a/src/tirx/ir/tir_visitor_with_path.h b/src/tirx/ir/tir_visitor_with_path.h index cac455467cea..33c112f98555 100644 --- a/src/tirx/ir/tir_visitor_with_path.h +++ b/src/tirx/ir/tir_visitor_with_path.h @@ -129,7 +129,6 @@ class TIRVisitorWithPath void VisitStmt_(const SBlockNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const SBlockRealizeNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const tirx::TilePrimitiveCallNode* op, ffi::reflection::AccessPath path) override; - void VisitStmt_(const ExecScopeStmtNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const ScopeIdDefStmtNode* op, ffi::reflection::AccessPath path) override; using ExprFunctor::VisitExpr; diff --git a/src/tirx/ir/tirx_stmt.cc b/src/tirx/ir/tirx_stmt.cc index ec6391dc0231..81f392048dc1 100644 --- a/src/tirx/ir/tirx_stmt.cc +++ b/src/tirx/ir/tirx_stmt.cc @@ -35,7 +35,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TilePrimitiveCallNode::RegisterReflection(); } TilePrimitiveCall::TilePrimitiveCall(tvm::Op op, ffi::Array args, ffi::Map workspace, ffi::Map config, - ffi::Optional dispatch) { + ffi::Optional dispatch, ExecScope scope) { // Check if the op is a TIRX op. static const auto& tirx_op_map = Op::GetAttrMap("TIsTIRxOp"); TVM_FFI_ICHECK_EQ(tirx_op_map.count(op), 1) @@ -47,6 +47,7 @@ TilePrimitiveCall::TilePrimitiveCall(tvm::Op op, ffi::Array args, n->workspace = std::move(workspace); n->config = std::move(config); n->dispatch = std::move(dispatch); + n->scope = std::move(scope); data_ = std::move(n); } @@ -55,8 +56,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def( "tirx.TilePrimitiveCall", [](tvm::Op op, ffi::Array args, ffi::Map workspace, - ffi::Map config, ffi::Optional dispatch) { - return TilePrimitiveCall(op, args, workspace, config, dispatch); + ffi::Map config, ffi::Optional dispatch, + ExecScope scope) { + return TilePrimitiveCall(op, args, workspace, config, dispatch, scope); }); } diff --git a/src/tirx/ir/transform.cc b/src/tirx/ir/transform.cc index 7156f421142c..1b0fe047a6a8 100644 --- a/src/tirx/ir/transform.cc +++ b/src/tirx/ir/transform.cc @@ -47,7 +47,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tirx.use_async_copy", bool); TVM_REGISTER_PASS_CONFIG_OPTION("tirx.merge_static_smem", bool); TVM_REGISTER_PASS_CONFIG_OPTION("tirx.instrument_lwp", bool); TVM_REGISTER_PASS_CONFIG_OPTION("tirx.vtcm_capacity", int64_t); -TVM_REGISTER_PASS_CONFIG_OPTION("tirx.ptx_ldg32", bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.ptx.ldg32", bool); TVM_REGISTER_PASS_CONFIG_OPTION("tirx.enable_fast_math", bool); /*! diff --git a/src/tirx/op/runtime.cc b/src/tirx/op/runtime.cc index 5c1bd0077ea6..fb24a82ae605 100644 --- a/src/tirx/op/runtime.cc +++ b/src/tirx/op/runtime.cc @@ -29,11 +29,13 @@ namespace tirx { TVM_REGISTER_OP("tirx.TVMBackendAnyListSetPackedArg") .set_num_inputs(5) + .set_attr("TIRxOpCategory", ffi::String("builtin"), 1) .set_attr("TGlobalSymbol", "TVMBackendAnyListSetPackedArg") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); TVM_REGISTER_OP("tirx.TVMBackendAnyListMoveFromPackedReturn") .set_num_inputs(3) + .set_attr("TIRxOpCategory", ffi::String("builtin"), 1) .set_attr("TGlobalSymbol", "TVMBackendAnyListMoveFromPackedReturn") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); diff --git a/src/tirx/op/target_builtin/cuda.cc b/src/tirx/op/target_builtin/cuda.cc index 574c622b52a0..91a84dbda32f 100644 --- a/src/tirx/op/target_builtin/cuda.cc +++ b/src/tirx/op/target_builtin/cuda.cc @@ -27,6 +27,8 @@ #include #include +#include + namespace tvm { namespace tirx { namespace builtin { @@ -79,8 +81,17 @@ TIRX_DEFINE_BUILTIN_FUNC(mma_store_legacy) TIRX_DEFINE_BUILTIN_FUNC(mma_fill_legacy) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIRX_DEFINE_BUILTIN_FUNC(ptx_ldg32).set_num_inputs(4).set_attr( - "TCallEffectKind", static_cast(CallEffectKind::kPure)); +const Op& ptx_ldg32() { + static const Op& op = Op::Get("tirx.ptx.ldg32"); + return op; +} + +TVM_REGISTER_OP("tirx.ptx.ldg32") + .set_num_inputs(4) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)) + .set_attr("TScriptPrinterName", ffi::String("ptx.ldg32"), 20) + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("ptx"), 10); TIRX_DEFINE_BUILTIN_FUNC(ptx_mma_sp) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) @@ -173,8 +184,24 @@ TIRX_DEFINE_BUILTIN_FUNC(ptx_elect_sync) TIRX_DEFINE_BUILTIN_FUNC(ptx_fence_mbarrier_init) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIRX_DEFINE_BUILTIN_FUNC(ptx_fetch_register) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)); +const Op& ptx_fetch_register() { + static const Op& op = Op::Get("tirx.ptx.fetch_register"); + return op; +} + +TVM_REGISTER_OP("tirx.ptx.fetch_register") + .set_num_inputs(-1) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)) + .set_attr("TIRxOpCategory", ffi::String("device_intrin")) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("ptx")) + .set_attr("TScriptPrinterName", ffi::String("ptx.fetch_register")); + +TVM_REGISTER_OP("tirx.ptx_fetch_register") + .set_num_inputs(-1) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)) + .set_attr("TIRxOpCategory", ffi::String("device_intrin")) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("ptx")) + .set_attr("TScriptPrinterName", ffi::String("ptx.fetch_register")); // griddepcontrol — programmatic dependent launch synchronization (sm_90+). // Both are memory barriers; mark kOpaque to prevent CSE/reordering. @@ -335,6 +362,222 @@ TIRX_DEFINE_BUILTIN_FUNC(nvshmem_fence) TIRX_DEFINE_BUILTIN_FUNC(nvshmem_barrier_all) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); +namespace { + +struct DeviceIntrinsicRegistration { + const char* flat_name; + const char* namespace_name; + CallEffectKind effect_kind; +}; + +void RegisterDeviceIntrinsic(const DeviceIntrinsicRegistration& reg) { + std::string flat_name(reg.flat_name); + std::string namespace_name(reg.namespace_name); + std::string prefix = namespace_name + "_"; + std::string suffix = flat_name; + if (suffix.rfind(prefix, 0) == 0) { + suffix = suffix.substr(prefix.size()); + } + + std::string flat_op_name = "tirx." + flat_name; + std::string canonical_op_name = "tirx." + namespace_name + "." + suffix; + ffi::String namespace_attr(namespace_name); + ffi::String printer_name(namespace_name + "." + suffix); + int64_t effect = static_cast(reg.effect_kind); + + auto register_one = [&](const std::string& op_name) { + OpRegEntry::RegisterOrGet(op_name) + .set_name() + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), + /*plevel=*/15) + .set_attr("TDeviceIntrinsicNamespace", namespace_attr, + /*plevel=*/15) + .set_attr("TCallEffectKind", effect, /*plevel=*/15) + .set_attr("TScriptPrinterName", printer_name, /*plevel=*/15); + }; + + register_one(flat_op_name); + register_one(canonical_op_name); +} + +#define TIRX_DEVICE_INTRIN_ALIAS(OpName, Namespace, EffectKind) \ + {#OpName, #Namespace, CallEffectKind::EffectKind} + +const DeviceIntrinsicRegistration kDeviceIntrinsics[] = { + TIRX_DEVICE_INTRIN_ALIAS(cuda_atomic_add, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_atomic_cas, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_ballot_sync, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_bfloat1622float2, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_bfloat162float, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_clock64, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_cluster_sync, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_copy_bytes, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_cta_reduce, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_cta_sync, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_cvta_generic_to_shared, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_fadd2_rn, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_ffs_u32, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_float22bfloat162_rn, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_float22bfloat162_rn_from_float2, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_float22half2, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_float2_x, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_float2_y, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_float8tohalf8, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_float_as_uint, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_fmul2_rn, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_fp8x4_e4m3_from_float4, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_func_call, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_get_tmem_addr, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_grid_sync, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_half2float, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_half8tofloat8, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_hmax2, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_hmin2, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_ldg, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_make_float2, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_nano_sleep, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_printf, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_reduce_add_sync_u32, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_reduce_min_sync_u32, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_runtime_instr_desc, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_smem_addr_from_uint64, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_sm100_tma_2sm_mbarrier_addr, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_syncthreads_and, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_syncthreads_or, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_thread_fence, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_thread_rank, cuda, kPure), + TIRX_DEVICE_INTRIN_ALIAS(cuda_trap_when_assert_failed, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_uint_as_float, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_warp_reduce, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_warp_sync, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(cuda_warpgroup_sync, cuda, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_barrier_all, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_fence, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_getmem_nbi, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_getmem_nbi_block, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_getmem_nbi_warp, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_my_pe, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_n_pes, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_putmem_nbi, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_putmem_nbi_block, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_putmem_nbi_warp, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_putmem_signal_nbi, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_putmem_signal_nbi_block, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_putmem_signal_nbi_warp, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_quiet, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_signal_op, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(nvshmem_wait_until, nvshmem, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_add_f32, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_add_f32x2, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_add_f64, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_add_rn_f32_bf16, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_any_sync, ptx, kPure), + TIRX_DEVICE_INTRIN_ALIAS(ptx_atom_scalar, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_bar_arrive, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_bar_sync, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_barrier_cluster_arrive, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_barrier_cluster_wait, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_commit_group, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_g2s_cluster, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_g2s_cta, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_s2g, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_s2s_cluster, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_shared_to_cluster, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_tensor_global_to_cluster, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_tensor_global_to_cluster_prefetch, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_tensor_shared_to_global, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_tensor_shared_to_global_reduce, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_wait_group, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_commit_group, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_mbarrier_arrive, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_wait_group, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_elect_sync, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_exp2, ptx, kPure), + TIRX_DEVICE_INTRIN_ALIAS(ptx_fence, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_fence_mbarrier_init, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_fence_proxy_async, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_fetch_register, ptx, kPure), + TIRX_DEVICE_INTRIN_ALIAS(ptx_fma_f32, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_fma_f32x2, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_fma_f64, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_fns_b32, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_griddepcontrol_launch_dependents, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_griddepcontrol_wait, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_ld, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_ld_acquire, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_ld_global_acquire, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_ld_volatile, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_ldmatrix, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_ldmatrix_legacy, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mapa, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_map_shared_rank, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_max_f32, ptx, kPure), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_arrive, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_arrive_expect_tx, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_init, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_test_wait_parity, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait_once, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mma, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mma_legacy, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mma_sp, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mul_f32, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mul_f32x2, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mul_f64, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_prefetch_tensormap, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_rcp, ptx, kPure), + TIRX_DEVICE_INTRIN_ALIAS(ptx_red_scalar, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_reduce3_max_f32, ptx, kPure), + TIRX_DEVICE_INTRIN_ALIAS(ptx_reduce3_min_f32, ptx, kPure), + TIRX_DEVICE_INTRIN_ALIAS(ptx_setmaxnreg, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_st, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_st_bulk, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_stmatrix, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_sub_f32, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_sub_f32x2, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_sub_f64, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_alloc, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_commit, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_cp, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_dealloc, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_encode_instr_descriptor, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_encode_instr_descriptor_block_scaled, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_encode_matrix_descriptor, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_fence_after_thread_sync, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_fence_before_thread_sync, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_ld, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_mma, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_mma_block_scale, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_mma_sp, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_mma_sp_block_scale, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_relinquish_alloc_permit, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_shift, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_st, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_wait_ld, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_tcgen05_wait_st, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_wgmma_commit_group, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_wgmma_encode_matrix_descriptor, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_wgmma_fence, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_wgmma_mma_async_rs, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_wgmma_mma_async_ss, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_wgmma_noop_barrier, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_wgmma_wait_group, ptx, kOpaque), +}; + +const bool kDeviceIntrinsicAliasesRegistered = []() { + for (const auto& reg : kDeviceIntrinsics) { + RegisterDeviceIntrinsic(reg); + } + return true; +}(); + +#undef TIRX_DEVICE_INTRIN_ALIAS + +} // namespace + } // namespace builtin } // namespace tirx } // namespace tvm diff --git a/src/tirx/op/target_builtin/trn.cc b/src/tirx/op/target_builtin/trn.cc index 7966e6d505b3..e9df7669cfb1 100644 --- a/src/tirx/op/target_builtin/trn.cc +++ b/src/tirx/op/target_builtin/trn.cc @@ -27,6 +27,8 @@ #include #include +#include + namespace tvm { namespace tirx { namespace builtin { @@ -86,6 +88,65 @@ TIRX_DEFINE_BUILTIN_FUNC(nki_scalar_tensor_scalar) TIRX_DEFINE_BUILTIN_FUNC(nki_affine_select) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); +namespace { + +void RegisterNKIIntrinsic(const char* flat_name) { + std::string flat(flat_name); + std::string prefix = "nki_"; + std::string suffix = flat; + if (suffix.rfind(prefix, 0) == 0) { + suffix = suffix.substr(prefix.size()); + } + + std::string flat_op_name = "tirx." + flat; + std::string canonical_op_name = "tirx.nki." + suffix; + ffi::String namespace_attr("nki"); + ffi::String printer_name("nki." + suffix); + int64_t effect = static_cast(CallEffectKind::kOpaque); + + auto register_one = [&](const std::string& op_name) { + OpRegEntry::RegisterOrGet(op_name) + .set_name() + .set_attr("TIRxOpCategory", ffi::String("device_intrin"), + /*plevel=*/15) + .set_attr("TDeviceIntrinsicNamespace", namespace_attr, + /*plevel=*/15) + .set_attr("TCallEffectKind", effect, /*plevel=*/15) + .set_attr("TScriptPrinterName", printer_name, /*plevel=*/15); + }; + + register_one(flat_op_name); + register_one(canonical_op_name); +} + +const char* kNKIIntrinsics[] = { + "nki_activation", + "nki_activation_reduce", + "nki_affine_select", + "nki_identity", + "nki_load", + "nki_matmul", + "nki_memset", + "nki_reciprocal", + "nki_scalar_tensor_scalar", + "nki_scalar_tensor_tensor", + "nki_store", + "nki_tensor_copy", + "nki_tensorreduce", + "nki_tensorscalar", + "nki_tensorscalar_reduce", + "nki_tensortensor", +}; + +const bool kNKIIntrinsicAliasesRegistered = []() { + for (const char* op_name : kNKIIntrinsics) { + RegisterNKIIntrinsic(op_name); + } + return true; +}(); + +} // namespace + } // namespace builtin } // namespace tirx } // namespace tvm diff --git a/src/tirx/op/tirx.cc b/src/tirx/op/tirx.cc index 1529780218f3..0410bb5f2157 100644 --- a/src/tirx/op/tirx.cc +++ b/src/tirx/op/tirx.cc @@ -33,15 +33,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { DispatchContextNode::RegisterReflection(); } /********************* Utils **********************/ -#define TIRX_DEFINE_BUILTIN_FUNC(OpName) \ - const Op& OpName() { \ - static const Op& op = Op::Get("tirx." #OpName); \ - return op; \ - } \ - TVM_REGISTER_OP("tirx." #OpName) \ - .set_attr("TScriptPrinterName", ffi::String(#OpName), /*plevel=*/9) - -#define TIRX_DEFINE_OP(OpName) TIRX_DEFINE_BUILTIN_FUNC(OpName).set_attr("TIsTIRxOp", true) +#define TIRX_DEFINE_TILE_FUNC(OpName) \ + const Op& OpName() { \ + static const Op& op = Op::Get("tirx.tile." #OpName); \ + return op; \ + } \ + TVM_REGISTER_OP("tirx.tile." #OpName) \ + .set_attr("TScriptPrinterName", ffi::String(#OpName), /*plevel=*/9) \ + .set_attr("TIRxOpCategory", ffi::String("tile_primitive"), /*plevel=*/9) \ + .set_attr("TIsTIRxOp", true) + +#define TIRX_DEFINE_TILE_OP(OpName, Kind) \ + TIRX_DEFINE_TILE_FUNC(OpName).set_attr("TTilePrimitiveKind", Kind) /********************* Context utils **********************/ template @@ -140,7 +143,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } /********************* Dispatch Ops **********************/ -#define TIRX_DEFINE_DISPATCH_OP(OpName) TIRX_DEFINE_OP(OpName).set_attr("TIsDispatchOp", true) +#define TIRX_DEFINE_DISPATCH_OP(OpName) \ + TIRX_DEFINE_TILE_OP(OpName, ffi::String("dispatch")).set_attr("TIsDispatchOp", true) TIRX_DEFINE_DISPATCH_OP(zero); TIRX_DEFINE_DISPATCH_OP(sqrt); @@ -168,20 +172,23 @@ TIRX_DEFINE_DISPATCH_OP(select); TIRX_DEFINE_DISPATCH_OP(cast); TIRX_DEFINE_DISPATCH_OP(fma); TIRX_DEFINE_DISPATCH_OP(silu); +TIRX_DEFINE_DISPATCH_OP(permute_layout); /********************* Compose Ops **********************/ -#define TIRX_DEFINE_COMPOSE_OP(OpName) TIRX_DEFINE_OP(OpName).set_attr("TIsComposeOp", true) +#define TIRX_DEFINE_COMPOSE_OP(OpName) \ + TIRX_DEFINE_TILE_OP(OpName, ffi::String("compose")).set_attr("TIsComposeOp", true) TIRX_DEFINE_COMPOSE_OP(compose_op); /********************* Async Ops **********************/ -#define TIRX_DEFINE_ASYNC_OP(OpName) TIRX_DEFINE_OP(OpName).set_attr("TIsAsyncOp", true) +#define TIRX_DEFINE_ASYNC_OP(OpName) \ + TIRX_DEFINE_TILE_OP(OpName, ffi::String("async")).set_attr("TIsAsyncOp", true) TIRX_DEFINE_ASYNC_OP(copy_async); TIRX_DEFINE_ASYNC_OP(gemm_async); /********************* Misc Ops **********************/ -TIRX_DEFINE_OP(tvm_kernel_replace_point); +TIRX_DEFINE_TILE_OP(tvm_kernel_replace_point, ffi::String("marker")); } // namespace tirx } // namespace tvm diff --git a/src/tirx/script/builder/frame.cc b/src/tirx/script/builder/frame.cc index 7a3974e94d6f..d7dc9a4f91a1 100644 --- a/src/tirx/script/builder/frame.cc +++ b/src/tirx/script/builder/frame.cc @@ -69,7 +69,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { TIRFrameNode::RegisterReflection(); PrimFuncFrameNode::RegisterReflection(); SBlockFrameNode::RegisterReflection(); - ExecScopeFrameNode::RegisterReflection(); BlockInitFrameNode::RegisterReflection(); ForFrameNode::RegisterReflection(); AssertFrameNode::RegisterReflection(); @@ -200,22 +199,6 @@ void SBlockFrameNode::ExitWithScope() { } } -void ExecScopeFrameNode::ExitWithScope() { - TIRFrameNode::ExitWithScope(); - TVM_FFI_ICHECK(exec_scope.defined()) - << "InternalError: ExecScopeFrame must have an execution scope"; - tvm::tirx::Stmt body = AsStmt(stmts); - tvm::tirx::Stmt stmt = tvm::tirx::ExecScopeStmt(exec_scope.value(), body); - ffi::Optional guard = std::nullopt; - for (const PrimExpr& predicate : guards) { - guard = guard.defined() ? PrimExpr(guard.value() && predicate) : predicate; - } - if (guard.defined()) { - stmt = tvm::tirx::IfThenElse(guard.value(), stmt); - } - AddToParent(stmt); -} - void BlockInitFrameNode::EnterWithScope() { SBlockFrame frame = FindSBlockFrame("T.init"); if (frame->init.defined()) { @@ -328,7 +311,7 @@ void ComposeOpFrameNode::ExitWithScope() { << stmt; ops.push_back(ffi::GetRef(op_call)); } - auto compose_op_op = tvm::Op::Get("tirx.compose_op"); + auto compose_op_op = tvm::Op::Get("tirx.tile.compose_op"); AddToParent(tvm::tirx::TilePrimitiveCall(compose_op_op, ops, workspace, config, dispatch)); } diff --git a/src/tirx/script/builder/ir.cc b/src/tirx/script/builder/ir.cc index 500ac254e1cc..f3bbbc4a4ed8 100644 --- a/src/tirx/script/builder/ir.cc +++ b/src/tirx/script/builder/ir.cc @@ -197,22 +197,6 @@ SBlockFrame Block(ffi::String name, bool no_realize, ffi::String exec_scope) { void TilePrimitiveCall(tvm::tirx::TilePrimitiveCall op_call) { AddToParent(op_call); } -ExecScopeFrame ExecScopeBlock(ffi::String exec_scope_name, ffi::Array guards) { - ffi::ObjectPtr n = ffi::make_object(); - TVM_FFI_ICHECK(!exec_scope_name.empty()) << "InternalError: exec_scope_name must not be empty"; - n->exec_scope = tvm::tirx::ExecScope(exec_scope_name); - n->guards = std::move(guards); - return ExecScopeFrame(n); -} - -ExecScopeFrame Cluster(ffi::Array guards) { return ExecScopeBlock("cluster", guards); } -ExecScopeFrame WarpGroup(ffi::Array guards) { - return ExecScopeBlock("warpgroup", guards); -} -ExecScopeFrame CTA(ffi::Array guards) { return ExecScopeBlock("cta", guards); } -ExecScopeFrame Warp(ffi::Array guards) { return ExecScopeBlock("warp", guards); } -ExecScopeFrame Thread(ffi::Array guards) { return ExecScopeBlock("thread", guards); } - ffi::Array ScopeId(ffi::Optional> extents, ffi::String parent, ffi::String name, ffi::String cur) { // Determine the number of Vars to introduce. Deferred form (extents=None) @@ -678,7 +662,7 @@ AttrFrame DeviceEntry() { IRBuilder builder = IRBuilder::Current(); ffi::Optional pf_frame = builder->FindFrame(); TVM_FFI_ICHECK(pf_frame.defined()) - << "Tx.device_entry() must be called inside a @Tx.prim_func body"; + << "T.device_entry() must be called inside a @T.prim_func body"; // Capture the AttrFrame by ObjectRef value so the lambda holds a strong // reference while the callback runs. Without this, the only reference is // the IRBuilder frame stack; ``ExitWithScope`` pops itself first and the @@ -962,13 +946,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("script.ir_builder.tirx.FuncRet", FuncRet) .def("script.ir_builder.tirx.MatchBuffer", MatchBuffer) .def("script.ir_builder.tirx.Block", Block) - .def("script.ir_builder.tirx.ExecScopeBlock", ExecScopeBlock) .def("script.ir_builder.tirx.TilePrimitiveCall", TilePrimitiveCall) - .def("script.ir_builder.tirx.Cluster", Cluster) - .def("script.ir_builder.tirx.CTA", CTA) - .def("script.ir_builder.tirx.WarpGroup", WarpGroup) - .def("script.ir_builder.tirx.Warp", Warp) - .def("script.ir_builder.tirx.Thread", Thread) .def("script.ir_builder.tirx.ClusterId", [](ffi::Optional> extents, ffi::String parent) { return ClusterId(extents, parent); diff --git a/src/tirx/script/builder/utils.h b/src/tirx/script/builder/utils.h index fc0293fbfca0..4d7821a84d5a 100644 --- a/src/tirx/script/builder/utils.h +++ b/src/tirx/script/builder/utils.h @@ -99,21 +99,6 @@ inline SBlockFrame FindSBlockFrame(const ffi::String& method) { throw; } -/*! - * \brief Find the innermost ExecScopeFrame in the IRBuilder frame stack. - * \param method The method name to be printed when throwing exception. - * \return The innermost ExecScopeFrame. - */ -inline ExecScopeFrame FindExecScopeFrame(const ffi::String& method) { - if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { - return frame.value(); - } - LOG(FATAL) << "ValueError: " << method - << " must be called inside an execution scope (e.g. T.cta(), T.warp()), " - << "but no ExecScopeFrame was found"; - throw; -} - /*! * \brief Check whether the top frame in IRBuilder frame stack is IfFrame. * \param method The method name to be printed when throwing exception. diff --git a/src/tirx/script/printer/block.cc b/src/tirx/script/printer/block.cc index db6d167b5062..6d7902a4a89f 100644 --- a/src/tirx/script/printer/block.cc +++ b/src/tirx/script/printer/block.cc @@ -232,19 +232,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_REGISTER_SCRIPT_AS_REPR(tirx::SBlockNode, ReprPrintTIR); TVM_REGISTER_SCRIPT_AS_REPR(tirx::SBlockRealizeNode, ReprPrintTIR); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", - [](tirx::ExecScopeStmt stmt, AccessPath p, IRDocsifier d) - -> Doc { return ExecScopeStmtDoc(stmt, p, d, {}); }); - -TVM_SCRIPT_REPR(tirx::ExecScopeStmtNode, ReprPrintTIR); - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](tirx::ScopeIdDefStmt stmt, AccessPath p, IRDocsifier d) -> Doc { // Render as ``(var1, var2, ...) = T.cta_id([ext], preferred=[...])`` - // (or the appropriate API name for the binding). Mirrors the loop - // in ``ExecScopeStmtDoc`` that handled the legacy payload form. + // (or the appropriate API name for the binding). TVM_FFI_ICHECK(!d->frames.empty()); tirx::ScopeIdDef def = stmt->def; AccessPath def_p = p->Attr("def"); diff --git a/src/tirx/script/printer/buffer.cc b/src/tirx/script/printer/buffer.cc index 2333eb89005b..7c190f941494 100644 --- a/src/tirx/script/printer/buffer.cc +++ b/src/tirx/script/printer/buffer.cc @@ -222,7 +222,7 @@ ffi::Map BufferAttrs(tirx::Buffer buffer, const AccessPath PrimExpr addr = buffer->allocated_addr[0]; AccessPath addr_p = buffer_p->Attr("allocated_addr")->ArrayItem(0); if (const auto* bl = addr.as()) { - // Ensure the buffer variable is defined (may emit a Tx.Buffer(...) statement). + // Ensure the buffer variable is defined (may emit a T.Buffer(...) statement). d->AsDoc(bl->buffer, addr_p->Attr("buffer")); // Get the variable name bound to this buffer. ffi::Optional buf_var = d->GetVarDoc(bl->buffer); diff --git a/src/tirx/script/printer/expr.cc b/src/tirx/script/printer/expr.cc index cd33f59e3c84..f732aa619b68 100644 --- a/src/tirx/script/printer/expr.cc +++ b/src/tirx/script/printer/expr.cc @@ -315,7 +315,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // cuda_func_call: last arg is source_code (keyword-only in the Python API). // Print it as source_code=... to enable TVMScript round-trip. - if (op->name == "tirx.cuda_func_call") { + if (op->name == "tirx.cuda_func_call" || op->name == "tirx.cuda.func_call") { int n_args = call->args.size(); ffi::Array args; // All args except the last (source_code) are positional. diff --git a/src/tirx/script/printer/stmt.cc b/src/tirx/script/printer/stmt.cc index a16f2b254be0..4d3c21c88d5d 100644 --- a/src/tirx/script/printer/stmt.cc +++ b/src/tirx/script/printer/stmt.cc @@ -94,9 +94,40 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) static const auto& dispatch_op_map = Op::GetAttrMap("TIsDispatchOp"); static const auto& compose_op_map = Op::GetAttrMap("TIsComposeOp"); static const auto& async_op_map = Op::GetAttrMap("TIsAsyncOp"); + static const auto& category_map = Op::GetAttrMap("TIRxOpCategory"); TVM_FFI_ICHECK(tirx_op_map.get(op, false)) << "Only TIRX ops can be used in tirx::TilePrimitiveCall"; ffi::String name = op_names.get(op, op->name); + // Per-call execution scope is printed as a namespace prefix on the op, + // e.g. ``T.warp.copy(...)``. ``warpgroup`` prints as ``wg``. The + // default ``thread`` scope prints through the explicit tile namespace, + // e.g. ``T.tile.copy(...)``, so canonical script only needs the full + // TIRx dialect import. ``Tx`` remains a handwritten shorthand for + // ``T.tile`` and ``T.`` tile calls. + auto scope_ns = [](tirx::ScopeKind k) -> ffi::Optional { + switch (k) { + case tirx::ScopeKind::kWarp: + return ffi::String("warp"); + case tirx::ScopeKind::kWarpgroup: + return ffi::String("wg"); + case tirx::ScopeKind::kCta: + return ffi::String("cta"); + case tirx::ScopeKind::kCluster: + return ffi::String("cluster"); + default: // kThread -> no prefix + return std::nullopt; + } + }; + auto scoped_callee = [&](const ffi::String& op_name) -> ExprDoc { + ffi::Optional ns = scope_ns(op_call->scope->kind); + if (ns.has_value()) { + return TIRx(d, ns.value())->Attr(op_name); + } + if (category_map.get(op, ffi::String("")) == "tile_primitive") { + return TIRx(d, "tile")->Attr(op_name); + } + return TIRx(d, op_name); + }; if (dispatch_op_map.get(op, false) || async_op_map.get(op, false)) { // Dispatch ops // Trim trailing None args (e.g. optional bias=None, scale=None) @@ -126,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (op_call->dispatch.has_value()) { disp = LiteralDoc::Str(op_call->dispatch.value(), p->Attr("dispatch")); } - return OpCallDoc(TIRx(d, name), args, + return OpCallDoc(scoped_callee(name), args, d->AsDoc(op_call->workspace, p->Attr("workspace")), d->AsDoc(op_call->config, p->Attr("config")), disp); } else if (compose_op_map.get(op, false)) { @@ -158,7 +189,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kw_values.push_back( d->AsDoc(kv.second, p->Attr("config")->MapItem(kv.first))); } - return ScopeDoc(std::nullopt, TIRx(d, "compose_op")->Call({}, kw_keys, kw_values), + return ScopeDoc(std::nullopt, scoped_callee("compose_op")->Call({}, kw_keys, kw_values), (*f)->stmts); } else { // Misc ops @@ -166,7 +197,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (size_t i = 0, n = op_call->args.size(); i < n; ++i) { args.push_back(d->AsDoc(op_call->args[i], p->Attr("args")->ArrayItem(i))); } - return OpCallDoc(TIRx(d, name), args, {}, {}, std::nullopt); + return OpCallDoc(scoped_callee(name), args, {}, {}, std::nullopt); } }); TVM_SCRIPT_REPR(tirx::TilePrimitiveCallNode, ReprPrintTIR); @@ -739,13 +770,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tirx::IfThenElse stmt, AccessPath p, IRDocsifier d) -> Doc { - if (!stmt->else_case.defined()) { - if (auto exec_scope_stmt = stmt->then_case.as()) { - ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); - return ExecScopeStmtDoc(ffi::GetRef(exec_scope_stmt), - p->Attr("then_case"), d, {cond}); - } - } ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); ffi::Array then_branch; ffi::Array else_branch; diff --git a/src/tirx/script/printer/utils.h b/src/tirx/script/printer/utils.h index 471720790b0e..6d72afd65229 100644 --- a/src/tirx/script/printer/utils.h +++ b/src/tirx/script/printer/utils.h @@ -220,17 +220,6 @@ inline ffi::String ScopeIdApiName(const tirx::ScopeBinding& binding) { return ""; } -inline Doc ExecScopeStmtDoc(tirx::ExecScopeStmt stmt, AccessPath p, IRDocsifier d, - ffi::Array call_args) { - With frame(d, stmt); - tirx::ExecScope exec_scope = stmt->exec_scope; - ffi::Array scope_call_args = call_args; - // ScopeIdDefStmts (formerly payload) are now standalone statements within - // the body and print via their own dispatch. - AsDocBody(stmt->body, p->Attr("body"), frame->get(), d); - return ScopeDoc(std::nullopt, TIR(d, exec_scope->name())->Call(scope_call_args), (*frame)->stmts); -} - /*! * \brief Find the top frame in the stack that could place a var definition * \param var The var to be defined diff --git a/src/tirx/transform/lower_tirx.cc b/src/tirx/transform/lower_tirx.cc index 7819237e8a43..c351a934e385 100644 --- a/src/tirx/transform/lower_tirx.cc +++ b/src/tirx/transform/lower_tirx.cc @@ -24,49 +24,21 @@ #include #include -#include -#include #include +#include +#include + namespace tvm { namespace tirx { namespace transform { -namespace { - -/*! - * \brief Strip ExecScopeStmt wrappers from lowered TIRX output. - * - * ExecScopeStmt is required while lowering TIRX ops and resolving scope IDs/slices. - * After those passes finish, the wrappers are no longer needed and should not be - * present in the final LowerTIRx output. - */ -class ExecScopeStripper : public StmtExprMutator { - public: - static Stmt Strip(const Stmt& stmt) { return ExecScopeStripper()(stmt); } - - private: - Stmt VisitStmt_(const ExecScopeStmtNode* op) final { return VisitStmt(op->body); } -}; - -Pass LowerTIRxStripExecScope() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - n->body = ExecScopeStripper::Strip(n->body); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tirx.LowerTIRxStripExecScope", {}); -} - -} // namespace - Pass LowerTIRx() { std::vector passes = {TilePrimitiveDispatch()}; if (std::getenv("TVM_PRINT_AFTER_TIRX_DISPATCH_OPS")) { passes.push_back(tvm::transform::PrintIR()); } passes.push_back(LowerTIRxCleanup()); - passes.push_back(LowerTIRxStripExecScope()); return tvm::transform::Sequential(passes, "tirx.LowerTIRx"); } diff --git a/src/tirx/transform/lower_tirx_cleanup.cc b/src/tirx/transform/lower_tirx_cleanup.cc index 318631fc939e..98aa794e3c3b 100644 --- a/src/tirx/transform/lower_tirx_cleanup.cc +++ b/src/tirx/transform/lower_tirx_cleanup.cc @@ -42,35 +42,6 @@ namespace tvm { namespace tirx { -class DispatchContextRemover : public StmtExprMutator { - public: - static Stmt Remove(const Stmt& stmt) { return DispatchContextRemover()(stmt); } - - private: - Stmt VisitStmt_(const ExecScopeStmtNode* op) final { - Stmt body = VisitStmt(op->body); - // Strip TIRX dispatch AttrStmts from ExecScopeStmt body - // (These are dead-code annotations that were never written but the cleanup pass - // historically erased: scope_id_extent_map, thread_var_map, tirx.warp_id_in_cta) - auto strip = [](Stmt stmt) { - while (auto attr = stmt.as()) { - if (attr->attr_key == "scope_id_extent_map" || attr->attr_key == "thread_var_map" || - attr->attr_key == "tirx.warp_id_in_cta") { - stmt = attr->body; - } else { - break; - } - } - return stmt; - }; - body = strip(body); - if (body.same_as(op->body)) { - return ffi::GetRef(op); - } - return ExecScopeStmt(op->exec_scope, body); - } -}; - class LayoutApplier : public arith::IRMutatorWithAnalyzer { public: static std::pair> Flatten( @@ -389,7 +360,6 @@ Pass LowerTIRxCleanup() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { Target target = ResolveTarget(f); auto* n = f.CopyOnWrite(); - n->body = DispatchContextRemover::Remove(n->body); std::tie(n->body, n->buffer_map) = LayoutApplier::Flatten(n->body, n->buffer_map, target); n->body = BufferOffsetRemover::Remove(n->body); return f; diff --git a/src/tirx/transform/split_host_device.cc b/src/tirx/transform/split_host_device.cc index 801554389c15..7ec104765f3a 100644 --- a/src/tirx/transform/split_host_device.cc +++ b/src/tirx/transform/split_host_device.cc @@ -34,6 +34,8 @@ #include #include +#include + #include "../../runtime/thread_storage_scope.h" #include "../analysis/var_use_def_analysis.h" #include "ir_utils.h" @@ -83,6 +85,36 @@ PrimFunc AnnotateDeviceRegionsForSplit(PrimFunc func) { // Host/device function extraction +class LaunchBoundsAttrExtractor : public StmtMutator { + public: + Stmt Extract(Stmt stmt) { + min_blocks_per_sm_.reset(); + return operator()(std::move(stmt)); + } + + std::optional min_blocks_per_sm() const { return min_blocks_per_sm_; } + + private: + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tirx::attr::kLaunchBoundsMinBlocksPerSM) { + const auto* min_blocks_per_sm = op->value.as(); + TVM_FFI_ICHECK(min_blocks_per_sm) + << tirx::attr::kLaunchBoundsMinBlocksPerSM << " expects an integer value"; + TVM_FFI_ICHECK_GT(min_blocks_per_sm->value, 0) + << tirx::attr::kLaunchBoundsMinBlocksPerSM << " must be positive"; + if (min_blocks_per_sm_.has_value()) { + TVM_FFI_ICHECK_EQ(min_blocks_per_sm_.value(), min_blocks_per_sm->value) + << "Conflicting " << tirx::attr::kLaunchBoundsMinBlocksPerSM << " values"; + } + min_blocks_per_sm_ = min_blocks_per_sm->value; + return VisitStmt(op->body); + } + return StmtMutator::VisitStmt_(op); + } + + std::optional min_blocks_per_sm_; +}; + class HostDeviceSplitter : public StmtMutator { public: explicit HostDeviceSplitter(IRModule* device_mod, std::function var_supply, @@ -147,21 +179,24 @@ class HostDeviceSplitter : public StmtMutator { for (Buffer buf : buffers_to_declare) { body = SeqStmt::Flatten(DeclBuffer(buf), std::move(body)); } + LaunchBoundsAttrExtractor launch_bounds_attr; + body = launch_bounds_attr.Extract(std::move(body)); PrimFunc device_func(params, body, kernel_ret_type); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tirx::attr::kNoAlias, true}, {tirx::attr::kIsGlobalFunc, true}}); - if (cur_func_->attrs->dict.count(tvm::attr::kSTir)) { + bool is_stir = cur_func_->attrs->dict.count(tvm::attr::kSTir); + if (is_stir) { device_func = WithAttr(std::move(device_func), tvm::attr::kSTir, true); } + if (device_target->kind->name == "cuda" && launch_bounds_attr.min_blocks_per_sm().has_value()) { + device_func = WithAttr(std::move(device_func), tirx::attr::kLaunchBoundsMinBlocksPerSM, + launch_bounds_attr.min_blocks_per_sm().value()); + } auto num_inputs = cur_func_->GetAttr(tvm::attr::kNumInputs); if (num_inputs.has_value()) { device_func = WithAttr(std::move(device_func), tvm::attr::kNumInputs, num_inputs); } - auto persistent = cur_func_->GetAttr(tirx::attr::kPersistentKernel); - if (persistent.has_value()) { - device_func = WithAttr(std::move(device_func), tirx::attr::kPersistentKernel, persistent); - } GlobalVar kernel_symbol_global = var_supply_(); (*device_mod_)->Add(kernel_symbol_global, device_func); ffi::Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); diff --git a/src/tirx/transform/tile_primitive_dispatch.cc b/src/tirx/transform/tile_primitive_dispatch.cc index 0e0e4932caf4..cb431e474d66 100644 --- a/src/tirx/transform/tile_primitive_dispatch.cc +++ b/src/tirx/transform/tile_primitive_dispatch.cc @@ -54,8 +54,7 @@ namespace { // Gather every ScopeIdDef declared anywhere under a given Stmt, paired with // the source stmt node that declared it (for implicit-eval routing). The -// source is either an enclosing ExecScopeStmt or the AttrStmt(kDeviceEntry) -// marker. +// source is the AttrStmt(kDeviceEntry) marker. struct ScopeIdDefWithSource { ScopeIdDef def; const StmtNode* source_stmt; @@ -69,10 +68,6 @@ class ScopeIdDefGather : public StmtExprVisitor { return std::move(gather.out_); } - void VisitStmt_(const ExecScopeStmtNode* op) override { - EnterSourceAndPartition(op, [&]() { StmtExprVisitor::VisitStmt_(op); }); - } - void VisitStmt_(const AttrStmtNode* op) override { if (op->attr_key == tvm::tirx::attr::kDeviceEntry) { EnterSourceAndPartition(op, [&]() { StmtExprVisitor::VisitStmt_(op); }); @@ -129,7 +124,14 @@ class ElectSyncFinder : public StmtExprVisitor { using StmtExprVisitor::VisitStmt_; void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(tirx::builtin::ptx_elect_sync())) { + auto is_canonical_elect_sync = [&]() { + if (op->op.same_as(tirx::builtin::ptx_elect_sync())) return true; + if (auto call_op = op->op.as()) { + return call_op.value()->name == "tirx.ptx.elect_sync"; + } + return false; + }; + if (is_canonical_elect_sync()) { found_ = true; return; } @@ -183,7 +185,7 @@ class ScopeIdDefRemover : public StmtExprMutator { // For implicitly-named ScopeIdDefs (parser-emitted Var("")), inject an // Evaluate(var) at the source stmt's body so the binding stays observably // live in the IR even if user code never references it. Routing uses source -// stmt-node identity to match against the surviving ExecScopeStmt nodes. +// stmt-node identity to match against the device-entry marker. class ImplicitScopeIdEvalInjector : public StmtExprMutator { public: static Stmt Inject(const Stmt& stmt, @@ -213,16 +215,6 @@ class ImplicitScopeIdEvalInjector : public StmtExprMutator { return evals; } - Stmt VisitStmt_(const ExecScopeStmtNode* op) final { - Stmt body = VisitStmt(op->body); - auto evals = ConsumeEvalsFor(op); - if (!evals.empty()) { - body = SeqStmt::Flatten(evals, body); - } - if (body.same_as(op->body)) return ffi::GetRef(op); - return ExecScopeStmt(op->exec_scope, body); - } - Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt body = VisitStmt(op->body); if (op->attr_key == tvm::tirx::attr::kDeviceEntry) { @@ -312,20 +304,6 @@ class TilePrimitiveDispatcher : public StmtExprMutator { Stmt body_; }; - Stmt VisitStmt_(const ExecScopeStmtNode* op) final { - exec_scope_stack_.push_back(op->exec_scope); - scope_id_defs_at_level_.push_back({}); - bool pushed_scope_ctx = PushScopeSwitchCtx(op->exec_scope->kind); - Stmt body = VisitStmt(op->body); - if (pushed_scope_ctx) ctx_stack_.pop_back(); - exec_scope_stack_.pop_back(); - scope_id_defs_at_level_.pop_back(); - if (body.same_as(op->body)) { - return ffi::GetRef(op); - } - return ExecScopeStmt(op->exec_scope, body); - } - Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tirx::attr::kDeviceEntry) { return ProcessDeviceEntry(op); @@ -351,16 +329,10 @@ class TilePrimitiveDispatcher : public StmtExprMutator { PrepareLaunchParams(entry_node, body_to_visit, &scope_binds); bool pushed_base_ctx = PushKernelEntryCtx(); - bool prev_inside = inside_device_entry_; - int prev_size = device_entry_stack_size_; - inside_device_entry_ = true; - device_entry_stack_size_ = static_cast(exec_scope_stack_.size()); // Direct ScopeIdDefStmt children of the device-entry marker live here. scope_id_defs_at_level_.push_back({}); Stmt body = VisitStmt(body_to_visit); scope_id_defs_at_level_.pop_back(); - inside_device_entry_ = prev_inside; - device_entry_stack_size_ = prev_size; // Post-dispatch: re-gather the now-inlined body and resolve every // ``ScopeIdDef`` (kernel-side + dispatch-introduced) into ``scope_binds``. @@ -424,14 +396,13 @@ class TilePrimitiveDispatcher : public StmtExprMutator { // alloc buffers wrapping ``body`` directly. Stmt res = body; - // Inject implicit scope-id evals sourced from inner ExecScopeStmts. - // Must run before ScopeIdDefRemover, which rebuilds ExecScope nodes - // and invalidates source identities. + // Inject implicit scope-id evals sourced from the device-entry marker. + // Must run before ScopeIdDefRemover, which rebuilds nodes and + // invalidates source identities. res = ImplicitScopeIdEvalInjector::Inject(res, implicit_scope_id_evals); - // Strip scope_id_def from inner ExecScopeStmts and standalone - // ScopeIdDefStmt nodes -- their values are now bound at kernel scope via - // the Bind statements below. + // Strip standalone ScopeIdDefStmt nodes -- their values are now bound at + // kernel scope via the Bind statements below. res = ScopeIdDefRemover::Remove(res); // Prepend Bind(var, value) for every resolved scope id (and the derived @@ -558,38 +529,27 @@ class TilePrimitiveDispatcher : public StmtExprMutator { } Stmt VisitStmt_(const tirx::TilePrimitiveCallNode* op) final { + // Scope is a per-call field on the node. Derive the (inter, intra) split + // on the spot from the current active set ``A`` (tracked through control + // flow on ``ctx_stack_``) under this call's own ``op->scope``. ffi::Map> inter_map, intra_map; - // scope_kind defaults to the current ExecScope's name (or "kernel" when - // we're at the device-region root without any inner ExecScope). When - // ExecContext tracking is active the tracked scope_kind wins (consistent - // once predicates change the active set). - ffi::String scope_kind; - ExecScope dispatch_scope; - if (exec_scope_stack_.empty()) { - // At the device-region root (inside AttrStmt(kDeviceEntry) but no - // inner ExecScope). Use ``kernel`` for dispatcher continuity. - scope_kind = "kernel"; - dispatch_scope = ExecScope("thread"); // placeholder; not load-bearing - } else { - scope_kind = exec_scope_stack_.back()->name(); - dispatch_scope = exec_scope_stack_.back(); - } + ffi::String scope_kind = ScopeKindToString(op->scope->kind); if (!ctx_stack_.empty()) { - const auto& ctx = ctx_stack_.back(); - inter_map = EncodeSplitSide(ctx.split.inter); - intra_map = EncodeSplitSide(ctx.split.intra); - scope_kind = ScopeKindToString(ctx.scope_kind); - } - // Preserve the "kernel" label at the device-region root (where - // dispatchers historically checked ``scope_kind == "kernel"`` to fire). - // The root corresponds to the dispatch site whose exec_scope_stack_ size - // matches the size at entry to ProcessDeviceEntry (the level where the - // marker was opened, before any inner ExecScope is pushed). - if (inside_device_entry_ && - static_cast(exec_scope_stack_.size()) == device_entry_stack_size_) { - scope_kind = "kernel"; - } - tirx::DispatchContext sctx(target_, dispatch_scope, launch_params_, var_range_map_, + ExecSplit split; + std::string err; + if (ScopeSwitch(ctx_stack_.back().A, op->scope->kind, &split, &err)) { + inter_map = EncodeSplitSide(split.inter); + intra_map = EncodeSplitSide(split.intra); + } else { + // Factoring failure (e.g. warpgroup with a lane that crosses a + // warpgroup boundary unaligned). Leave the split empty; dispatchers + // fall back to scope_kind. This is not validated earlier, so an + // incompatible per-call scope only warns here and yields a degenerate + // split rather than a hard error. + LOG(WARNING) << "ExecContext scope_switch failed: " << err; + } + } + tirx::DispatchContext sctx(target_, op->scope, launch_params_, var_range_map_, /*alloc_only=*/false, /*callbacks=*/{}, shared_state_, inter_map, intra_map, scope_kind); static auto f_op_dispatcher_ = ffi::Function::GetGlobal("tirx.f_op_dispatcher"); @@ -702,7 +662,7 @@ class TilePrimitiveDispatcher : public StmtExprMutator { for (size_t i = 0; i < def->def_ids.size(); i++) { // Reuse the original Var as the bind target -- no rename, no // substitution. The IR already references this Var directly, and - // dispatch's filter resolution walks ExecScopeStmt::scope_id_def + // dispatch's filter resolution walks ScopeIdDefStmt::def // to map Vars back to their ScopeBinding. Var bind_var = def->def_ids[i]; PrimExpr value = resolved[i]; @@ -818,21 +778,6 @@ class TilePrimitiveDispatcher : public StmtExprMutator { return true; } - bool PushScopeSwitchCtx(ScopeKind new_scope_kind) { - if (ctx_stack_.empty()) return false; - ExecContext new_ctx; - std::string err; - if (!ctx_stack_.back().WithScopeSwitch(new_scope_kind, &new_ctx, &err)) { - // Factoring failure (e.g. warpgroup case 3 / world scope_switch). - // Pause tracking; dispatchers fall back to scope_kind. The verifier - // (VerifyTIRxWellFormed) is responsible for catching this earlier. - LOG(WARNING) << "ExecContext scope_switch failed: " << err; - return false; - } - ctx_stack_.push_back(new_ctx); - return true; - } - struct ScopeIdTarget { ScopeBinding binding; int dim = 0; @@ -1472,17 +1417,10 @@ class TilePrimitiveDispatcher : public StmtExprMutator { ffi::Map var_range_map_; arith::Analyzer analyzer_; const Target& target_; - std::vector exec_scope_stack_; - // Parallel to exec_scope_stack_ plus one entry for the device-entry body - // itself: list of ScopeIdDefs visible at each level. Grows as - // ScopeIdDefStmt nodes are visited. + // List of ScopeIdDefs visible at each nesting level (one entry for the + // device-entry body itself, plus one per ScopeIdDefStmt-bearing region). + // Grows as ScopeIdDefStmt nodes are visited. std::vector> scope_id_defs_at_level_; - // True while inside the AttrStmt(kDeviceEntry) body. - bool inside_device_entry_ = false; - // ``exec_scope_stack_.size()`` at the moment ProcessDeviceEntry was called. - // A TilePrimitiveCall whose dispatch site is at this same stack size is at - // the device-entry root level (no inner ExecScope opened yet). - int device_entry_stack_size_ = -1; std::vector ctx_stack_; std::unordered_map launch_params_; std::vector alloc_buffers_; @@ -1523,41 +1461,6 @@ class TilePrimitiveDispatcher : public StmtExprMutator { // No failure aggregation; pass surfaces per-op exceptions }; -class ScopeMerger : public StmtExprMutator { - public: - static Stmt Merge(const Stmt& stmt) { return ScopeMerger()(stmt); } - - private: - Stmt VisitStmt_(const SeqStmtNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - if (auto* n = stmt.as()) { - std::vector seq; - for (size_t i = 0; i < n->seq.size();) { - if (auto* exec_scope_stmt = n->seq[i].as()) { - // Find a sequence of ExecScopeStmts with the same exec_scope - std::vector new_body{exec_scope_stmt->body}; - auto scope = exec_scope_stmt->exec_scope; - for (i++; i < n->seq.size(); i++) { - if (auto* next_exec_scope = n->seq[i].as()) { - if (scope->kind == next_exec_scope->exec_scope->kind) { - new_body.push_back(next_exec_scope->body); - continue; - } - } - break; - } - seq.push_back(ExecScopeStmt(scope, SeqStmt::Flatten(new_body))); - } else { - seq.push_back(n->seq[i]); - i++; - } - } - return SeqStmt::Flatten(seq); - } - return stmt; - }; -}; - namespace { Target ResolveTarget(const PrimFunc& f) { auto target = f->GetAttr(tvm::attr::kTarget); diff --git a/tests/python/codegen/test_inject_ptx_ldg32.py b/tests/python/codegen/test_inject_ptx_ldg32.py index fa61b6a50338..821f987e635b 100644 --- a/tests/python/codegen/test_inject_ptx_ldg32.py +++ b/tests/python/codegen/test_inject_ptx_ldg32.py @@ -46,7 +46,7 @@ def test_inject_ptx_intrin(): if major < 8: # Require at least SM80 return - with tvm.transform.PassContext(config={"tirx.ptx_ldg32": True}): + with tvm.transform.PassContext(config={"tirx.ptx.ldg32": True}): mod = tvm.compile(f, target="cuda") A_np = np.random.rand(16).astype("float32") B_np = np.zeros(32).astype("float32") diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py index 5731c368c42c..d739e2259ef2 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py @@ -37,7 +37,7 @@ def _count_ptx_ldg32(stmt): num_call = [0] def visit(n): - if isinstance(n, tvm.tirx.Call) and n.op.name == "tirx.ptx_ldg32": + if isinstance(n, tvm.tirx.Call) and n.op.name == "tirx.ptx.ldg32": num_call[0] += 1 tvm.tirx.stmt_functor.post_order_visit(stmt, visit) diff --git a/tests/python/tirx-base/test_tir_op_types.py b/tests/python/tirx-base/test_tir_op_types.py index f0d5d1ab6b03..2ffce7dce8c6 100644 --- a/tests/python/tirx-base/test_tir_op_types.py +++ b/tests/python/tirx-base/test_tir_op_types.py @@ -164,7 +164,7 @@ def test_tir_op_ptx_mma(): 0, False, ) - assert expr.op.name == "tirx.ptx_mma_legacy" + assert expr.op.name == "tirx.ptx.mma_legacy" def test_tir_op_ptx_mma_sp(): @@ -190,7 +190,7 @@ def test_tir_op_ptx_mma_sp(): 0, False, ) - assert expr.op.name == "tirx.ptx_mma_sp" + assert expr.op.name == "tirx.ptx.mma_sp" def test_tir_op_mma_store(): @@ -232,21 +232,21 @@ def test_op_ptx_ldmatrix(): buffer_local.data, buffer_local.data, ) - assert expr.op.name == "tirx.ptx_ldmatrix" + assert expr.op.name == "tirx.ptx.ldmatrix" def test_op_ptx_cp_async(): buffer_shared = tirx.decl_buffer([16, 16], "float16", scope="shared") buffer_local = tirx.decl_buffer([8], "float16", scope="local") expr = tirx.ptx_cp_async_legacy(buffer_shared.data, 0, buffer_local.data, 0, 16) - assert expr.op.name == "tirx.ptx_cp_async" + assert expr.op.name == "tirx.ptx.cp_async" def test_op_ptx_cp_async_bulk(): buffer_shared = tirx.decl_buffer([16, 16], "float16", scope="shared") buffer_local = tirx.decl_buffer([8], "float16", scope="local") expr = tirx.ptx_cp_async_bulk("float16", buffer_shared.data, 0, buffer_local.data, 0, 16, 0) - assert expr.op.name == "tirx.ptx_cp_async_bulk" + assert expr.op.name == "tirx.ptx.cp_async_bulk" def test_tir_op_vectorlow(): diff --git a/tests/python/tirx-base/test_tir_stmt_functor.py b/tests/python/tirx-base/test_tir_stmt_functor.py index aff44eb9d471..3b53ef8d29b5 100644 --- a/tests/python/tirx-base/test_tir_stmt_functor.py +++ b/tests/python/tirx-base/test_tir_stmt_functor.py @@ -23,6 +23,7 @@ from tvm import tirx as tir from tvm.ir import Range from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.expr import EQ, GT, LT, Add, IntImm, Mul, Sub, Var from tvm.tirx.stmt_functor import StmtExprMutator, StmtExprVisitor, StmtMutator, StmtVisitor @@ -670,8 +671,7 @@ def func(A: T.Buffer((10,), "int32")): # OpCall @T.prim_func(s_tir=True) def op_call(A: T.Buffer((10,), "int32"), B: T.Buffer((10,), "int32")): - with T.thread(): - T.add(A, B, 1.0) + Tx.add(A, B, 1.0) return { "evaluate": evaluate_stmt, @@ -684,7 +684,7 @@ def op_call(A: T.Buffer((10,), "int32"), B: T.Buffer((10,), "int32")): "if_then_else": if_then_else, "for_with_break": func.body, "decl_buffer": buffer_decl, - "op_call": op_call.body.body, + "op_call": op_call.body, } diff --git a/tests/python/tirx/codegen/test_codegen_ampere.py b/tests/python/tirx/codegen/test_codegen_ampere.py index 86e7ca16a7e3..f0c8911cd9b4 100644 --- a/tests/python/tirx/codegen/test_codegen_ampere.py +++ b/tests/python/tirx/codegen/test_codegen_ampere.py @@ -17,7 +17,7 @@ # pylint: disable=missing-function-docstring """Codegen tests for Ampere (sm_80) warp-level ``mma.sync`` tensor cores. -These exercise the ``Tx.ptx.mma`` intrinsic directly (not via the gemm +These exercise the ``T.ptx.mma`` intrinsic directly (not via the gemm dispatch). ``ptx.mma`` takes one pointer per 32-bit register for each operand (``d_ptrs`` / ``a_ptrs`` / ``b_ptrs`` / ``c_ptrs``), enumerated in the fixed PTX register order, so the b32 registers may be scattered in the register file @@ -34,7 +34,7 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T DEV = tvm.device("cuda") @@ -81,59 +81,58 @@ def test_ptx_mma_m16n8k16(a_type, no_c_ptr): b_type = a_type # fmt: off - @Tx.prim_func + @T.prim_func def main( - D: Tx.Buffer((16, 8), "float32"), - A: Tx.Buffer((16, 16), a_type), - B: Tx.Buffer((16, 8), b_type), - C: Tx.Buffer((16, 8), "float32"), + D: T.Buffer((16, 8), "float32"), + A: T.Buffer((16, 16), a_type), + B: T.Buffer((16, 8), b_type), + C: T.Buffer((16, 8), "float32"), ): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) - with Tx.thread(): - D_local = Tx.alloc_local([4], "float32") - A_local = Tx.alloc_local([8], a_type) - B_local = Tx.alloc_local([4], b_type) - C_local = Tx.alloc_local([4], "float32") - - @Tx.inline - def G2L(buf_local, buf_global, block_8x8, mode="row"): - if mode == "row": - for i in range(block_8x8): - row = Tx.meta_var(i % 2 * 8 + tx // 4) - col = Tx.meta_var(i // 2 * 8 + (tx % 4) * 2) - for j in range(2): - buf_local[i * 2 + j] = buf_global[row, col + j] - elif mode == "col": - for i in range(block_8x8): - row = Tx.meta_var(i % 2 * 8 + (tx % 4) * 2) - col = Tx.meta_var(i // 2 * 8 + tx // 4) - for j in range(2): - buf_local[i * 2 + j] = buf_global[row + j, col] - - G2L(D_local, D, 2) - G2L(A_local, A, 4) - G2L(B_local, B, 2, "col") - G2L(C_local, C, 2) - - # One pointer per b32 register, in PTX order: A=4, B=2, D/C=4. - d_ptrs = [D_local.ptr_to([i]) for i in range(4)] - a_ptrs = [A_local.ptr_to([2 * i]) for i in range(4)] - b_ptrs = [B_local.ptr_to([2 * i]) for i in range(2)] - if no_c_ptr: - Tx.ptx.mma("m16n8k16", "row", "col", "float32", a_type, b_type, "float32", - d_ptrs, a_ptrs, b_ptrs) - else: - c_ptrs = [C_local.ptr_to([i]) for i in range(4)] - Tx.ptx.mma("m16n8k16", "row", "col", "float32", a_type, b_type, "float32", - d_ptrs, a_ptrs, b_ptrs, c_ptrs) - - for i in range(2): - row = Tx.meta_var(i % 2 * 8 + tx // 4) - col = Tx.meta_var(i // 2 * 8 + (tx % 4) * 2) - for j in range(2): - D[row, col + j] = D_local[i * 2 + j] + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) + D_local = T.alloc_local([4], "float32") + A_local = T.alloc_local([8], a_type) + B_local = T.alloc_local([4], b_type) + C_local = T.alloc_local([4], "float32") + + @T.inline + def G2L(buf_local, buf_global, block_8x8, mode="row"): + if mode == "row": + for i in range(block_8x8): + row = T.meta_var(i % 2 * 8 + tx // 4) + col = T.meta_var(i // 2 * 8 + (tx % 4) * 2) + for j in range(2): + buf_local[i * 2 + j] = buf_global[row, col + j] + elif mode == "col": + for i in range(block_8x8): + row = T.meta_var(i % 2 * 8 + (tx % 4) * 2) + col = T.meta_var(i // 2 * 8 + tx // 4) + for j in range(2): + buf_local[i * 2 + j] = buf_global[row + j, col] + + G2L(D_local, D, 2) + G2L(A_local, A, 4) + G2L(B_local, B, 2, "col") + G2L(C_local, C, 2) + + # One pointer per b32 register, in PTX order: A=4, B=2, D/C=4. + d_ptrs = [D_local.ptr_to([i]) for i in range(4)] + a_ptrs = [A_local.ptr_to([2 * i]) for i in range(4)] + b_ptrs = [B_local.ptr_to([2 * i]) for i in range(2)] + if no_c_ptr: + T.ptx.mma("m16n8k16", "row", "col", "float32", a_type, b_type, "float32", + d_ptrs, a_ptrs, b_ptrs) + else: + c_ptrs = [C_local.ptr_to([i]) for i in range(4)] + T.ptx.mma("m16n8k16", "row", "col", "float32", a_type, b_type, "float32", + d_ptrs, a_ptrs, b_ptrs, c_ptrs) + + for i in range(2): + row = T.meta_var(i % 2 * 8 + tx // 4) + col = T.meta_var(i // 2 * 8 + (tx % 4) * 2) + for j in range(2): + D[row, col + j] = D_local[i * 2 + j] # fmt: on src, mod = _get_source(main) @@ -152,59 +151,58 @@ def test_ptx_mma_m16n8k8(a_type, no_c_ptr): b_type = a_type # fmt: off - @Tx.prim_func + @T.prim_func def main( - D: Tx.Buffer((16, 8), "float32"), - A: Tx.Buffer((16, 8), a_type), - B: Tx.Buffer((8, 8), b_type), - C: Tx.Buffer((16, 8), "float32"), + D: T.Buffer((16, 8), "float32"), + A: T.Buffer((16, 8), a_type), + B: T.Buffer((8, 8), b_type), + C: T.Buffer((16, 8), "float32"), ): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) - with Tx.thread(): - D_local = Tx.alloc_local([4], "float32") - A_local = Tx.alloc_local([4], a_type) - B_local = Tx.alloc_local([2], b_type) - C_local = Tx.alloc_local([4], "float32") - - @Tx.inline - def G2L(buf_local, buf_global, block_8x8, mode="row"): - if mode == "row": - for i in range(block_8x8): - row = Tx.meta_var(i % 2 * 8 + tx // 4) - col = Tx.meta_var(i // 2 * 8 + (tx % 4) * 2) - for j in range(2): - buf_local[i * 2 + j] = buf_global[row, col + j] - elif mode == "col": - for i in range(block_8x8): - row = Tx.meta_var(i % 2 * 8 + (tx % 4) * 2) - col = Tx.meta_var(i // 2 * 8 + tx // 4) - for j in range(2): - buf_local[i * 2 + j] = buf_global[row + j, col] - - G2L(D_local, D, 2) - G2L(A_local, A, 2) - G2L(B_local, B, 1, "col") - G2L(C_local, C, 2) - - # One pointer per b32 register, in PTX order: A=2, B=1, D/C=4. - d_ptrs = [D_local.ptr_to([i]) for i in range(4)] - a_ptrs = [A_local.ptr_to([2 * i]) for i in range(2)] - b_ptrs = [B_local.ptr_to([0])] - if no_c_ptr: - Tx.ptx.mma("m16n8k8", "row", "col", "float32", a_type, b_type, "float32", - d_ptrs, a_ptrs, b_ptrs) - else: - c_ptrs = [C_local.ptr_to([i]) for i in range(4)] - Tx.ptx.mma("m16n8k8", "row", "col", "float32", a_type, b_type, "float32", - d_ptrs, a_ptrs, b_ptrs, c_ptrs) - - for i in range(2): - row = Tx.meta_var(i % 2 * 8 + tx // 4) - col = Tx.meta_var(i // 2 * 8 + (tx % 4) * 2) - for j in range(2): - D[row, col + j] = D_local[i * 2 + j] + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) + D_local = T.alloc_local([4], "float32") + A_local = T.alloc_local([4], a_type) + B_local = T.alloc_local([2], b_type) + C_local = T.alloc_local([4], "float32") + + @T.inline + def G2L(buf_local, buf_global, block_8x8, mode="row"): + if mode == "row": + for i in range(block_8x8): + row = T.meta_var(i % 2 * 8 + tx // 4) + col = T.meta_var(i // 2 * 8 + (tx % 4) * 2) + for j in range(2): + buf_local[i * 2 + j] = buf_global[row, col + j] + elif mode == "col": + for i in range(block_8x8): + row = T.meta_var(i % 2 * 8 + (tx % 4) * 2) + col = T.meta_var(i // 2 * 8 + tx // 4) + for j in range(2): + buf_local[i * 2 + j] = buf_global[row + j, col] + + G2L(D_local, D, 2) + G2L(A_local, A, 2) + G2L(B_local, B, 1, "col") + G2L(C_local, C, 2) + + # One pointer per b32 register, in PTX order: A=2, B=1, D/C=4. + d_ptrs = [D_local.ptr_to([i]) for i in range(4)] + a_ptrs = [A_local.ptr_to([2 * i]) for i in range(2)] + b_ptrs = [B_local.ptr_to([0])] + if no_c_ptr: + T.ptx.mma("m16n8k8", "row", "col", "float32", a_type, b_type, "float32", + d_ptrs, a_ptrs, b_ptrs) + else: + c_ptrs = [C_local.ptr_to([i]) for i in range(4)] + T.ptx.mma("m16n8k8", "row", "col", "float32", a_type, b_type, "float32", + d_ptrs, a_ptrs, b_ptrs, c_ptrs) + + for i in range(2): + row = T.meta_var(i % 2 * 8 + tx // 4) + col = T.meta_var(i // 2 * 8 + (tx % 4) * 2) + for j in range(2): + D[row, col + j] = D_local[i * 2 + j] # fmt: on src, mod = _get_source(main) diff --git a/tests/python/tirx/codegen/test_codegen_blackwell.py b/tests/python/tirx/codegen/test_codegen_blackwell.py index d40a87e23616..f6c526a2a193 100644 --- a/tests/python/tirx/codegen/test_codegen_blackwell.py +++ b/tests/python/tirx/codegen/test_codegen_blackwell.py @@ -20,7 +20,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx def _get_source(func: tvm.tirx.PrimFunc) -> str: @@ -37,28 +38,25 @@ def test_tmem_alloc_dealloc_relinquish(): cta_group = 1 # fmt: off - @Tx.prim_func - def test_tmem(A: Tx.Buffer((16, 16), "float16")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([128]) - with Tx.cta(): - # tmem_addr = Tx.alloc_buffer((1,), "uint32", scope="shared", align=8) - tmem_addr = Tx.shared_scalar("uint32") - - # alloc TMEM - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=N_COLS, cta_group=cta_group) # noqa: E501 - Tx.cuda.cta_sync() - - # dealloc TMEM - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) - Tx.ptx.tcgen05.dealloc(tmem_addr, n_cols=N_COLS, cta_group=cta_group) + @T.prim_func + def test_tmem(A: T.Buffer((16, 16), "float16")): + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + tid = T.thread_id([128]) + # tmem_addr = T.alloc_buffer((1,), "uint32", scope="shared", align=8) + tmem_addr = T.shared_scalar("uint32") + + # alloc TMEM + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=N_COLS, cta_group=cta_group) + T.cuda.cta_sync() + + # dealloc TMEM + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) + T.ptx.tcgen05.dealloc(tmem_addr, n_cols=N_COLS, cta_group=cta_group) # fmt: on target = tvm.target.Target("cuda") @@ -72,14 +70,13 @@ def test_tmem(A: Tx.Buffer((16, 16), "float16")): @tvm.testing.requires_cuda_compute_version(10) def test_mbarrier_try_wait_once_codegen(): # fmt: off - @Tx.prim_func - def test_try_wait_once(A: Tx.Buffer((16, 16), "float16")): - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([128]) - with Tx.cta(): - bar = Tx.shared_scalar("uint64") - Tx.evaluate(Tx.ptx.mbarrier.try_wait_once(Tx.address_of(bar), 0, 0)) + @T.prim_func + def test_try_wait_once(A: T.Buffer((16, 16), "float16")): + T.device_entry() + T.cta_id([1]) + T.thread_id([128]) + bar = T.shared_scalar("uint64") + T.evaluate(T.ptx.mbarrier.try_wait_once(T.address_of(bar), 0, 0)) # fmt: on target = tvm.target.Target("cuda") @@ -92,17 +89,16 @@ def test_try_wait_once(A: Tx.Buffer((16, 16), "float16")): @tvm.testing.requires_cuda_compute_version(10) def test_fence_before_after_thread_sync(): # fmt: off - @Tx.prim_func - def test_fence(A: Tx.Buffer((16, 16), "float16")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([128]) - with Tx.thread(): - Tx.ptx.tcgen05.fence.before_thread_sync() - Tx.ptx.bar.sync(0, 32) - Tx.ptx.tcgen05.fence.after_thread_sync() + @T.prim_func + def test_fence(A: T.Buffer((16, 16), "float16")): + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + tid = T.thread_id([128]) + T.ptx.tcgen05.fence.before_thread_sync() + T.ptx.bar.sync(0, 32) + T.ptx.tcgen05.fence.after_thread_sync() # fmt: on target = tvm.target.Target("cuda") @@ -121,51 +117,46 @@ def test_tcgen05_ld_st_roundtrip(): cta_group = 1 # fmt: off - @Tx.prim_func - def test_ld_st(A: Tx.Buffer((HEIGHT, WIDTH), "float32"), B: Tx.Buffer((HEIGHT, WIDTH), "float32")): # noqa: E501 - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - tx = Tx.thread_id([128]) - with Tx.cta(): - reg = Tx.alloc_buffer((WIDTH,), "float32", scope="local") - # tmem_addr = Tx.alloc_buffer((1,), "uint32", scope="shared", align=8) - tmem_addr = Tx.shared_scalar("uint32") - - # alloc TMEM - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=N_COLS, cta_group=cta_group) # noqa: E501 - Tx.cuda.cta_sync() - - with Tx.thread(): - # GMEM -> RF - for i in range(WIDTH): - reg[i] = A[tx, i] - # RF -> TMEM - for i in range(WIDTH): - Tx.ptx.tcgen05.st(tmem_addr, reg[i], shape="32x32b", num=REPEAT_NUM, row=warp_id * 32, col=i) # noqa: E501 - Tx.ptx.tcgen05.wait.st() - Tx.cuda.cta_sync() - # reset RF - for i in range(WIDTH): - reg[i] = 0.0 - Tx.cuda.cta_sync() - # TMEM -> RF - Tx.ptx.tcgen05.fence.after_thread_sync() - for i in range(WIDTH): - Tx.ptx.tcgen05.ld(tmem_addr, reg[i], shape="32x32b", num=REPEAT_NUM, row=warp_id * 32, col=i) # noqa: E501 - Tx.ptx.tcgen05.wait.ld() - # RF -> GMEM - for i in range(WIDTH): - B[tx, i] = reg[i] - - # dealloc TMEM - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) - Tx.ptx.tcgen05.dealloc(tmem_addr, n_cols=N_COLS, cta_group=cta_group) + @T.prim_func + def test_ld_st(A: T.Buffer((HEIGHT, WIDTH), "float32"), B: T.Buffer((HEIGHT, WIDTH), "float32")): # noqa: E501 + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + tx = T.thread_id([128]) + reg = T.alloc_buffer((WIDTH,), "float32", scope="local") + # tmem_addr = T.alloc_buffer((1,), "uint32", scope="shared", align=8) + tmem_addr = T.shared_scalar("uint32") + + # alloc TMEM + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=N_COLS, cta_group=cta_group) + T.cuda.cta_sync() + # GMEM -> RF + for i in range(WIDTH): + reg[i] = A[tx, i] + # RF -> TMEM + for i in range(WIDTH): + T.ptx.tcgen05.st(tmem_addr, reg[i], shape="32x32b", num=REPEAT_NUM, row=warp_id * 32, col=i) # noqa: E501 + T.ptx.tcgen05.wait.st() + T.cuda.cta_sync() + # reset RF + for i in range(WIDTH): + reg[i] = 0.0 + T.cuda.cta_sync() + # TMEM -> RF + T.ptx.tcgen05.fence.after_thread_sync() + for i in range(WIDTH): + T.ptx.tcgen05.ld(tmem_addr, reg[i], shape="32x32b", num=REPEAT_NUM, row=warp_id * 32, col=i) # noqa: E501 + T.ptx.tcgen05.wait.ld() + # RF -> GMEM + for i in range(WIDTH): + B[tx, i] = reg[i] + + # dealloc TMEM + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) + T.ptx.tcgen05.dealloc(tmem_addr, n_cols=N_COLS, cta_group=cta_group) # fmt: on DEV = tvm.cuda(0) @@ -191,69 +182,61 @@ def test_tcgen05_cp_ld_roundtrip(): N_COLS = 512 REPEAT_NUM = 1 SWIZZLE = 0 - A_layout = Tx.TileLayout(Tx.S[(HEIGHT, WIDTH // 4, 4) : (4, HEIGHT * 4, 1)]) + A_layout = T.TileLayout(T.S[(HEIGHT, WIDTH // 4, 4) : (4, HEIGHT * 4, 1)]) ldo, sdo = 128, 8 cta_group = 1 # fmt: off - @Tx.prim_func - def test_cp_ld(A: Tx.Buffer((HEIGHT, WIDTH), dtype, layout=Tx.TileLayout(Tx.S[(HEIGHT, WIDTH // 4, 4) : (4, HEIGHT * 4, 1)])), # noqa: E501 - B: Tx.Buffer((HEIGHT, WIDTH), dtype, layout=Tx.TileLayout(Tx.S[(HEIGHT, WIDTH // 4, 4) : (4, HEIGHT * 4, 1)]))): # noqa: E501 - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - tx = Tx.thread_id([128]) - with Tx.cta(): - A_smem = Tx.alloc_buffer((HEIGHT, WIDTH), dtype, scope="shared", layout=A_layout) - reg = Tx.alloc_buffer((WIDTH,), dtype, scope="local") - # tmem_addr = Tx.alloc_buffer((1,), "uint32", scope="shared", align=8) - tmem_addr = Tx.shared_scalar("uint32") - descA = Tx.alloc_buffer((1,), "uint64", scope="local") - bar = Tx.alloc_buffer((1,), "uint64", scope="shared", align=8) - phase = Tx.alloc_buffer((1,), "int32", scope="local") - - # alloc TMEM - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=N_COLS, cta_group=cta_group) # noqa: E501 - Tx.cuda.cta_sync() - - # GMEM -> SMEM - with Tx.cta(): - Tx.copy(A_smem[:, :], A[:, :]) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - - with Tx.thread(): - # reset RF - for i in range(WIDTH): - reg[i] = 0.0 - # SMEM -> TMEM (cp) - phase[0] = 0 - if tx == 0: - Tx.ptx.mbarrier.init(bar.data, 1) - for k in range(dtype_bits * WIDTH // 256): - Tx.ptx.tcgen05.encode_matrix_descriptor(descA.data, A_smem.access_ptr("r", offset=A_smem.elem_offset_of([0, k * 8])), ldo=ldo, sdo=sdo, swizzle=SWIZZLE) # noqa: E501 - Tx.ptx.tcgen05.cp(tmem_addr, descA[0], shape="128x256b", cta_group=cta_group, col=k * 256 // 32) # noqa: E501 - Tx.ptx.tcgen05.commit(bar.data, cta_group) - Tx.ptx.mbarrier.try_wait(bar.data, phase[0]) - phase[0] = phase[0] ^ 1 - Tx.cuda.cta_sync() - # TMEM -> RF (ld) - Tx.ptx.tcgen05.fence.after_thread_sync() - for i in range(WIDTH): - Tx.ptx.tcgen05.ld(tmem_addr, reg[i], shape="32x32b", num=REPEAT_NUM, row=warp_id * 32, col=i) # noqa: E501 - Tx.ptx.tcgen05.wait.ld() - # RF -> GMEM - for i in range(WIDTH): - B[tx, i] = reg[i] - - # dealloc TMEM - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) - Tx.ptx.tcgen05.dealloc(tmem_addr, n_cols=N_COLS, cta_group=cta_group) + @T.prim_func + def test_cp_ld(A: T.Buffer((HEIGHT, WIDTH), dtype, layout=T.TileLayout(T.S[(HEIGHT, WIDTH // 4, 4) : (4, HEIGHT * 4, 1)])), # noqa: E501 + B: T.Buffer((HEIGHT, WIDTH), dtype, layout=T.TileLayout(T.S[(HEIGHT, WIDTH // 4, 4) : (4, HEIGHT * 4, 1)]))): # noqa: E501 + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + tx = T.thread_id([128]) + A_smem = T.alloc_buffer((HEIGHT, WIDTH), dtype, scope="shared", layout=A_layout) + reg = T.alloc_buffer((WIDTH,), dtype, scope="local") + # tmem_addr = T.alloc_buffer((1,), "uint32", scope="shared", align=8) + tmem_addr = T.shared_scalar("uint32") + descA = T.alloc_buffer((1,), "uint64", scope="local") + bar = T.alloc_buffer((1,), "uint64", scope="shared", align=8) + phase = T.alloc_buffer((1,), "int32", scope="local") + + # alloc TMEM + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=N_COLS, cta_group=cta_group) + T.cuda.cta_sync() + Tx.cta.copy(A_smem[:, :], A[:, :]) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + # reset RF + for i in range(WIDTH): + reg[i] = 0.0 + # SMEM -> TMEM (cp) + phase[0] = 0 + if tx == 0: + T.ptx.mbarrier.init(bar.data, 1) + for k in range(dtype_bits * WIDTH // 256): + T.ptx.tcgen05.encode_matrix_descriptor(descA.data, A_smem.access_ptr("r", offset=A_smem.elem_offset_of([0, k * 8])), ldo=ldo, sdo=sdo, swizzle=SWIZZLE) # noqa: E501 + T.ptx.tcgen05.cp(tmem_addr, descA[0], shape="128x256b", cta_group=cta_group, col=k * 256 // 32) # noqa: E501 + T.ptx.tcgen05.commit(bar.data, cta_group) + T.ptx.mbarrier.try_wait(bar.data, phase[0]) + phase[0] = phase[0] ^ 1 + T.cuda.cta_sync() + # TMEM -> RF (ld) + T.ptx.tcgen05.fence.after_thread_sync() + for i in range(WIDTH): + T.ptx.tcgen05.ld(tmem_addr, reg[i], shape="32x32b", num=REPEAT_NUM, row=warp_id * 32, col=i) # noqa: E501 + T.ptx.tcgen05.wait.ld() + # RF -> GMEM + for i in range(WIDTH): + B[tx, i] = reg[i] + + # dealloc TMEM + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) + T.ptx.tcgen05.dealloc(tmem_addr, n_cols=N_COLS, cta_group=cta_group) # fmt: on DEV = tvm.cuda(0) @@ -282,37 +265,37 @@ def test_tcgen05_mma_ss_no_tma(swizzle): cta_group = 1 if SWIZZLE == 0: - A_layout = Tx.TileLayout(Tx.S[(M, K // 8, 8) : (8, M * 8, 1)]) - B_layout = Tx.TileLayout(Tx.S[(N, K // 8, 8) : (8, N * 8, 1)]) + A_layout = T.TileLayout(T.S[(M, K // 8, 8) : (8, M * 8, 1)]) + B_layout = T.TileLayout(T.S[(N, K // 8, 8) : (8, N * 8, 1)]) ldo, sdo = 128, 8 elif SWIZZLE == 1: - A_layout = Tx.ComposeLayout( - Tx.SwizzleLayout(3, 1, 3, swizzle_inner=True), - Tx.TileLayout(Tx.S[(M, K // 16, 16) : (16, M * 16, 1)]), + A_layout = T.ComposeLayout( + T.SwizzleLayout(3, 1, 3, swizzle_inner=True), + T.TileLayout(T.S[(M, K // 16, 16) : (16, M * 16, 1)]), ) - B_layout = Tx.ComposeLayout( - Tx.SwizzleLayout(3, 1, 3, swizzle_inner=True), - Tx.TileLayout(Tx.S[(N, K // 16, 16) : (16, N * 16, 1)]), + B_layout = T.ComposeLayout( + T.SwizzleLayout(3, 1, 3, swizzle_inner=True), + T.TileLayout(T.S[(N, K // 16, 16) : (16, N * 16, 1)]), ) ldo, sdo = 256, 16 elif SWIZZLE == 2: - A_layout = Tx.ComposeLayout( - Tx.SwizzleLayout(3, 2, 3, swizzle_inner=True), - Tx.TileLayout(Tx.S[(M, K // 32, 32) : (32, M * 32, 1)]), + A_layout = T.ComposeLayout( + T.SwizzleLayout(3, 2, 3, swizzle_inner=True), + T.TileLayout(T.S[(M, K // 32, 32) : (32, M * 32, 1)]), ) - B_layout = Tx.ComposeLayout( - Tx.SwizzleLayout(3, 2, 3, swizzle_inner=True), - Tx.TileLayout(Tx.S[(N, K // 32, 32) : (32, N * 32, 1)]), + B_layout = T.ComposeLayout( + T.SwizzleLayout(3, 2, 3, swizzle_inner=True), + T.TileLayout(T.S[(N, K // 32, 32) : (32, N * 32, 1)]), ) ldo, sdo = 512, 32 elif SWIZZLE == 3: - A_layout = Tx.ComposeLayout( - Tx.SwizzleLayout(3, 3, 3, swizzle_inner=True), - Tx.TileLayout(Tx.S[(M, 1, 64) : (64, M * 64, 1)]), + A_layout = T.ComposeLayout( + T.SwizzleLayout(3, 3, 3, swizzle_inner=True), + T.TileLayout(T.S[(M, 1, 64) : (64, M * 64, 1)]), ) - B_layout = Tx.ComposeLayout( - Tx.SwizzleLayout(3, 3, 3, swizzle_inner=True), - Tx.TileLayout(Tx.S[(N, 1, 64) : (64, N * 64, 1)]), + B_layout = T.ComposeLayout( + T.SwizzleLayout(3, 3, 3, swizzle_inner=True), + T.TileLayout(T.S[(N, 1, 64) : (64, N * 64, 1)]), ) ldo, sdo = 1, 64 else: @@ -321,78 +304,67 @@ def test_tcgen05_mma_ss_no_tma(swizzle): dyn_smem_bytes = 1024 + (M * K + N * K) * 2 # fmt: off - @Tx.prim_func - def test_mma_ss_no_tma(A: Tx.Buffer((M, K), a_type, layout=Tx.TileLayout(Tx.S[M, K])), - B: Tx.Buffer((N, K), b_type, layout=Tx.TileLayout(Tx.S[N, K])), - C: Tx.Buffer((M, N), d_type)): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - tx = Tx.thread_id([128]) - with Tx.cta(): - dyn = Tx.alloc_buffer((dyn_smem_bytes,), "uint8", scope="shared") - tmem_addr = Tx.decl_scalar("uint32", dyn.data, scope="shared", elem_offset=0) - A_smem = Tx.decl_buffer((M, K), a_type, dyn.data, elem_offset=256, layout=A_layout) - B_smem = Tx.decl_buffer((N, K), b_type, dyn.data, elem_offset=256 + M*K, layout=B_layout) # noqa: E501 - bar = Tx.decl_buffer((1,), "uint64", dyn.data, scope="shared", elem_offset=8) - - reg = Tx.alloc_buffer((N,), d_type, scope="local") - descA = Tx.alloc_buffer((1,), "uint64", scope="local") - descB = Tx.alloc_buffer((1,), "uint64", scope="local") - descI = Tx.alloc_buffer((1,), "uint32", scope="local") - phase = Tx.alloc_buffer((1,), "int32", scope="local") - - # alloc TMEM - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=N_COLS, cta_group=cta_group) # noqa: E501 - Tx.cuda.cta_sync() - - # reset RF - with Tx.thread(): - for i in range(N): - reg[i] = 0.0 - - # GMEM -> SMEM - with Tx.cta(): - Tx.copy(A_smem[:, :], A[:, :]) - Tx.copy(B_smem[:, :], B[:, :]) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - - with Tx.thread(): - # MMA - phase[0] = 0 - if tx == 0: - Tx.ptx.mbarrier.init(bar.data, 1) - Tx.ptx.tcgen05.encode_instr_descriptor(descI.data, d_dtype=d_type, a_dtype=a_type, b_dtype=b_type, M=M, N=N, K=MMA_K, trans_a=False, trans_b=False, n_cta_groups=cta_group) # noqa: E501 - for k in range(K // MMA_K): - Tx.ptx.tcgen05.encode_matrix_descriptor(descA.data, A_smem.access_ptr("r", offset=A_smem.elem_offset_of([0, k * MMA_K])), ldo=ldo, sdo=sdo, swizzle=SWIZZLE) # noqa: E501 - Tx.ptx.tcgen05.encode_matrix_descriptor(descB.data, B_smem.access_ptr("r", offset=B_smem.elem_offset_of([0, k * MMA_K])), ldo=ldo, sdo=sdo, swizzle=SWIZZLE) # noqa: E501 - if k == 0: - Tx.ptx.tcgen05.mma(tmem_addr, descA[0], descB[0], descI[0], d_dtype=d_type, a_dtype=a_type, b_dtype=b_type, use_a_tmem=False, cta_group=cta_group, enable_input_d=0) # noqa: E501 - else: - Tx.ptx.tcgen05.mma(tmem_addr, descA[0], descB[0], descI[0], d_dtype=d_type, a_dtype=a_type, b_dtype=b_type, use_a_tmem=False, cta_group=cta_group, enable_input_d=1) # noqa: E501 - Tx.ptx.tcgen05.commit(bar.data, cta_group) - Tx.ptx.mbarrier.try_wait(bar.data, phase[0]) - phase[0] = phase[0] ^ 1 - Tx.cuda.cta_sync() - - # TMEM -> RF - Tx.ptx.tcgen05.fence.after_thread_sync() - for i in range(N): - Tx.ptx.tcgen05.ld(tmem_addr, reg[i], shape="32x32b", num=REPEAT_NUM, row=warp_id * 32, col=i) # noqa: E501 - Tx.ptx.tcgen05.wait.ld() - # RF -> GMEM - for i in range(N): - C[tx, i] = reg[i] - - # dealloc TMEM - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) - Tx.ptx.tcgen05.dealloc(tmem_addr, n_cols=N_COLS, cta_group=cta_group) + @T.prim_func + def test_mma_ss_no_tma(A: T.Buffer((M, K), a_type, layout=T.TileLayout(T.S[M, K])), + B: T.Buffer((N, K), b_type, layout=T.TileLayout(T.S[N, K])), + C: T.Buffer((M, N), d_type)): + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + tx = T.thread_id([128]) + dyn = T.alloc_buffer((dyn_smem_bytes,), "uint8", scope="shared") + tmem_addr = T.decl_scalar("uint32", dyn.data, scope="shared", elem_offset=0) + A_smem = T.decl_buffer((M, K), a_type, dyn.data, elem_offset=256, layout=A_layout) + B_smem = T.decl_buffer((N, K), b_type, dyn.data, elem_offset=256 + M*K, layout=B_layout) + bar = T.decl_buffer((1,), "uint64", dyn.data, scope="shared", elem_offset=8) + + reg = T.alloc_buffer((N,), d_type, scope="local") + descA = T.alloc_buffer((1,), "uint64", scope="local") + descB = T.alloc_buffer((1,), "uint64", scope="local") + descI = T.alloc_buffer((1,), "uint32", scope="local") + phase = T.alloc_buffer((1,), "int32", scope="local") + + # alloc TMEM + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=N_COLS, cta_group=cta_group) + T.cuda.cta_sync() + for i in range(N): + reg[i] = 0.0 + Tx.cta.copy(A_smem[:, :], A[:, :]) + Tx.cta.copy(B_smem[:, :], B[:, :]) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + # MMA + phase[0] = 0 + if tx == 0: + T.ptx.mbarrier.init(bar.data, 1) + T.ptx.tcgen05.encode_instr_descriptor(descI.data, d_dtype=d_type, a_dtype=a_type, b_dtype=b_type, M=M, N=N, K=MMA_K, trans_a=False, trans_b=False, n_cta_groups=cta_group) # noqa: E501 + for k in range(K // MMA_K): + T.ptx.tcgen05.encode_matrix_descriptor(descA.data, A_smem.access_ptr("r", offset=A_smem.elem_offset_of([0, k * MMA_K])), ldo=ldo, sdo=sdo, swizzle=SWIZZLE) # noqa: E501 + T.ptx.tcgen05.encode_matrix_descriptor(descB.data, B_smem.access_ptr("r", offset=B_smem.elem_offset_of([0, k * MMA_K])), ldo=ldo, sdo=sdo, swizzle=SWIZZLE) # noqa: E501 + if k == 0: + T.ptx.tcgen05.mma(tmem_addr, descA[0], descB[0], descI[0], d_dtype=d_type, a_dtype=a_type, b_dtype=b_type, use_a_tmem=False, cta_group=cta_group, enable_input_d=0) # noqa: E501 + else: + T.ptx.tcgen05.mma(tmem_addr, descA[0], descB[0], descI[0], d_dtype=d_type, a_dtype=a_type, b_dtype=b_type, use_a_tmem=False, cta_group=cta_group, enable_input_d=1) # noqa: E501 + T.ptx.tcgen05.commit(bar.data, cta_group) + T.ptx.mbarrier.try_wait(bar.data, phase[0]) + phase[0] = phase[0] ^ 1 + T.cuda.cta_sync() + + # TMEM -> RF + T.ptx.tcgen05.fence.after_thread_sync() + for i in range(N): + T.ptx.tcgen05.ld(tmem_addr, reg[i], shape="32x32b", num=REPEAT_NUM, row=warp_id * 32, col=i) # noqa: E501 + T.ptx.tcgen05.wait.ld() + # RF -> GMEM + for i in range(N): + C[tx, i] = reg[i] + + # dealloc TMEM + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) + T.ptx.tcgen05.dealloc(tmem_addr, n_cols=N_COLS, cta_group=cta_group) # fmt: on import torch diff --git a/tests/python/tirx/codegen/test_codegen_cuda.py b/tests/python/tirx/codegen/test_codegen_cuda.py index 563fbd2ecfc9..f253d6d375c6 100644 --- a/tests/python/tirx/codegen/test_codegen_cuda.py +++ b/tests/python/tirx/codegen/test_codegen_cuda.py @@ -20,7 +20,7 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T DEV = tvm.device("cuda") @@ -41,17 +41,45 @@ def _helper_source(src: str, helper_name: str) -> str: return src[start:next_helper] +def test_tirx_launch_bounds_omits_min_blocks_without_persistent_schedule(): + @T.prim_func + def main(A: T.Buffer((4,), "int32")): + T.device_entry() + bx = T.cta_id([4]) + tx = T.thread_id([128]) + if tx == 0: + A[bx] = A[bx] + 1 + + src, _ = _get_source(main) + assert 'extern "C" __global__ void __launch_bounds__(128) main_kernel' in src + assert "__launch_bounds__(128, 1)" not in src + + +def test_tirx_launch_bounds_min_blocks_attr_sets_one_block_per_sm(): + @T.prim_func + def main(A: T.Buffer((4,), "int32")): + T.device_entry() + T.attr({"tirx.launch_bounds_min_blocks_per_sm": 1}) + bx = T.cta_id([4]) + tx = T.thread_id([128]) + if tx == 0: + A[bx] = A[bx] + 1 + + src, _ = _get_source(main) + assert 'extern "C" __global__ void __launch_bounds__(128, 1) main_kernel' in src + assert "tirx.launch_bounds_min_blocks_per_sm" not in src + + def test_serial_pragma_unroll_codegen(): - @Tx.prim_func - def main(A: Tx.Buffer((4,), "int32")): - Tx.device_entry() - tx = Tx.thread_id([32]) + @T.prim_func + def main(A: T.Buffer((4,), "int32")): + T.device_entry() + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - for i in Tx.serial(4, unroll=True): - if i == 2: - break - A[i] = A[i] + 1 + for i in T.serial(4, unroll=True): + if i == 2: + break + A[i] = A[i] + 1 src, _ = _get_source(main) assert "#pragma unroll\n" in src @@ -60,14 +88,13 @@ def main(A: Tx.Buffer((4,), "int32")): def test_cluster_cta_id_codegen_uses_coordinate_sregs(): - @Tx.prim_func - def main(A: Tx.Buffer((1,), "int32")): - Tx.device_entry() - cbx, cby = Tx.cta_id_in_cluster([2, 2]) - tx = Tx.thread_id([32]) + @T.prim_func + def main(A: T.Buffer((1,), "int32")): + T.device_entry() + cbx, cby = T.cta_id_in_cluster([2, 2]) + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - A[0] = cbx + cby + A[0] = cbx + cby src, _ = _get_source(main) assert "%cluster_ctaid.x" in src @@ -77,14 +104,13 @@ def main(A: Tx.Buffer((1,), "int32")): def test_cuda_handle_uint64_reinterpret_codegen(): - @Tx.prim_func - def main(A: Tx.Buffer((1,), "uint64")): - Tx.device_entry() - tx = Tx.thread_id([32]) + @T.prim_func + def main(A: T.Buffer((1,), "uint64")): + T.device_entry() + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - ptr = Tx.reinterpret("handle", A[0]) - A[0] = Tx.reinterpret("uint64", ptr) + ptr = T.reinterpret("handle", A[0]) + A[0] = T.reinterpret("uint64", ptr) src, _ = _get_source(main) assert "reinterpret_cast" in src @@ -93,15 +119,14 @@ def main(A: Tx.Buffer((1,), "uint64")): def test_cuda_atomic_add(): - @Tx.prim_func - def main(A: Tx.Buffer((1,), "int32"), B: Tx.Buffer((1,), "float32")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) + @T.prim_func + def main(A: T.Buffer((1,), "int32"), B: T.Buffer((1,), "float32")): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - Tx.cuda.atomic_add(A.data, Tx.int32(1)) - Tx.cuda.atomic_add(B.data, Tx.float32(1.0)) + T.cuda.atomic_add(A.data, T.int32(1)) + T.cuda.atomic_add(B.data, T.float32(1.0)) src, mod = _get_source(main) assert "tvm_builtin_cuda_atomic_add" in src @@ -115,19 +140,16 @@ def main(A: Tx.Buffer((1,), "int32"), B: Tx.Buffer((1,), "float32")): def test_ptx_ld_acquire_and_volatile_codegen(): - @Tx.prim_func - def main( - A: Tx.Buffer((1,), "uint64"), B: Tx.Buffer((1,), "int32"), C: Tx.Buffer((1,), "uint32") - ): - Tx.device_entry() - tx = Tx.thread_id([32]) + @T.prim_func + def main(A: T.Buffer((1,), "uint64"), B: T.Buffer((1,), "int32"), C: T.Buffer((1,), "uint32")): + T.device_entry() + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - A[0] = Tx.ptx.ld_acquire(A.data, "uint64", "u64", scope="gpu", space="global") - B[0] = Tx.ptx.ld_acquire(B.data, "int32", "s32", scope="sys", space="global") - C[0] = Tx.ptx.ld_acquire(C.data, "uint32", "b32", scope="gpu", space="global") - Tx.ptx.ld_global_acquire(B[0], B.data) - A[0] = Tx.ptx.ld_volatile(A.data, "uint64", "u64", space="global") + A[0] = T.ptx.ld_acquire(A.data, "uint64", "u64", scope="gpu", space="global") + B[0] = T.ptx.ld_acquire(B.data, "int32", "s32", scope="sys", space="global") + C[0] = T.ptx.ld_acquire(C.data, "uint32", "b32", scope="gpu", space="global") + T.ptx.ld_global_acquire(B[0], B.data) + A[0] = T.ptx.ld_volatile(A.data, "uint64", "u64", space="global") src, _ = _get_source(main) assert "ld.acquire.gpu.global.u64" in src @@ -139,84 +161,83 @@ def main( def test_megamoe_extracted_intrinsics_codegen(): - @Tx.prim_func + @T.prim_func def main( - U32: Tx.Buffer((4,), "uint32"), - I32: Tx.Buffer((1,), "int32"), - U64: Tx.Buffer((1,), "uint64"), - F32: Tx.Buffer((4,), "float32"), + U32: T.Buffer((4,), "uint32"), + I32: T.Buffer((1,), "int32"), + U64: T.Buffer((1,), "uint64"), + F32: T.Buffer((4,), "float32"), ): - Tx.device_entry() - tx = Tx.thread_id([32]) + T.device_entry() + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - Tx.ptx.red_scalar( - U64.data, - U64[0], - sem="release", - scope="gpu", - space="global", - op="or", - ptx_type="b64", - ) - Tx.ptx.red_scalar( - I32.data, - I32[0], - sem="release", - scope="sys", - space="global", - op="add", - ptx_type="s32", - ) - U32[0] = Tx.ptx.atom_scalar( - U32.data, - U32[0], - sem="release", - scope="gpu", - space="global", - op="add", - ptx_type="u32", - ) - U64[0] = Tx.ptx.atom_scalar( - U64.data, U64[0], scope="sys", space="global", op="add", ptx_type="u64" - ) - Tx.ptx.red_scalar( - U32.data, U32[0], scope="gpu", space="global", op="add", ptx_type="u32" - ) - Tx.ptx.st(U32.data, U32[0], space="shared", ptx_type="u32") - Tx.ptx.st( - U32.data, - U32[0], - U32[1], - U32[2], - U32[3], - space="shared", - vec="v4", - ptx_type="b32", - ) - Tx.ptx.st_bulk(U32.data, Tx.uint32(16), weak=True, space="shared::cta") - U32[0] = Tx.ptx.fns_b32(U32[0], U32[1], I32[0]) - Tx.ptx.stmatrix( - True, # trans - 1, # num - ".b8", # dtype - U32.data, # smem_ptr - U32.data, # src0 - shape="m16n8", - space="shared", - ) + T.ptx.red_scalar( + U64.data, + U64[0], + sem="release", + scope="gpu", + space="global", + op="or", + ptx_type="b64", + ) + T.ptx.red_scalar( + I32.data, + I32[0], + sem="release", + scope="sys", + space="global", + op="add", + ptx_type="s32", + ) + U32[0] = T.ptx.atom_scalar( + U32.data, + U32[0], + sem="release", + scope="gpu", + space="global", + op="add", + ptx_type="u32", + ) + U64[0] = T.ptx.atom_scalar( + U64.data, U64[0], scope="sys", space="global", op="add", ptx_type="u64" + ) + T.ptx.red_scalar( + U32.data, U32[0], scope="gpu", space="global", op="add", ptx_type="u32" + ) + T.ptx.st(U32.data, U32[0], space="shared", ptx_type="u32") + T.ptx.st( + U32.data, + U32[0], + U32[1], + U32[2], + U32[3], + space="shared", + vec="v4", + ptx_type="b32", + ) + T.ptx.st_bulk(U32.data, T.uint32(16), weak=True, space="shared::cta") + U32[0] = T.ptx.fns_b32(U32[0], U32[1], I32[0]) + T.ptx.stmatrix( + True, # trans + 1, # num + ".b8", # dtype + U32.data, # smem_ptr + U32.data, # src0 + shape="m16n8", + space="shared", + ) - F32[1] = Tx.cuda.uint_as_float(U32[0]) - F32[2] = Tx.ptx.ld(F32.data, "float32", "f32", space="global") - U32[3] = Tx.cuda.float_as_uint(F32[1]) - F32[0] = Tx.ptx.add_rn_f32_bf16(F32[0], Tx.cast(U32[0], "uint16")) - U64[0] = Tx.reinterpret("uint64", U32.data) - U32[0] = Tx.cuda.ballot_sync(Tx.uint32(0xFFFFFFFF), I32[0]) - I32[0] = Tx.cuda.ffs_u32(U32[0]) - U32[0] = Tx.cuda.reduce_add_sync_u32(Tx.uint32(0xFFFFFFFF), U32[0]) - U32[0] = Tx.cuda.reduce_min_sync_u32(Tx.uint32(0xFFFFFFFF), U32[0]) - U64[0] = Tx.cuda.clock64() - U32[0] = Tx.cuda.float22bfloat162_rn(F32[0], F32[1]) + F32[1] = T.cuda.uint_as_float(U32[0]) + F32[2] = T.ptx.ld(F32.data, "float32", "f32", space="global") + U32[3] = T.cuda.float_as_uint(F32[1]) + F32[0] = T.ptx.add_rn_f32_bf16(F32[0], T.cast(U32[0], "uint16")) + U64[0] = T.reinterpret("uint64", U32.data) + U32[0] = T.cuda.ballot_sync(T.uint32(0xFFFFFFFF), I32[0]) + I32[0] = T.cuda.ffs_u32(U32[0]) + U32[0] = T.cuda.reduce_add_sync_u32(T.uint32(0xFFFFFFFF), U32[0]) + U32[0] = T.cuda.reduce_min_sync_u32(T.uint32(0xFFFFFFFF), U32[0]) + U64[0] = T.cuda.clock64() + U32[0] = T.cuda.float22bfloat162_rn(F32[0], F32[1]) src, _ = _get_source(main) for snippet in [ @@ -245,24 +266,23 @@ def main( def test_ptx_cp_async_bulk_non_tma_form_codegen(): - @Tx.prim_func + @T.prim_func def main( - A: Tx.Buffer((128,), "float32"), - B: Tx.Buffer((128,), "float32"), - C: Tx.Buffer((1,), "uint64"), + A: T.Buffer((128,), "float32"), + B: T.Buffer((128,), "float32"), + C: T.Buffer((1,), "uint64"), ): - Tx.device_entry() - tx = Tx.thread_id([32]) + T.device_entry() + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - smem = Tx.alloc_shared([128], "float32") - Tx.ptx.cp_async_bulk_g2s_cta( - smem.ptr_to([0]), A.data, Tx.uint32(64), smem.ptr_to([0]), cache_policy=C[0] - ) - Tx.ptx.cp_async_bulk_g2s_cluster( - smem.ptr_to([0]), A.data, Tx.uint32(64), smem.ptr_to([0]), cache_policy=C[0] - ) - Tx.ptx.cp_async_bulk_s2g(B.data, smem.ptr_to([0]), Tx.uint32(64), cache_policy=C[0]) + smem = T.alloc_shared([128], "float32") + T.ptx.cp_async_bulk_g2s_cta( + smem.ptr_to([0]), A.data, T.uint32(64), smem.ptr_to([0]), cache_policy=C[0] + ) + T.ptx.cp_async_bulk_g2s_cluster( + smem.ptr_to([0]), A.data, T.uint32(64), smem.ptr_to([0]), cache_policy=C[0] + ) + T.ptx.cp_async_bulk_s2g(B.data, smem.ptr_to([0]), T.uint32(64), cache_policy=C[0]) src, _ = _get_source(main) assert "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" in src @@ -272,13 +292,12 @@ def main( def test_tensor_map_param_codegen(): - @Tx.prim_func - def main(A_map: Tx.TensorMap()): - Tx.device_entry() - tx = Tx.thread_id([32]) + @T.prim_func + def main(A_map: T.TensorMap()): + T.device_entry() + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - Tx.evaluate(Tx.address_of(A_map)) + T.evaluate(T.address_of(A_map)) src, _ = _get_source(main) assert "const __grid_constant__ CUtensorMap A_map" in src @@ -286,22 +305,62 @@ def main(A_map: Tx.TensorMap()): def test_tma_cache_policy_operand_codegen(): - @Tx.prim_func - def main(Cache: Tx.Buffer((1,), "uint64")): - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - B_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) + @T.prim_func + def main(Cache: T.Buffer((1,), "uint64")): + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + B_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) - Tx.device_entry() - tx = Tx.thread_id([32]) + T.device_entry() + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - smem = Tx.alloc_buffer((128,), "float32", scope="shared", align=128) - bar = Tx.shared_scalar("uint64") - Tx.ptx.cp_async.bulk.tensor.g2c( + smem = T.alloc_buffer((128,), "float32", scope="shared", align=128) + bar = T.shared_scalar("uint64") + T.ptx.cp_async.bulk.tensor.g2c( + 2, + smem.data, + T.address_of(bar), + T.address_of(A_map), + 1, + 2, + "", + 0, + 0, + cache_policy=Cache[0], + ) + T.ptx.cp_async.bulk.tensor.g2c( + 2, + smem.data, + T.address_of(bar), + T.address_of(A_map), + 3, + 2, + "", + 0, + 0, + cache_policy=Cache[0], + ) + T.ptx.cp_async.bulk.tensor.s2g( + 2, smem.data, T.address_of(A_map), "", 0, 0, cache_policy=Cache[0] + ) + masked_bar = T.cuda.sm100_tma_2sm_mbarrier_addr(T.address_of(bar)) + T.ptx.cp_async.bulk.tensor.g2c_bar_addr( + 2, + smem.data, + masked_bar, + T.address_of(A_map), + 1, + 2, + "", + 0, + 0, + cache_policy=Cache[0], + ) + if tx == 0: + T.ptx.cp_async.bulk.tensor.g2c_bar_addr( 2, smem.data, - Tx.address_of(bar), - Tx.address_of(A_map), + masked_bar, + T.address_of(A_map), 1, 2, "", @@ -309,27 +368,12 @@ def main(Cache: Tx.Buffer((1,), "uint64")): 0, cache_policy=Cache[0], ) - Tx.ptx.cp_async.bulk.tensor.g2c( - 2, - smem.data, - Tx.address_of(bar), - Tx.address_of(A_map), - 3, - 2, - "", - 0, - 0, - cache_policy=Cache[0], - ) - Tx.ptx.cp_async.bulk.tensor.s2g( - 2, smem.data, Tx.address_of(A_map), "", 0, 0, cache_policy=Cache[0] - ) - masked_bar = Tx.cuda.sm100_tma_2sm_mbarrier_addr(Tx.address_of(bar)) - Tx.ptx.cp_async.bulk.tensor.g2c_bar_addr( + else: + T.ptx.cp_async.bulk.tensor.g2c_bar_addr( 2, smem.data, masked_bar, - Tx.address_of(A_map), + T.address_of(B_map), 1, 2, "", @@ -337,32 +381,6 @@ def main(Cache: Tx.Buffer((1,), "uint64")): 0, cache_policy=Cache[0], ) - if tx == 0: - Tx.ptx.cp_async.bulk.tensor.g2c_bar_addr( - 2, - smem.data, - masked_bar, - Tx.address_of(A_map), - 1, - 2, - "", - 0, - 0, - cache_policy=Cache[0], - ) - else: - Tx.ptx.cp_async.bulk.tensor.g2c_bar_addr( - 2, - smem.data, - masked_bar, - Tx.address_of(B_map), - 1, - 2, - "", - 0, - 0, - cache_policy=Cache[0], - ) src, _ = _get_source(main) assert "ptx_cp_async_bulk_tensor_g2cluster_tile_2d_cache_hint" in src @@ -386,42 +404,39 @@ def main(Cache: Tx.Buffer((1,), "uint64")): def test_cuda_thread_fence(): - @Tx.prim_func - def main(A: Tx.Buffer((16, 16), "int32")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) + @T.prim_func + def main(A: T.Buffer((16, 16), "int32")): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - Tx.cuda.thread_fence() + T.cuda.thread_fence() src, mod = _get_source(main) assert "tvm_builtin_cuda_thread_fence" in src def test_cuda_nano_sleep(): - @Tx.prim_func - def main(A: Tx.Buffer((16, 16), "int32")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) + @T.prim_func + def main(A: T.Buffer((16, 16), "int32")): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - Tx.cuda.nano_sleep(1) + T.cuda.nano_sleep(1) src, mod = _get_source(main) assert "tvm_builtin_cuda_nano_sleep" in src def test_cuda_atomic_cas(): - @Tx.prim_func - def main(A: Tx.Buffer((16, 16), "int32")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) + @T.prim_func + def main(A: T.Buffer((16, 16), "int32")): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - Tx.cuda.atomic_cas(A.data, Tx.int32(1), Tx.int32(2)) + T.cuda.atomic_cas(A.data, T.int32(1), T.int32(2)) src, mod = _get_source(main) assert "tvm_builtin_cuda_atomic_cas" in src @@ -435,17 +450,16 @@ def test_add_one(): } """ - @Tx.prim_func - def main(a: Tx.Buffer((16, 16), "int32"), b: Tx.Buffer((16, 16), "int32")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) + @T.prim_func + def main(a: T.Buffer((16, 16), "int32"), b: T.Buffer((16, 16), "int32")): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - for i, j in Tx.grid(16, 16): - b[i, j] = Tx.cuda.func_call( - "add_one", a[i, j], source_code=add_one, return_type="int32" - ) + for i, j in T.grid(16, 16): + b[i, j] = T.cuda.func_call( + "add_one", a[i, j], source_code=add_one, return_type="int32" + ) src, mod = _get_source(main) A = np.random.randint(0, 10, (16, 16)).astype("int32") @@ -465,15 +479,14 @@ def test_print(): } """ - @Tx.prim_func - def main(a: Tx.Buffer((16, 16), "int32")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) + @T.prim_func + def main(a: T.Buffer((16, 16), "int32")): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) if tx == 0: - with Tx.thread(): - for i, j in Tx.grid(16, 16): - Tx.cuda.func_call("print", a[i, j], source_code=print_func) + for i, j in T.grid(16, 16): + T.cuda.func_call("print", a[i, j], source_code=print_func) src, mod = _get_source(main) A = np.random.randint(0, 10, (16, 16)).astype("int32") @@ -486,22 +499,22 @@ def main(a: Tx.Buffer((16, 16), "int32")): def test_warp_shuffle_xor_sync(): # fmt: off - @Tx.prim_func - def func(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (32,), dtype="float32", align=16) + @T.prim_func + def func(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (32,), dtype="float32", align=16) - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) - A_local = Tx.alloc_buffer([1], "float32", scope="local") - i = Tx.alloc_buffer([1], "int32", scope="local") + A_local = T.alloc_buffer([1], "float32", scope="local") + i = T.alloc_buffer([1], "int32", scope="local") - A_local[0] = Tx.float32(31 - lane_id) + A_local[0] = T.float32(31 - lane_id) i[0] = 16 while i[0] >= 1: - A_local[0] += Tx.tvm_warp_shuffle_xor(0xFFFFFFFF, A_local[0], i[0], 32, 32) + A_local[0] += T.tvm_warp_shuffle_xor(0xFFFFFFFF, A_local[0], i[0], 32, 32) i[0] = i[0] // 2 A[lane_id] = A_local[0] @@ -522,7 +535,7 @@ def func(A_ptr: Tx.handle): @pytest.mark.parametrize("cp_size", [4, 8, 16]) @pytest.mark.parametrize("cache_hint", ["", "evict_last"]) @pytest.mark.parametrize("prefetch_size", [-1, 64, 128, 256]) -@pytest.mark.parametrize("predicate", [-1, Tx.int32(0), Tx.int32(1)]) +@pytest.mark.parametrize("predicate", [-1, T.int32(0), T.int32(1)]) @pytest.mark.parametrize("fill_mode", ["", "zero"]) def test_ptx_cp_async(cp_size, cache_hint, prefetch_size, predicate, fill_mode): if fill_mode != "" and predicate == -1: @@ -531,19 +544,19 @@ def test_ptx_cp_async(cp_size, cache_hint, prefetch_size, predicate, fill_mode): N = cp_size // 2 # fmt: off - @Tx.prim_func - def main(A: Tx.Buffer((N), "float16")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([32]) - A_shared = Tx.alloc_shared([N], "float16") - for i in Tx.vectorized(N): + @T.prim_func + def main(A: T.Buffer((N), "float16")): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([32]) + A_shared = T.alloc_shared([N], "float16") + for i in T.vectorized(N): A_shared[i] = 5.0 - Tx.ptx.fence.proxy_async("shared::cta") - Tx.ptx.cp_async(A_shared.ptr_to([0]), A.ptr_to([0]), cp_size, cache_hint=cache_hint, prefetch_size=prefetch_size, predicate=predicate, fill_mode=fill_mode) # noqa: E501 - Tx.ptx.cp_async.commit_group() - Tx.ptx.cp_async.wait_group(0) - for i in Tx.serial(N): + T.ptx.fence.proxy_async("shared::cta") + T.ptx.cp_async(A_shared.ptr_to([0]), A.ptr_to([0]), cp_size, cache_hint=cache_hint, prefetch_size=prefetch_size, predicate=predicate, fill_mode=fill_mode) # noqa: E501 + T.ptx.cp_async.commit_group() + T.ptx.cp_async.wait_group(0) + for i in T.serial(N): A[i] = A_shared[i] + 1.0 # fmt: on @@ -568,47 +581,46 @@ def test_ptx_ldmatrix(trans, num): dtype = ".b16" # fmt: off - @Tx.prim_func - def main(A: Tx.Buffer((16, 16), "float16"), B: Tx.Buffer((16, 16), "float16")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) - A_shared = Tx.alloc_shared([16, 16], "float16") + @T.prim_func + def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) + A_shared = T.alloc_shared([16, 16], "float16") if tx == 0: - with Tx.thread(): - for i, j in Tx.grid(16, 16): - A_shared[i, j] = A[i, j] - Tx.cuda.cta_sync() - A_local = Tx.alloc_local([8], "float16") + for i, j in T.grid(16, 16): + A_shared[i, j] = A[i, j] + T.cuda.cta_sync() + A_local = T.alloc_local([8], "float16") A_local[0] = -1.0 # ldmatrix .x{num}.b16 writes `num` 32-bit registers; A_local # is a contiguous fp16[8] buffer, so consecutive register # destinations land 2 fp16 elements apart. if num == 1: - Tx.ptx.ldmatrix( + T.ptx.ldmatrix( trans, num, dtype, A_shared.ptr_to([tx % 16, tx // 16 * 8]), - Tx.address_of(A_local[0]), + T.address_of(A_local[0]), ) elif num == 2: - Tx.ptx.ldmatrix( + T.ptx.ldmatrix( trans, num, dtype, A_shared.ptr_to([tx % 16, tx // 16 * 8]), - Tx.address_of(A_local[0]), - Tx.address_of(A_local[2]), + T.address_of(A_local[0]), + T.address_of(A_local[2]), ) else: - Tx.ptx.ldmatrix( + T.ptx.ldmatrix( trans, num, dtype, A_shared.ptr_to([tx % 16, tx // 16 * 8]), - Tx.address_of(A_local[0]), - Tx.address_of(A_local[2]), - Tx.address_of(A_local[4]), - Tx.address_of(A_local[6]), + T.address_of(A_local[0]), + T.address_of(A_local[2]), + T.address_of(A_local[4]), + T.address_of(A_local[6]), ) for i in range(8): - row: Tx.let = (i // 2) % 2 * 8 - col: Tx.let = (i // 4) * 8 + row: T.let = (i // 2) % 2 * 8 + col: T.let = (i // 4) * 8 B[row + tx // 4, col + tx % 4 * 2 + i % 2] = A_local[i] # fmt: on diff --git a/tests/python/tirx/codegen/test_codegen_dsmem.py b/tests/python/tirx/codegen/test_codegen_dsmem.py index 4c83c9247ce3..d538be571f88 100644 --- a/tests/python/tirx/codegen/test_codegen_dsmem.py +++ b/tests/python/tirx/codegen/test_codegen_dsmem.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T def _get_source(func: tvm.tirx.PrimFunc) -> str: @@ -31,24 +31,24 @@ def _get_source(func: tvm.tirx.PrimFunc) -> str: def test_ptx_cp_async_bulk_s2c_codegen(): - """Test that Tx.ptx.cp_async.bulk.s2c emits the correct PTX instruction.""" + """Test that T.ptx.cp_async.bulk.s2c emits the correct PTX instruction.""" # fmt: off - @Tx.prim_func - def main(A: Tx.Buffer((128,), "float16")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([1]) - A_smem = Tx.alloc_shared([128], "float16") - for i in Tx.serial(128): + @T.prim_func + def main(A: T.Buffer((128,), "float16")): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([1]) + A_smem = T.alloc_shared([128], "float16") + for i in T.serial(128): A_smem[i] = A[i] # Use the raw PTX instruction directly - dst_ptr = Tx.ptx.map_shared_rank(A_smem.ptr_to([0]), Tx.int32(1)) - mbar_ptr = Tx.ptx.map_shared_rank(A_smem.ptr_to([0]), Tx.int32(1)) - Tx.ptx.cp_async.bulk.s2c( + dst_ptr = T.ptx.map_shared_rank(A_smem.ptr_to([0]), T.int32(1)) + mbar_ptr = T.ptx.map_shared_rank(A_smem.ptr_to([0]), T.int32(1)) + T.ptx.cp_async.bulk.s2c( dst_ptr, A_smem.ptr_to([0]), - Tx.int32(256), # 128 elements * 2 bytes + T.int32(256), # 128 elements * 2 bytes mbar_ptr, ) # fmt: on @@ -62,20 +62,20 @@ def test_ptx_cp_async_bulk_s2c_codegen_address_conversion(): """Test that the codegen correctly converts addresses to shared space.""" # fmt: off - @Tx.prim_func - def main(A: Tx.Buffer((64,), "float32")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([1]) - A_smem = Tx.alloc_shared([64], "float32") - for i in Tx.serial(64): + @T.prim_func + def main(A: T.Buffer((64,), "float32")): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([1]) + A_smem = T.alloc_shared([64], "float32") + for i in T.serial(64): A_smem[i] = A[i] - dst_ptr = Tx.ptx.map_shared_rank(A_smem.ptr_to([0]), Tx.int32(0)) - mbar_ptr = Tx.ptx.map_shared_rank(A_smem.ptr_to([0]), Tx.int32(0)) - Tx.ptx.cp_async.bulk.s2c( + dst_ptr = T.ptx.map_shared_rank(A_smem.ptr_to([0]), T.int32(0)) + mbar_ptr = T.ptx.map_shared_rank(A_smem.ptr_to([0]), T.int32(0)) + T.ptx.cp_async.bulk.s2c( dst_ptr, A_smem.ptr_to([0]), - Tx.int32(256), # 64 * 4 bytes + T.int32(256), # 64 * 4 bytes mbar_ptr, ) # fmt: on diff --git a/tests/python/tirx/codegen/test_codegen_hopper.py b/tests/python/tirx/codegen/test_codegen_hopper.py index 538f780e5948..8f14dfc3c22d 100644 --- a/tests/python/tirx/codegen/test_codegen_hopper.py +++ b/tests/python/tirx/codegen/test_codegen_hopper.py @@ -22,7 +22,7 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.tirx import Buffer @@ -36,18 +36,17 @@ def _get_source(func: tvm.tirx.PrimFunc) -> tuple[str, tvm.IRModule]: def _run_tensormap_encode(shape, dtype, encode_args): # fmt: off - @Tx.prim_func - def main(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, shape, dtype=dtype, align=32) - - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", A_map, dtype, len(shape), A.data, *encode_args) # noqa: E501 - - Tx.device_entry() - for blockIdx in Tx.thread_binding(1, thread="blockIdx.x"): - for threadIdx in Tx.thread_binding(1, thread="threadIdx.x"): - with Tx.thread(): - Tx.evaluate(blockIdx + threadIdx) + @T.prim_func + def main(A_ptr: T.handle): + A = T.match_buffer(A_ptr, shape, dtype=dtype, align=32) + + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", A_map, dtype, len(shape), A.data, *encode_args) # noqa: E501 + + T.device_entry() + for blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for threadIdx in T.thread_binding(1, thread="threadIdx.x"): + T.evaluate(blockIdx + threadIdx) # fmt: on target = tvm.target.Target("cuda") @@ -61,12 +60,12 @@ def main(A_ptr: Tx.handle): @tvm.testing.requires_cuda_compute_version(9) def test_ptx_setmaxnreg(inc): # fmt: off - @Tx.prim_func - def func(A: Tx.Buffer(1)): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([128]) - Tx.ptx.setmaxnreg(inc, 32) + @T.prim_func + def func(A: T.Buffer(1)): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([128]) + T.ptx.setmaxnreg(inc, 32) # fmt: on src, mod = _get_source(func) @@ -81,25 +80,23 @@ def func(A: Tx.Buffer(1)): @tvm.testing.requires_cuda_compute_version(9) def test_stmatrix_sync_aligned(trans): # fmt: off - @Tx.prim_func - def func(A: Tx.Buffer((16, 16), "float16")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) - with Tx.cta(): - A_smem = Tx.alloc_buffer((16, 16), "float16", scope="shared", align=16) - with Tx.thread(): - reg = Tx.alloc_buffer((8,), "float16", scope="local") - for i in range(8): - reg[i] = tx * 8 + i - Tx.ptx.stmatrix( - trans, 4, ".b16", - A_smem.ptr_to([tx % 16, tx // 16 * 8]), - reg.ptr_to([0]), reg.ptr_to([2]), reg.ptr_to([4]), reg.ptr_to([6]), - ) - if tx == 0: - for i, j in Tx.grid(16, 16): - A[i, j] = A_smem[i, j] + @T.prim_func + def func(A: T.Buffer((16, 16), "float16")): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) + A_smem = T.alloc_buffer((16, 16), "float16", scope="shared", align=16) + reg = T.alloc_buffer((8,), "float16", scope="local") + for i in range(8): + reg[i] = tx * 8 + i + T.ptx.stmatrix( + trans, 4, ".b16", + A_smem.ptr_to([tx % 16, tx // 16 * 8]), + reg.ptr_to([0]), reg.ptr_to([2]), reg.ptr_to([4]), reg.ptr_to([6]), + ) + if tx == 0: + for i, j in T.grid(16, 16): + A[i, j] = A_smem[i, j] # fmt: on DEV = tvm.cuda(0) @@ -144,30 +141,28 @@ def func(A: Tx.Buffer((16, 16), "float16")): @pytest.mark.parametrize("num", [1, 2, 4]) def test_ptx_stmatrix(trans, num): # fmt: off - @Tx.prim_func - def main(A: Tx.Buffer((16, 16), "float16")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) - A_shared = Tx.alloc_shared([16, 16], "float16") + @T.prim_func + def main(A: T.Buffer((16, 16), "float16")): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) + A_shared = T.alloc_shared([16, 16], "float16") if tx == 0: - with Tx.thread(): - for i, j in Tx.grid(16, 16): - A_shared[i, j] = Tx.float16(0.0) - Tx.cuda.cta_sync() - A_local = Tx.alloc_local([8], "float16") + for i, j in T.grid(16, 16): + A_shared[i, j] = T.float16(0.0) + T.cuda.cta_sync() + A_local = T.alloc_local([8], "float16") for i in range(8): A_local[i] = (i // 2) * 64 + tx * 2 + i % 2 - Tx.ptx.stmatrix( + T.ptx.stmatrix( trans, num, ".b16", A_shared.ptr_to([tx % 16, tx // 16 * 8]), *[A_local.ptr_to([i * 2]) for i in range(num)], ) - Tx.cuda.cta_sync() + T.cuda.cta_sync() if tx == 0: - with Tx.thread(): - for i, j in Tx.grid(16, 16): - A[i, j] = A_shared[i, j] + for i, j in T.grid(16, 16): + A[i, j] = A_shared[i, j] # fmt: on DEV = tvm.cuda(0) @@ -216,31 +211,29 @@ def test_ptx_stmatrix_noncontiguous(trans, num): LOCAL_SIZE = STRIDE * num # fmt: off - @Tx.prim_func - def main(A: Tx.Buffer((16, 16), "float16")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([32]) - A_shared = Tx.alloc_shared([16, 16], "float16") + @T.prim_func + def main(A: T.Buffer((16, 16), "float16")): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([32]) + A_shared = T.alloc_shared([16, 16], "float16") if tx == 0: - with Tx.thread(): - for i, j in Tx.grid(16, 16): - A_shared[i, j] = Tx.float16(0.0) - Tx.cuda.cta_sync() - A_local = Tx.alloc_local([LOCAL_SIZE], "float16") + for i, j in T.grid(16, 16): + A_shared[i, j] = T.float16(0.0) + T.cuda.cta_sync() + A_local = T.alloc_local([LOCAL_SIZE], "float16") for i in range(num): - A_local[i * STRIDE + 0] = Tx.float16(i * 64 + tx * 2 + 0) - A_local[i * STRIDE + 1] = Tx.float16(i * 64 + tx * 2 + 1) - Tx.ptx.stmatrix( + A_local[i * STRIDE + 0] = T.float16(i * 64 + tx * 2 + 0) + A_local[i * STRIDE + 1] = T.float16(i * 64 + tx * 2 + 1) + T.ptx.stmatrix( trans, num, ".b16", A_shared.ptr_to([tx % 16, tx // 16 * 8]), *[A_local.ptr_to([i * STRIDE]) for i in range(num)], ) - Tx.cuda.cta_sync() + T.cuda.cta_sync() if tx == 0: - with Tx.thread(): - for i, j in Tx.grid(16, 16): - A[i, j] = A_shared[i, j] + for i, j in T.grid(16, 16): + A[i, j] = A_shared[i, j] # fmt: on DEV = tvm.cuda(0) @@ -277,12 +270,12 @@ def main(A: Tx.Buffer((16, 16), "float16")): @tvm.testing.requires_cuda_compute_version(9) def test_bar_arrive(): # fmt: off - @Tx.prim_func - def func(A: Tx.Buffer(1)): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([128]) - Tx.ptx.bar.arrive(0, 128) + @T.prim_func + def func(A: T.Buffer(1)): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([128]) + T.ptx.bar.arrive(0, 128) # fmt: on src, mod = _get_source(func) @@ -293,12 +286,12 @@ def func(A: Tx.Buffer(1)): @tvm.testing.requires_cuda_compute_version(9) def test_bar_sync(): # fmt: off - @Tx.prim_func - def func(A: Tx.Buffer(1)): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([128]) - Tx.ptx.bar.sync(0, 128) + @T.prim_func + def func(A: T.Buffer(1)): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([128]) + T.ptx.bar.sync(0, 128) # fmt: on src, mod = _get_source(func) @@ -309,12 +302,12 @@ def func(A: Tx.Buffer(1)): @tvm.testing.requires_cuda_compute_version(9) def test_fence_mbarrier_init_release_clsuter(): # fmt: off - @Tx.prim_func - def func(A: Tx.Buffer(1)): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([128]) - Tx.ptx.fence.mbarrier_init() + @T.prim_func + def func(A: T.Buffer(1)): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([128]) + T.ptx.fence.mbarrier_init() # fmt: on src, mod = _get_source(func) @@ -324,12 +317,12 @@ def func(A: Tx.Buffer(1)): @tvm.testing.requires_cuda_compute_version(9) def test_ptx_elect_sync(): # fmt: off - @Tx.prim_func - def func(A: Tx.Buffer(1)): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([128]) - if (Tx.ptx.elect_sync()): + @T.prim_func + def func(A: T.Buffer(1)): + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([128]) + if (T.ptx.elect_sync()): A[tx] = tx # fmt: on @@ -342,12 +335,12 @@ def func(A: Tx.Buffer(1)): @pytest.mark.parametrize("sem,scope", [("sc", "cta"), ("acq_rel", "gpu"), ("sc", "sys")]) def test_ptx_fence(sem, scope): # fmt: off - @Tx.prim_func - def func(A: Tx.Buffer(1)): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([128]) - Tx.ptx.fence(sem, scope) + @T.prim_func + def func(A: T.Buffer(1)): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([128]) + T.ptx.fence(sem, scope) # fmt: on src, mod = _get_source(func) @@ -357,13 +350,13 @@ def func(A: Tx.Buffer(1)): @tvm.testing.requires_cuda_compute_version(9) def test_fence_proxy_async(): # fmt: off - @Tx.prim_func - def func(A: Tx.Buffer(1)): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([128]) - Tx.ptx.fence.proxy_async("global") - Tx.ptx.fence.proxy_async("shared::cta") + @T.prim_func + def func(A: T.Buffer(1)): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([128]) + T.ptx.fence.proxy_async("global") + T.ptx.fence.proxy_async("shared::cta") # fmt: on @@ -394,40 +387,39 @@ def get_ir(shape, tma_args): tma_args_copy[len(shape) + i] *= t_dtype.bits // 8 # fmt: off - @Tx.prim_func - def main(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, shape, dtype=dtype, align=16) - B = Tx.match_buffer(B_ptr, shape, dtype=dtype, align=16) - - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", A_map, dtype, len(shape), A.data, *tma_args_copy) # noqa: E501 - B_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", B_map, dtype, len(shape), B.data, *tma_args_copy) # noqa: E501 - - Tx.device_entry() - for blockIdx in Tx.thread_binding(1, thread="blockIdx.x"): - for threadIdx in Tx.thread_binding(128, thread="threadIdx.x"): - with Tx.thread(): - bar = Tx.shared_scalar("uint64") - phase: Tx.int32 - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", align=128) - - phase = 0 - if threadIdx == 0: - Tx.ptx.mbarrier.init(Tx.address_of(bar), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.data, Tx.address_of(bar), Tx.address_of(A_map), 0, 1, "", *coord) # noqa: E501 - Tx.ptx.mbarrier.arrive.expect_tx(Tx.address_of(bar), total_bytes) - Tx.ptx.mbarrier.try_wait(Tx.address_of(bar), phase) - phase = phase ^ 1 - - Tx.cuda.cta_sync() - Tx.ptx.fence.proxy_async("shared::cta") - - if threadIdx == 0: - Tx.ptx.cp_async.bulk.tensor.s2g(len(shape), A_smem.access_ptr("r", offset=0), Tx.address_of(B_map), "", *coord) # noqa: E501 - Tx.ptx.cp_async.bulk.commit_group() - Tx.ptx.cp_async.bulk.wait_group(0) + @T.prim_func + def main(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, shape, dtype=dtype, align=16) + B = T.match_buffer(B_ptr, shape, dtype=dtype, align=16) + + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", A_map, dtype, len(shape), A.data, *tma_args_copy) # noqa: E501 + B_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", B_map, dtype, len(shape), B.data, *tma_args_copy) # noqa: E501 + + T.device_entry() + for blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for threadIdx in T.thread_binding(128, thread="threadIdx.x"): + bar = T.shared_scalar("uint64") + phase: T.int32 + A_smem = T.alloc_buffer(shape, dtype, scope="shared", align=128) + + phase = 0 + if threadIdx == 0: + T.ptx.mbarrier.init(T.address_of(bar), 1) + T.ptx.fence.proxy_async("shared::cta") + T.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.data, T.address_of(bar), T.address_of(A_map), 0, 1, "", *coord) # noqa: E501 + T.ptx.mbarrier.arrive.expect_tx(T.address_of(bar), total_bytes) + T.ptx.mbarrier.try_wait(T.address_of(bar), phase) + phase = phase ^ 1 + + T.cuda.cta_sync() + T.ptx.fence.proxy_async("shared::cta") + + if threadIdx == 0: + T.ptx.cp_async.bulk.tensor.s2g(len(shape), A_smem.access_ptr("r", offset=0), T.address_of(B_map), "", *coord) # noqa: E501 + T.ptx.cp_async.bulk.commit_group() + T.ptx.cp_async.bulk.wait_group(0) # fmt: on return main @@ -556,40 +548,39 @@ def get_ir(swizzle, dtype): coord = [0 for _ in shape] # fmt: off - @Tx.prim_func - def main(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, total_elems, dtype=dtype, align=16) - B = Tx.match_buffer(B_ptr, total_elems, dtype=dtype, align=16) - - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", A_map, dtype, len(shape), A.data, *load_args) # noqa: E501 - B_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", B_map, dtype, len(shape), B.data, *store_args) # noqa: E501 - - Tx.device_entry() - for blockIdx in Tx.thread_binding(1, thread="blockIdx.x"): - for threadIdx in Tx.thread_binding(128, thread="threadIdx.x"): - with Tx.thread(): - A_smem = Tx.alloc_buffer((total_elems,), dtype, scope="shared", align=128) - bar = Tx.shared_scalar("uint64") - phase: Tx.int32 - - phase = 0 - if threadIdx == 0: - Tx.ptx.mbarrier.init(Tx.address_of(bar), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.data, Tx.address_of(bar), Tx.address_of(A_map), 0, 1, "", *coord) # noqa: E501 - Tx.ptx.mbarrier.arrive.expect_tx(Tx.address_of(bar), total_bytes) - Tx.ptx.mbarrier.try_wait(Tx.address_of(bar), phase) - phase = phase ^ 1 - - Tx.cuda.cta_sync() - Tx.ptx.fence.proxy_async("shared::cta") - - if threadIdx == 0: - Tx.ptx.cp_async.bulk.tensor.s2g(len(shape), A_smem.access_ptr("r", offset=0), Tx.address_of(B_map), "", *coord) # noqa: E501 - Tx.ptx.cp_async.bulk.commit_group() - Tx.ptx.cp_async.bulk.wait_group(0) + @T.prim_func + def main(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, total_elems, dtype=dtype, align=16) + B = T.match_buffer(B_ptr, total_elems, dtype=dtype, align=16) + + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", A_map, dtype, len(shape), A.data, *load_args) # noqa: E501 + B_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", B_map, dtype, len(shape), B.data, *store_args) # noqa: E501 + + T.device_entry() + for blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for threadIdx in T.thread_binding(128, thread="threadIdx.x"): + A_smem = T.alloc_buffer((total_elems,), dtype, scope="shared", align=128) + bar = T.shared_scalar("uint64") + phase: T.int32 + + phase = 0 + if threadIdx == 0: + T.ptx.mbarrier.init(T.address_of(bar), 1) + T.ptx.fence.proxy_async("shared::cta") + T.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.data, T.address_of(bar), T.address_of(A_map), 0, 1, "", *coord) # noqa: E501 + T.ptx.mbarrier.arrive.expect_tx(T.address_of(bar), total_bytes) + T.ptx.mbarrier.try_wait(T.address_of(bar), phase) + phase = phase ^ 1 + + T.cuda.cta_sync() + T.ptx.fence.proxy_async("shared::cta") + + if threadIdx == 0: + T.ptx.cp_async.bulk.tensor.s2g(len(shape), A_smem.access_ptr("r", offset=0), T.address_of(B_map), "", *coord) # noqa: E501 + T.ptx.cp_async.bulk.commit_group() + T.ptx.cp_async.bulk.wait_group(0) # fmt: on return main, shape @@ -610,7 +601,7 @@ def main(A_ptr: Tx.handle, B_ptr: Tx.handle): B = tvm.runtime.tensor(B_np, device=DEV) mod(A, B) dtype = tvm.DataType(dtype) - layout = Tx.SwizzleLayout( + layout = T.SwizzleLayout( per_element=int(math.log2(128 // dtype.bits)), swizzle_len=swizzle, atom_len=3 ) B_np = B.numpy() @@ -640,45 +631,44 @@ def get_ir(shape, tma_args): coord = [0 for _ in shape] # fmt: off - @Tx.prim_func - def main(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, shape, dtype="float32", align=16) - B = Tx.match_buffer(B_ptr, shape, dtype="float32", align=16) - - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", A_map, "float32", len(shape), A.data, *tma_args) # noqa: E501 - B_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", B_map, "float32", len(shape), B.data, *tma_args) # noqa: E501 - - Tx.device_entry() - for clusterCtaIdx in Tx.thread_binding(4, thread="clusterCtaIdx.x"): - for bx in Tx.thread_binding(4, thread="blockIdx.x"): - for tx in Tx.thread_binding(128, thread="threadIdx.x"): - with Tx.thread(): - bar = Tx.shared_scalar("uint64") - phase: Tx.int32 - A_smem = Tx.alloc_buffer(shape[::-1], "float32", scope="shared", align=128) # noqa: E501 - - phase = 0 + @T.prim_func + def main(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, shape, dtype="float32", align=16) + B = T.match_buffer(B_ptr, shape, dtype="float32", align=16) + + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", A_map, "float32", len(shape), A.data, *tma_args) # noqa: E501 + B_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", B_map, "float32", len(shape), B.data, *tma_args) # noqa: E501 + + T.device_entry() + for clusterCtaIdx in T.thread_binding(4, thread="clusterCtaIdx.x"): + for bx in T.thread_binding(4, thread="blockIdx.x"): + for tx in T.thread_binding(128, thread="threadIdx.x"): + bar = T.shared_scalar("uint64") + phase: T.int32 + A_smem = T.alloc_buffer(shape[::-1], "float32", scope="shared", align=128) + + phase = 0 + if tx == 0: + # leader thread in each CTA + T.ptx.mbarrier.init(T.address_of(bar), 1) + T.ptx.fence.proxy_async("shared::cta") + T.ptx.mbarrier.arrive.expect_tx(T.address_of(bar), total_bytes) + if clusterCtaIdx == 0: + # only the first CTA in the cluster does the copy, and then multicast # noqa: E501 + T.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.data, T.address_of(bar), T.address_of(A_map), int("1111", 2), 1, "", *coord) # noqa: E501 + # wait for the copy to finish + T.ptx.mbarrier.try_wait(T.address_of(bar), phase) + phase = phase ^ 1 + T.cuda.cta_sync() + T.ptx.fence.proxy_async("shared::cta") + + if bx == 2: if tx == 0: - # leader thread in each CTA - Tx.ptx.mbarrier.init(Tx.address_of(bar), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.ptx.mbarrier.arrive.expect_tx(Tx.address_of(bar), total_bytes) - if clusterCtaIdx == 0: - # only the first CTA in the cluster does the copy, and then multicast # noqa: E501 - Tx.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.data, Tx.address_of(bar), Tx.address_of(A_map), int("1111", 2), 1, "", *coord) # noqa: E501 - # wait for the copy to finish - Tx.ptx.mbarrier.try_wait(Tx.address_of(bar), phase) - phase = phase ^ 1 - Tx.cuda.cta_sync() - Tx.ptx.fence.proxy_async("shared::cta") - - if bx == 2: - if tx == 0: - Tx.ptx.cp_async.bulk.tensor.s2g(len(shape), A_smem.access_ptr("r", offset=0), Tx.address_of(B_map), "", *coord) # noqa: E501 - Tx.ptx.cp_async.bulk.commit_group() - Tx.ptx.cp_async.bulk.wait_group(0) + T.ptx.cp_async.bulk.tensor.s2g(len(shape), A_smem.access_ptr("r", offset=0), T.address_of(B_map), "", *coord) # noqa: E501 + T.ptx.cp_async.bulk.commit_group() + T.ptx.cp_async.bulk.wait_group(0) # fmt: on return main @@ -722,53 +712,52 @@ def get_ir(shape, tma_args): tma_store_args[3 * len(shape) - 2] = shape[-1] # fmt: off - @Tx.prim_func - def main(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, shape, dtype="float32", align=16) - B = Tx.match_buffer(B_ptr, shape, dtype="float32", align=16) - - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", A_map, "float32", len(shape), A.data, *tma_args) # noqa: E501 - B_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", B_map, "float32", len(shape), B.data, *tma_store_args) # noqa: E501 - - Tx.device_entry() - for clusterCtaIdx in Tx.thread_binding(4, thread="clusterCtaIdx.x"): - for bx in Tx.thread_binding(4, thread="blockIdx.x"): - for tx in Tx.thread_binding(128, thread="threadIdx.x"): - with Tx.thread(): - bar = Tx.shared_scalar("uint64") - phase: Tx.int32 - A_smem = Tx.alloc_buffer(shape[::-1], "float32", scope="shared", align=128) # noqa: E501 - - phase = 0 + @T.prim_func + def main(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, shape, dtype="float32", align=16) + B = T.match_buffer(B_ptr, shape, dtype="float32", align=16) + + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", A_map, "float32", len(shape), A.data, *tma_args) # noqa: E501 + B_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", B_map, "float32", len(shape), B.data, *tma_store_args) # noqa: E501 + + T.device_entry() + for clusterCtaIdx in T.thread_binding(4, thread="clusterCtaIdx.x"): + for bx in T.thread_binding(4, thread="blockIdx.x"): + for tx in T.thread_binding(128, thread="threadIdx.x"): + bar = T.shared_scalar("uint64") + phase: T.int32 + A_smem = T.alloc_buffer(shape[::-1], "float32", scope="shared", align=128) + + phase = 0 + if tx == 0: + # leader thread in each CTA + T.ptx.mbarrier.init(T.address_of(bar), 1) + T.ptx.fence.proxy_async("shared::cta") + T.ptx.mbarrier.arrive.expect_tx(T.address_of(bar), total_bytes) + if clusterCtaIdx == 0: + T.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.access_ptr(Buffer.WRITE, offset=A_smem.elem_offset_of(coord0[::-1])), # noqa: E501 + T.address_of(bar), T.address_of(A_map), int("1111", 2), 1, "", *coord0) # noqa: E501 + if clusterCtaIdx == 1: + T.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.access_ptr(Buffer.WRITE, offset=A_smem.elem_offset_of(coord1[::-1])), # noqa: E501 + T.address_of(bar), T.address_of(A_map), int("1111", 2), 1, "", *coord1) # noqa: E501 + if clusterCtaIdx == 2: + T.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.access_ptr(Buffer.WRITE, offset=A_smem.elem_offset_of(coord2[::-1])), # noqa: E501 + T.address_of(bar), T.address_of(A_map), int("1111", 2), 1, "", *coord2) # noqa: E501 + if clusterCtaIdx == 3: + T.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.access_ptr(Buffer.WRITE, offset=A_smem.elem_offset_of(coord3[::-1])), # noqa: E501 + T.address_of(bar), T.address_of(A_map), int("1111", 2), 1, "", *coord3) # noqa: E501 + # wait for the copy to finish + T.ptx.mbarrier.try_wait(T.address_of(bar), phase) + phase = phase ^ 1 + T.cuda.cta_sync() + + if bx == 1: if tx == 0: - # leader thread in each CTA - Tx.ptx.mbarrier.init(Tx.address_of(bar), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.ptx.mbarrier.arrive.expect_tx(Tx.address_of(bar), total_bytes) - if clusterCtaIdx == 0: - Tx.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.access_ptr(Buffer.WRITE, offset=A_smem.elem_offset_of(coord0[::-1])), # noqa: E501 - Tx.address_of(bar), Tx.address_of(A_map), int("1111", 2), 1, "", *coord0) # noqa: E501 - if clusterCtaIdx == 1: - Tx.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.access_ptr(Buffer.WRITE, offset=A_smem.elem_offset_of(coord1[::-1])), # noqa: E501 - Tx.address_of(bar), Tx.address_of(A_map), int("1111", 2), 1, "", *coord1) # noqa: E501 - if clusterCtaIdx == 2: - Tx.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.access_ptr(Buffer.WRITE, offset=A_smem.elem_offset_of(coord2[::-1])), # noqa: E501 - Tx.address_of(bar), Tx.address_of(A_map), int("1111", 2), 1, "", *coord2) # noqa: E501 - if clusterCtaIdx == 3: - Tx.ptx.cp_async.bulk.tensor.g2c(len(shape), A_smem.access_ptr(Buffer.WRITE, offset=A_smem.elem_offset_of(coord3[::-1])), # noqa: E501 - Tx.address_of(bar), Tx.address_of(A_map), int("1111", 2), 1, "", *coord3) # noqa: E501 - # wait for the copy to finish - Tx.ptx.mbarrier.try_wait(Tx.address_of(bar), phase) - phase = phase ^ 1 - Tx.cuda.cta_sync() - - if bx == 1: - if tx == 0: - Tx.ptx.cp_async.bulk.tensor.s2g(len(shape), A_smem.access_ptr("r", offset=0), Tx.address_of(B_map), "", *coord0) # noqa: E501 - Tx.ptx.cp_async.bulk.commit_group() - Tx.ptx.cp_async.bulk.wait_group(0) + T.ptx.cp_async.bulk.tensor.s2g(len(shape), A_smem.access_ptr("r", offset=0), T.address_of(B_map), "", *coord0) # noqa: E501 + T.ptx.cp_async.bulk.commit_group() + T.ptx.cp_async.bulk.wait_group(0) # fmt: on return main @@ -806,29 +795,29 @@ def get_ir(shape, tma_args): coord = [0 for _ in shape] # fmt: off - @Tx.prim_func - def main(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, shape, dtype="float32", align=16) + @T.prim_func + def main(A_ptr: T.handle): + A = T.match_buffer(A_ptr, shape, dtype="float32", align=16) - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", A_map, "float32", len(shape), A.data, *tma_args) # noqa: E501 + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", A_map, "float32", len(shape), A.data, *tma_args) # noqa: E501 - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([128]) + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([128]) - A_smem = Tx.alloc_buffer(elems, "float32", scope="shared", align=128) + A_smem = T.alloc_buffer(elems, "float32", scope="shared", align=128) if tx == 0: - for i in Tx.serial(0, elems): + for i in T.serial(0, elems): A_smem[i] = i - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() if tx == 0: - Tx.ptx.cp_async.bulk.tensor.s2g(len(shape), A_smem.access_ptr("r", offset=0), Tx.address_of(A_map), "", *coord) # noqa: E501 - Tx.ptx.cp_async.bulk.commit_group() - Tx.ptx.cp_async.bulk.wait_group(0) + T.ptx.cp_async.bulk.tensor.s2g(len(shape), A_smem.access_ptr("r", offset=0), T.address_of(A_map), "", *coord) # noqa: E501 + T.ptx.cp_async.bulk.commit_group() + T.ptx.cp_async.bulk.wait_group(0) # fmt: on return main @@ -875,73 +864,73 @@ def get_ir( def get_init_value(dtype): if dtype == "float32": - return Tx.float32(0.0) + return T.float32(0.0) assert False, f"Unsupported dtype {dtype}" def get_accum_list(C, C_elems): return [C[i] for i in range(C_elems)] # fmt: off - @Tx.prim_func - def main(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, shapeA, dtype=in_dtype, align=16) - B = Tx.match_buffer(B_ptr, shapeB, dtype=in_dtype, align=16) - C = Tx.match_buffer(C_ptr, shapeC, dtype=out_dtype, align=16) - - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", A_map, in_dtype, len(shapeA), A.data, *A_tma_args) # noqa: E501 - B_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", B_map, in_dtype, len(shapeB), B.data, *B_tma_args) # noqa: E501 - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([128]) # A warpgroup is 128 threads - - A_smem = Tx.alloc_buffer(shapeA, in_dtype, scope="shared", align=1024) - B_smem = Tx.alloc_buffer(shapeB, in_dtype, scope="shared", align=1024) - bar = Tx.shared_scalar("uint64") - phase: Tx.int32 - - descA: Tx.uint64 - descB: Tx.uint64 - C_local = Tx.alloc_buffer((C_elems,), out_dtype, scope="local") + @T.prim_func + def main(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle): + A = T.match_buffer(A_ptr, shapeA, dtype=in_dtype, align=16) + B = T.match_buffer(B_ptr, shapeB, dtype=in_dtype, align=16) + C = T.match_buffer(C_ptr, shapeC, dtype=out_dtype, align=16) + + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", A_map, in_dtype, len(shapeA), A.data, *A_tma_args) # noqa: E501 + B_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", B_map, in_dtype, len(shapeB), B.data, *B_tma_args) # noqa: E501 + + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([128]) # A warpgroup is 128 threads + + A_smem = T.alloc_buffer(shapeA, in_dtype, scope="shared", align=1024) + B_smem = T.alloc_buffer(shapeB, in_dtype, scope="shared", align=1024) + bar = T.shared_scalar("uint64") + phase: T.int32 + + descA: T.uint64 + descB: T.uint64 + C_local = T.alloc_buffer((C_elems,), out_dtype, scope="local") # init phase and bar phase = 0 if tx == 0: - Tx.ptx.mbarrier.init(Tx.address_of(bar), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.mbarrier.init(T.address_of(bar), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() # load A and B to smem if tx == 0: - Tx.ptx.cp_async.bulk.tensor.g2c(len(shapeA), A_smem.data, Tx.address_of(bar), Tx.address_of(A_map), 0, 1, "", *coordA) # noqa: E501 - Tx.ptx.cp_async.bulk.tensor.g2c(len(shapeB), B_smem.data, Tx.address_of(bar), Tx.address_of(B_map), 0, 1, "", *coordB) # noqa: E501 - Tx.ptx.mbarrier.arrive.expect_tx(Tx.address_of(bar), A_bytes + B_bytes) - Tx.ptx.mbarrier.try_wait(Tx.address_of(bar), phase) + T.ptx.cp_async.bulk.tensor.g2c(len(shapeA), A_smem.data, T.address_of(bar), T.address_of(A_map), 0, 1, "", *coordA) # noqa: E501 + T.ptx.cp_async.bulk.tensor.g2c(len(shapeB), B_smem.data, T.address_of(bar), T.address_of(B_map), 0, 1, "", *coordB) # noqa: E501 + T.ptx.mbarrier.arrive.expect_tx(T.address_of(bar), A_bytes + B_bytes) + T.ptx.mbarrier.try_wait(T.address_of(bar), phase) phase = phase ^ 1 - Tx.cuda.cta_sync() + T.cuda.cta_sync() # init C_local - for i in Tx.serial(0, C_elems): - C_local[i] = Tx.Cast(out_dtype, get_init_value(out_dtype)) - Tx.ptx.wgmma.noop_barrier(C_local[i]) + for i in T.serial(0, C_elems): + C_local[i] = T.Cast(out_dtype, get_init_value(out_dtype)) + T.ptx.wgmma.noop_barrier(C_local[i]) # do wgmma - Tx.ptx.wgmma.encode_matrix_descriptor(Tx.address_of(descA), A_smem.data, *A_encode_args) # noqa: F821 - Tx.ptx.wgmma.encode_matrix_descriptor(Tx.address_of(descB), B_smem.data, *B_encode_args) # noqa: F821 - Tx.ptx.wgmma.fence() - Tx.ptx.wgmma.mma_async.ss(descA, descB, *get_accum_list(C_local, C_elems), # noqa: F821 + T.ptx.wgmma.encode_matrix_descriptor(T.address_of(descA), A_smem.data, *A_encode_args) # noqa: F821 + T.ptx.wgmma.encode_matrix_descriptor(T.address_of(descB), B_smem.data, *B_encode_args) # noqa: F821 + T.ptx.wgmma.fence() + T.ptx.wgmma.mma_async.ss(descA, descB, *get_accum_list(C_local, C_elems), # noqa: F821 M=M, N=N, K=K, in_dtype=in_dtype, out_dtype=out_dtype, transA=transA, transB=transB, scaleA=1.0, scaleB=1.0, scaleD=False) # noqa: E501 - Tx.ptx.wgmma.commit_group() - Tx.ptx.wgmma.wait_group(0) + T.ptx.wgmma.commit_group() + T.ptx.wgmma.wait_group(0) - for i in Tx.serial(0, C_elems): - Tx.ptx.wgmma.noop_barrier(C_local[i]) + for i in T.serial(0, C_elems): + T.ptx.wgmma.noop_barrier(C_local[i]) # store C_local to C - for i in Tx.serial(0, C_elems // 4): - row = Tx.meta_var((tx % 32) // 4 + (tx // 32) * 16) - col = Tx.meta_var(i * 8 + tx % 4 * 2) + for i in T.serial(0, C_elems // 4): + row = T.meta_var((tx % 32) // 4 + (tx // 32) * 16) + col = T.meta_var(i * 8 + tx % 4 * 2) C[row, col] = C_local[i * 4] C[row, col + 1] = C_local[i * 4 + 1] C[row + 8, col] = C_local[i * 4 + 2] @@ -1021,7 +1010,7 @@ def get_ir( def get_init_value(dtype): if dtype == "float32": - return Tx.float32(0.0) + return T.float32(0.0) assert False, f"Unsupported dtype {dtype}" def get_A_list(A_local, A_elems): @@ -1031,79 +1020,79 @@ def get_accum_list(C, C_elems): return [C[i] for i in range(C_elems)] # fmt: off - @Tx.prim_func - def main(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, shapeA, dtype=in_dtype, align=16) - B = Tx.match_buffer(B_ptr, shapeB, dtype=in_dtype, align=16) - C = Tx.match_buffer(C_ptr, shapeC, dtype=out_dtype, align=16) + @T.prim_func + def main(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle): + A = T.match_buffer(A_ptr, shapeA, dtype=in_dtype, align=16) + B = T.match_buffer(B_ptr, shapeB, dtype=in_dtype, align=16) + C = T.match_buffer(C_ptr, shapeC, dtype=out_dtype, align=16) - B_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", B_map, in_dtype, len(shapeB), B.data, *B_tma_args) # noqa: E501 + B_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", B_map, in_dtype, len(shapeB), B.data, *B_tma_args) # noqa: E501 - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx = Tx.thread_id([128]) # A warpgroup is 128 threads + T.device_entry() + cta_id = T.cta_id([1]) + tx = T.thread_id([128]) # A warpgroup is 128 threads - B_smem = Tx.alloc_buffer(shapeB, in_dtype, scope="shared", align=1024) - # bar = Tx.alloc_buffer((1,), "uint64", scope="shared", align=8) - bar = Tx.shared_scalar("uint64") + B_smem = T.alloc_buffer(shapeB, in_dtype, scope="shared", align=1024) + # bar = T.alloc_buffer((1,), "uint64", scope="shared", align=8) + bar = T.shared_scalar("uint64") - # descB = Tx.alloc_buffer((1,), "uint64", scope="local") - descB: Tx.uint64 - A_local = Tx.alloc_buffer((A_elems,), in_dtype, scope="local") - C_local = Tx.alloc_buffer((C_elems,), out_dtype, scope="local") + # descB = T.alloc_buffer((1,), "uint64", scope="local") + descB: T.uint64 + A_local = T.alloc_buffer((A_elems,), in_dtype, scope="local") + C_local = T.alloc_buffer((C_elems,), out_dtype, scope="local") - A_elems_b32 = Tx.meta_var(A_elems // (32 // in_dtype_bits)) - A_local_b32 = Tx.decl_buffer((A_elems_b32,), "uint32", data=A_local.data) + A_elems_b32 = T.meta_var(A_elems // (32 // in_dtype_bits)) + A_local_b32 = T.decl_buffer((A_elems_b32,), "uint32", data=A_local.data) # load A to regs - for i in Tx.serial(0, A_elems // 4): - row = Tx.meta_var((tx % 32) // 4 + (tx // 32) * 16) - col = Tx.meta_var(i * 8 + tx % 4 * 2) + for i in T.serial(0, A_elems // 4): + row = T.meta_var((tx % 32) // 4 + (tx // 32) * 16) + col = T.meta_var(i * 8 + tx % 4 * 2) A_local[i * 4] = A[row, col] A_local[i * 4 + 1] = A[row, col + 1] A_local[i * 4 + 2] = A[row + 8, col] A_local[i * 4 + 3] = A[row + 8, col + 1] # init bar, and make sure it's visible to all threads and async proxy if tx == 0: - Tx.ptx.mbarrier.init(Tx.address_of(bar), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.mbarrier.init(T.address_of(bar), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() # load B to smem if tx == 0: - Tx.ptx.cp_async.bulk.tensor.g2c(len(shapeB), B_smem.data, Tx.address_of(bar), Tx.address_of(B_map), 0, 1, "", *coordB) # noqa: E501 - Tx.ptx.mbarrier.arrive.expect_tx(Tx.address_of(bar), B_bytes) - Tx.ptx.mbarrier.try_wait(Tx.address_of(bar), 0) - Tx.cuda.cta_sync() + T.ptx.cp_async.bulk.tensor.g2c(len(shapeB), B_smem.data, T.address_of(bar), T.address_of(B_map), 0, 1, "", *coordB) # noqa: E501 + T.ptx.mbarrier.arrive.expect_tx(T.address_of(bar), B_bytes) + T.ptx.mbarrier.try_wait(T.address_of(bar), 0) + T.cuda.cta_sync() # init C_local - for i in Tx.serial(0, C_elems): - C_local[i] = Tx.Cast(out_dtype, get_init_value(out_dtype)) + for i in T.serial(0, C_elems): + C_local[i] = T.Cast(out_dtype, get_init_value(out_dtype)) # fence A_local and C_local - for i in Tx.serial(0, A_elems_b32): - Tx.ptx.wgmma.noop_barrier(A_local_b32[i]) - for i in Tx.serial(0, C_elems): - Tx.ptx.wgmma.noop_barrier(C_local[i]) + for i in T.serial(0, A_elems_b32): + T.ptx.wgmma.noop_barrier(A_local_b32[i]) + for i in T.serial(0, C_elems): + T.ptx.wgmma.noop_barrier(C_local[i]) # do wgmma - Tx.ptx.wgmma.encode_matrix_descriptor(Tx.address_of(descB), B_smem.data, *B_encode_args) # noqa: F821 - Tx.ptx.wgmma.fence() - Tx.ptx.wgmma.mma_async.rs(descB, *(get_A_list(A_local_b32, A_elems_b32) + get_accum_list(C_local, C_elems)), # noqa: E501, F821 + T.ptx.wgmma.encode_matrix_descriptor(T.address_of(descB), B_smem.data, *B_encode_args) # noqa: F821 + T.ptx.wgmma.fence() + T.ptx.wgmma.mma_async.rs(descB, *(get_A_list(A_local_b32, A_elems_b32) + get_accum_list(C_local, C_elems)), # noqa: E501, F821 M=M, N=N, K=K, in_dtype=in_dtype, out_dtype=out_dtype, transA=transA, transB=transB, scaleA=1.0, scaleB=1.0, scaleD=False) # noqa: E501 - Tx.ptx.wgmma.commit_group() - Tx.ptx.wgmma.wait_group(0) + T.ptx.wgmma.commit_group() + T.ptx.wgmma.wait_group(0) # fence A_local - for i in Tx.serial(0, A_elems_b32): - Tx.ptx.wgmma.noop_barrier(A_local_b32[i]) + for i in T.serial(0, A_elems_b32): + T.ptx.wgmma.noop_barrier(A_local_b32[i]) # fence C_local - for i in Tx.serial(0, C_elems): - Tx.ptx.wgmma.noop_barrier(C_local[i]) + for i in T.serial(0, C_elems): + T.ptx.wgmma.noop_barrier(C_local[i]) # store C_local to C - for i in Tx.serial(0, C_elems // 4): - row = Tx.meta_var((tx % 32) // 4 + (tx // 32) * 16) - col = Tx.meta_var(i * 8 + tx % 4 * 2) + for i in T.serial(0, C_elems // 4): + row = T.meta_var((tx % 32) // 4 + (tx // 32) * 16) + col = T.meta_var(i * 8 + tx % 4 * 2) C[row, col] = C_local[i * 4] C[row, col + 1] = C_local[i * 4 + 1] C[row + 8, col] = C_local[i * 4 + 2] @@ -1163,17 +1152,15 @@ def main(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle): @tvm.testing.requires_cuda_compute_version(9) def test_ptx_map_shared_rank(): - @Tx.prim_func - def func(A: Tx.Buffer(1)): - Tx.device_entry() - cbx = Tx.cta_id_in_cluster([2]) - cta_id = Tx.cta_id([2]) - tx = Tx.thread_id([128]) - with Tx.cta(): - A_smem = Tx.alloc_buffer([1], "uint32", scope="shared") - if cbx == 0 and tx == 0: - with Tx.thread(): - Tx.ptx.map_shared_rank(A_smem.data, cbx) + @T.prim_func + def func(A: T.Buffer(1)): + T.device_entry() + cbx = T.cta_id_in_cluster([2]) + cta_id = T.cta_id([2]) + tx = T.thread_id([128]) + A_smem = T.alloc_buffer([1], "uint32", scope="shared") + if cbx == 0 and tx == 0: + T.ptx.map_shared_rank(A_smem.data, cbx) src, mod = _get_source(func) print(src) diff --git a/tests/python/tirx/codegen/test_codegen_nki.py b/tests/python/tirx/codegen/test_codegen_nki.py index 73587a02b844..ca8965e7d361 100644 --- a/tests/python/tirx/codegen/test_codegen_nki.py +++ b/tests/python/tirx/codegen/test_codegen_nki.py @@ -18,7 +18,7 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T target = tvm.target.Target("aws/trn1/trn1.2xlarge") @@ -38,24 +38,24 @@ def compare_strings_ignore_whitespace(s1, s2): def test_nki_add_1(): # fmt: off - @Tx.prim_func - def func(A: Tx.Buffer((128, 512)), B: Tx.Buffer((128, 512))): - Tx.func_attr({"num_inputs": 1}) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((128, 512), "float32", scope="trn.sbuf",) - B_sbuf = Tx.alloc_buffer((128, 512), "float32", scope="trn.sbuf",) - with Tx.attr(0, "tensorized_nki_instruction", 1): + @T.prim_func + def func(A: T.Buffer((128, 512)), B: T.Buffer((128, 512))): + T.func_attr({"num_inputs": 1}) + T.device_entry() + A_sbuf = T.alloc_buffer((128, 512), "float32", scope="trn.sbuf",) + B_sbuf = T.alloc_buffer((128, 512), "float32", scope="trn.sbuf",) + with T.attr(0, "tensorized_nki_instruction", 1): for i in range(0, 128): for j in range(0, 512): - Tx.nki.load(A_sbuf[i, j], A[i, j]) - with Tx.attr(0, "tensorized_nki_instruction", 1): + T.nki.load(A_sbuf[i, j], A[i, j]) + with T.attr(0, "tensorized_nki_instruction", 1): for i in range(0, 128): for j in range(0, 512): - Tx.nki.tensorscalar(B_sbuf[i, j], A_sbuf[i, j], Tx.float32(1.0), "add") - with Tx.attr(0, "tensorized_nki_instruction", 1): + T.nki.tensorscalar(B_sbuf[i, j], A_sbuf[i, j], T.float32(1.0), "add") + with T.attr(0, "tensorized_nki_instruction", 1): for i in range(0, 128): for j in range(0, 512): - Tx.nki.store(B[i, j], B_sbuf[i, j]) + T.nki.store(B[i, j], B_sbuf[i, j]) # fmt: on src = lower_and_get_source(func) print(src) @@ -92,25 +92,25 @@ def func_kernel(A_ptr, B_ptr: nt.mutable_tensor, ): def test_nki_add_2(): # fmt: off - @Tx.prim_func - def func(A: Tx.Buffer((128, 2048)), B: Tx.Buffer((128, 2048))): - Tx.func_attr({"num_inputs": 1}) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((128, 512), "float32", scope="trn.sbuf",) - B_sbuf = Tx.alloc_buffer((128, 512), "float32", scope="trn.sbuf",) + @T.prim_func + def func(A: T.Buffer((128, 2048)), B: T.Buffer((128, 2048))): + T.func_attr({"num_inputs": 1}) + T.device_entry() + A_sbuf = T.alloc_buffer((128, 512), "float32", scope="trn.sbuf",) + B_sbuf = T.alloc_buffer((128, 512), "float32", scope="trn.sbuf",) for k in range(0, 4): - with Tx.attr(0, "tensorized_nki_instruction", 1): + with T.attr(0, "tensorized_nki_instruction", 1): for i in range(0, 128): for j in range(0, 512): - Tx.nki.load(A_sbuf[i, j], A[i, 512*k+j]) - with Tx.attr(0, "tensorized_nki_instruction", 1): + T.nki.load(A_sbuf[i, j], A[i, 512*k+j]) + with T.attr(0, "tensorized_nki_instruction", 1): for i in range(0, 128): for j in range(0, 512): - Tx.nki.tensorscalar(B_sbuf[i, j], A_sbuf[i, j], Tx.float32(1.0), "add") - with Tx.attr(0, "tensorized_nki_instruction", 1): + T.nki.tensorscalar(B_sbuf[i, j], A_sbuf[i, j], T.float32(1.0), "add") + with T.attr(0, "tensorized_nki_instruction", 1): for i in range(0, 128): for j in range(0, 512): - Tx.nki.store(B[i, 512*k+j], B_sbuf[i, j]) + T.nki.store(B[i, 512*k+j], B_sbuf[i, j]) # fmt: on src = lower_and_get_source(func) @@ -168,104 +168,99 @@ def test_nki_matmul_1(): NUM_BLOCK_N = N // BLOCK_N NUM_BLOCK_K = K // BLOCK_K - @Tx.prim_func + @T.prim_func def func( - lhsT: Tx.Buffer((K, M), "float16"), - rhs: Tx.Buffer((K, N), "float16"), - result: Tx.buffer((M, N), "float16"), + lhsT: T.Buffer((K, M), "float16"), + rhs: T.Buffer((K, N), "float16"), + result: T.buffer((M, N), "float16"), ): - Tx.func_attr({"num_inputs": 2}) - with Tx.thread(): - result_tiles = Tx.alloc_buffer( - (TILE_M, NUM_BLOCK_M, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILE_N), - "float32", - scope="trn.sbuf", - ) - rhs_tiles = Tx.alloc_buffer( - (TILE_K, TILES_IN_BLOCK_K, BLOCK_N), "float16", scope="trn.sbuf" - ) - lhsT_tiles = Tx.alloc_buffer( - (TILE_K, TILES_IN_BLOCK_K, BLOCK_M), "float16", scope="trn.sbuf" - ) - res_tile = Tx.alloc_buffer((1, TILE_M, TILE_N), "float32", scope="trn.psum") - result_packed = Tx.alloc_buffer((TILE_K, BLOCK_N), "float32", scope="trn.sbuf") - for n in range(NUM_BLOCK_N): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for i0 in range(TILE_M): - for i1 in range(NUM_BLOCK_M): - for i2 in range(TILES_IN_BLOCK_M): - for i3 in range(TILES_IN_BLOCK_N): - for i4 in range(TILE_N): - Tx.nki.memset( - result_tiles[i0, i1, i2, i3, i4], Tx.float32(0.0) - ) - for k in range(NUM_BLOCK_K): - for bk_r in range(TILES_IN_BLOCK_K): - with Tx.attr(0, "tensorized_nki_instruction", 1): + T.func_attr({"num_inputs": 2}) + result_tiles = T.alloc_buffer( + (TILE_M, NUM_BLOCK_M, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILE_N), + "float32", + scope="trn.sbuf", + ) + rhs_tiles = T.alloc_buffer((TILE_K, TILES_IN_BLOCK_K, BLOCK_N), "float16", scope="trn.sbuf") + lhsT_tiles = T.alloc_buffer( + (TILE_K, TILES_IN_BLOCK_K, BLOCK_M), "float16", scope="trn.sbuf" + ) + res_tile = T.alloc_buffer((1, TILE_M, TILE_N), "float32", scope="trn.psum") + result_packed = T.alloc_buffer((TILE_K, BLOCK_N), "float32", scope="trn.sbuf") + for n in range(NUM_BLOCK_N): + with T.attr(0, "tensorized_nki_instruction", 1): + for i0 in range(TILE_M): + for i1 in range(NUM_BLOCK_M): + for i2 in range(TILES_IN_BLOCK_M): + for i3 in range(TILES_IN_BLOCK_N): + for i4 in range(TILE_N): + T.nki.memset(result_tiles[i0, i1, i2, i3, i4], T.float32(0.0)) + for k in range(NUM_BLOCK_K): + for bk_r in range(TILES_IN_BLOCK_K): + with T.attr(0, "tensorized_nki_instruction", 1): + for i in range(TILE_K): + for j in range(BLOCK_N): + T.nki.load( + rhs_tiles[i, bk_r, j], + rhs[ + (TILES_IN_BLOCK_K * k + bk_r) * TILE_K + i, + n * BLOCK_N + j, + ], + ) + for m in range(NUM_BLOCK_M): + for bk_l in range(TILES_IN_BLOCK_K): + with T.attr(0, "tensorized_nki_instruction", 1): for i in range(TILE_K): - for j in range(BLOCK_N): - Tx.nki.load( - rhs_tiles[i, bk_r, j], - rhs[ - (TILES_IN_BLOCK_K * k + bk_r) * TILE_K + i, - n * BLOCK_N + j, + for j in range(BLOCK_M): + T.nki.load( + lhsT_tiles[i, bk_l, j], + lhsT[ + (TILES_IN_BLOCK_K * k + bk_l) * TILE_K + i, + m * BLOCK_M + j, ], ) - for m in range(NUM_BLOCK_M): - for bk_l in range(TILES_IN_BLOCK_K): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for i in range(TILE_K): - for j in range(BLOCK_M): - Tx.nki.load( - lhsT_tiles[i, bk_l, j], - lhsT[ - (TILES_IN_BLOCK_K * k + bk_l) * TILE_K + i, - m * BLOCK_M + j, - ], - ) - for bn in range(TILES_IN_BLOCK_N): - for bm in range(TILES_IN_BLOCK_M): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for i in range(TILE_M): - for j in range(TILE_N): - Tx.nki.memset(res_tile[0, i, j], Tx.float32(0.0)) - for bk in range(TILES_IN_BLOCK_K): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for i in range(TILE_M): - for j in range(TILE_N): - for k in range(TILE_K): - Tx.nki.matmul( - res_tile[0, i, j], - lhsT_tiles[k, bk, bm * TILE_M + i], - rhs_tiles[k, bk, bn * TILE_N + j], - 1, - ) - with Tx.attr(0, "tensorized_nki_instruction", 1): + for bn in range(TILES_IN_BLOCK_N): + for bm in range(TILES_IN_BLOCK_M): + with T.attr(0, "tensorized_nki_instruction", 1): + for i in range(TILE_M): + for j in range(TILE_N): + T.nki.memset(res_tile[0, i, j], T.float32(0.0)) + for bk in range(TILES_IN_BLOCK_K): + with T.attr(0, "tensorized_nki_instruction", 1): for i in range(TILE_M): for j in range(TILE_N): - Tx.nki.tensortensor( - result_tiles[i, m, bm, bn, j], - result_tiles[i, m, bm, bn, j], - res_tile[0, i, j], - "add", - ) - for m in range(NUM_BLOCK_M): - for bm in range(TILES_IN_BLOCK_M): - for bn in range(TILES_IN_BLOCK_N): - with Tx.attr(0, "tensorized_nki_instruction", 1): - for i in range(TILE_K): + for k in range(TILE_K): + T.nki.matmul( + res_tile[0, i, j], + lhsT_tiles[k, bk, bm * TILE_M + i], + rhs_tiles[k, bk, bn * TILE_N + j], + 1, + ) + with T.attr(0, "tensorized_nki_instruction", 1): + for i in range(TILE_M): for j in range(TILE_N): - Tx.nki.tensor_copy( - result_packed[i, bn * TILE_N + j], + T.nki.tensortensor( + result_tiles[i, m, bm, bn, j], result_tiles[i, m, bm, bn, j], + res_tile[0, i, j], + "add", ) - with Tx.attr(0, "tensorized_nki_instruction", 1): + for m in range(NUM_BLOCK_M): + for bm in range(TILES_IN_BLOCK_M): + for bn in range(TILES_IN_BLOCK_N): + with T.attr(0, "tensorized_nki_instruction", 1): for i in range(TILE_K): - for j in range(BLOCK_N): - Tx.nki.store( - result[m * BLOCK_M + bm * TILE_M + i, n * BLOCK_N + j], - result_packed[i, j], + for j in range(TILE_N): + T.nki.tensor_copy( + result_packed[i, bn * TILE_N + j], + result_tiles[i, m, bm, bn, j], ) + with T.attr(0, "tensorized_nki_instruction", 1): + for i in range(TILE_K): + for j in range(BLOCK_N): + T.nki.store( + result[m * BLOCK_M + bm * TILE_M + i, n * BLOCK_N + j], + result_packed[i, j], + ) # fmt: on diff --git a/tests/python/tirx/codegen/test_codegen_nvshmem.py b/tests/python/tirx/codegen/test_codegen_nvshmem.py index 10ee76e89d72..ff9f17170ddd 100644 --- a/tests/python/tirx/codegen/test_codegen_nvshmem.py +++ b/tests/python/tirx/codegen/test_codegen_nvshmem.py @@ -26,7 +26,7 @@ import tvm.testing from tvm.runtime import ShapeTuple from tvm.runtime import disco as di -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.support.popen_pool import PopenWorker NUM_WORKERS = 4 @@ -73,13 +73,13 @@ def _test_func(): sess.sync_worker_0() def test_thread_info(sess): - @Tx.prim_func - def main(res: Tx.Buffer((2,), "int32")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([nwarps * 32]) - res[0] = Tx.nvshmem.my_pe() - res[1] = Tx.nvshmem.n_pes() + @T.prim_func + def main(res: T.Buffer((2,), "int32")): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([nwarps * 32]) + res[0] = T.nvshmem.my_pe() + res[1] = T.nvshmem.n_pes() res_array = sess.empty((2,), "int32") run_prim_func(sess, main, res_array) @@ -88,26 +88,26 @@ def test_transfer(sess, scope, shape, nwarps, nelems, op_name): """Tests data transfer operations (get/put) at thread, warp, and block scopes.""" dtype = "float32" is_get = "get" in op_name - op_func = getattr(Tx.nvshmem, op_name) + op_func = getattr(T.nvshmem, op_name) if scope != "thread": op_func = getattr(op_func, scope) # fmt: off - @Tx.prim_func - def main(A: Tx.Buffer(shape, dtype), B: Tx.Buffer(shape, dtype)): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([nwarps]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([nwarps * 32]) - - my_pe = Tx.nvshmem.my_pe() - n_pes = Tx.nvshmem.n_pes() - offset = Tx.if_then_else( - scope == "block", 0, Tx.if_then_else(scope == "thread", tid, warp_id * 32) + @T.prim_func + def main(A: T.Buffer(shape, dtype), B: T.Buffer(shape, dtype)): + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([nwarps]) + lane_id = T.lane_id([32]) + tid = T.thread_id([nwarps * 32]) + + my_pe = T.nvshmem.my_pe() + n_pes = T.nvshmem.n_pes() + offset = T.if_then_else( + scope == "block", 0, T.if_then_else(scope == "thread", tid, warp_id * 32) ) op_func(dst=B.ptr_to([offset]), src=A.ptr_to([offset]), nelems=nelems, pe=(my_pe + 1) % n_pes) # noqa: E501 - Tx.nvshmem.quiet() + T.nvshmem.quiet() # fmt: on def init_fn(i, s, d): @@ -132,19 +132,19 @@ def test_signal_op(sess, sig_op): cmp_value = 1 if sig_op == "set" else 2 # fmt: off - @Tx.prim_func - def main(res: Tx.Buffer((1,), "uint64")): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([nwarps * 32]) - my_pe = Tx.nvshmem.my_pe() - n_pes = Tx.nvshmem.n_pes() + @T.prim_func + def main(res: T.Buffer((1,), "uint64")): + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([nwarps * 32]) + my_pe = T.nvshmem.my_pe() + n_pes = T.nvshmem.n_pes() dst_pe = (my_pe + 1) % n_pes if sig_op == "add": res[0] = 1 - Tx.nvshmem.barrier_all() - Tx.nvshmem.signal_op(sig_addr=res.ptr_to([0]), signal=1, sig_op=sig_op, pe=dst_pe) - Tx.nvshmem.wait_until(ivar=res.ptr_to([0]), cmp="eq", cmp_value=cmp_value) + T.nvshmem.barrier_all() + T.nvshmem.signal_op(sig_addr=res.ptr_to([0]), signal=1, sig_op=sig_op, pe=dst_pe) + T.nvshmem.wait_until(ivar=res.ptr_to([0]), cmp="eq", cmp_value=cmp_value) # fmt: on res_array = create_nvshmem_array(sess, (1,), "uint64") @@ -161,45 +161,43 @@ def main(res: Tx.Buffer((1,), "uint64")): def test_put_signal(sess, scope, shape, nwarps, nelems, cmp_value): """Tests combined data transfer and signal operations at thread/warp/block scopes.""" dtype = "float32" - op_func = getattr(Tx.nvshmem, "putmem_signal_nbi") + op_func = getattr(T.nvshmem, "putmem_signal_nbi") if scope != "thread": op_func = getattr(op_func, scope) - @Tx.prim_func + @T.prim_func def main( - A: Tx.Buffer(shape, dtype), - B: Tx.Buffer(shape, dtype), - signal_array: Tx.Buffer((1,), "uint64"), + A: T.Buffer(shape, dtype), + B: T.Buffer(shape, dtype), + signal_array: T.Buffer((1,), "uint64"), ): - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([nwarps]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([nwarps * 32]) - - with Tx.thread(): - my_pe = Tx.nvshmem.my_pe() - n_pes = Tx.nvshmem.n_pes() - dst_pe = (my_pe + 1) % n_pes - offset = Tx.if_then_else( - scope == "block", - 0, - Tx.if_then_else(scope == "thread", tid, warp_id * 32), - ) - op_func( - dst=B.access_ptr("w", offset=offset), - src=A.access_ptr("r", offset=offset), - nelems=nelems, - sig_addr=signal_array.access_ptr("w", offset=0), - signal=1, - sig_op="set", - pe=dst_pe, - ) - Tx.nvshmem.wait_until( - ivar=signal_array.access_ptr("r", offset=0), - cmp="eq", - cmp_value=cmp_value, - ) + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([nwarps]) + lane_id = T.lane_id([32]) + tid = T.thread_id([nwarps * 32]) + my_pe = T.nvshmem.my_pe() + n_pes = T.nvshmem.n_pes() + dst_pe = (my_pe + 1) % n_pes + offset = T.if_then_else( + scope == "block", + 0, + T.if_then_else(scope == "thread", tid, warp_id * 32), + ) + op_func( + dst=B.access_ptr("w", offset=offset), + src=A.access_ptr("r", offset=offset), + nelems=nelems, + sig_addr=signal_array.access_ptr("w", offset=0), + signal=1, + sig_op="set", + pe=dst_pe, + ) + T.nvshmem.wait_until( + ivar=signal_array.access_ptr("r", offset=0), + cmp="eq", + cmp_value=cmp_value, + ) def init_A(i, s, d): return np.arange(s[0], dtype=d) + i * 100 @@ -223,24 +221,22 @@ def test_fence_barrier(sess): dtype = "float32" # fmt: off - @Tx.prim_func - def main(A: Tx.Buffer(shape, dtype), B: Tx.Buffer(shape, dtype), res: Tx.Buffer((1,), "uint64")): # noqa: E501 - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([nwarps]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([2 * 32]) - - with Tx.thread(): - my_pe = Tx.nvshmem.my_pe() - n_pes = Tx.nvshmem.n_pes() - dst_pe = (my_pe + 1) % n_pes - Tx.nvshmem.barrier_all() - Tx.nvshmem.putmem_nbi.block(dst=B.ptr_to([0]), src=A.ptr_to([0]), nelems=4 * 64, pe=(my_pe + 1) % n_pes) # noqa: E501 - Tx.nvshmem.fence() - if tid == 0: - Tx.nvshmem.signal_op(sig_addr=res.ptr_to([0]), signal=1, sig_op="set", pe=dst_pe) # noqa: E501 - Tx.nvshmem.wait_until(ivar=res.ptr_to([0]), cmp="eq", cmp_value=1) + @T.prim_func + def main(A: T.Buffer(shape, dtype), B: T.Buffer(shape, dtype), res: T.Buffer((1,), "uint64")): # noqa: E501 + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([nwarps]) + lane_id = T.lane_id([32]) + tid = T.thread_id([2 * 32]) + my_pe = T.nvshmem.my_pe() + n_pes = T.nvshmem.n_pes() + dst_pe = (my_pe + 1) % n_pes + T.nvshmem.barrier_all() + T.nvshmem.putmem_nbi.block(dst=B.ptr_to([0]), src=A.ptr_to([0]), nelems=4 * 64, pe=(my_pe + 1) % n_pes) # noqa: E501 + T.nvshmem.fence() + if tid == 0: + T.nvshmem.signal_op(sig_addr=res.ptr_to([0]), signal=1, sig_op="set", pe=dst_pe) + T.nvshmem.wait_until(ivar=res.ptr_to([0]), cmp="eq", cmp_value=1) # fmt: on def init_fn(i, s, d): return np.arange(s[0], dtype=d) + i * 100 diff --git a/tests/python/tirx/codegen/test_cuda_copy.py b/tests/python/tirx/codegen/test_cuda_copy.py index fa23e01a5276..cb08f4247318 100644 --- a/tests/python/tirx/codegen/test_cuda_copy.py +++ b/tests/python/tirx/codegen/test_cuda_copy.py @@ -20,7 +20,7 @@ import pytest import tvm -from tvm.script import tirx as Tx +from tvm.script import tirx as T DEV = tvm.cuda(0) TARGET = tvm.target.Target("cuda") @@ -38,27 +38,23 @@ def test_copy_128b(): """copy_128b: copies 16 bytes (4 float32 elements) via uint4 load/store.""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (4,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - src_buf = Tx.alloc_buffer((4,), "float32", scope="shared") - dst_buf = Tx.alloc_buffer((4,), "float32", scope="shared") - with Tx.thread(): - if lane < 4: - src_buf[lane] = Tx.float32(lane + 1) - Tx.cuda.cta_sync() - with Tx.thread(): - if lane == 0: - Tx.cuda.copy_128b(dst_buf.ptr_to([0]), src_buf.ptr_to([0])) - Tx.cuda.cta_sync() - with Tx.thread(): - if lane < 4: - out[lane] = dst_buf[lane] + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (4,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + src_buf = T.alloc_buffer((4,), "float32", scope="shared") + dst_buf = T.alloc_buffer((4,), "float32", scope="shared") + if lane < 4: + src_buf[lane] = T.float32(lane + 1) + T.cuda.cta_sync() + if lane == 0: + T.cuda.copy_128b(dst_buf.ptr_to([0]), src_buf.ptr_to([0])) + T.cuda.cta_sync() + if lane < 4: + out[lane] = dst_buf[lane] # fmt: on out_np = np.zeros(4, dtype="float32") @@ -71,27 +67,23 @@ def test_copy_64b(): """copy_64b: copies 8 bytes (2 float32 elements) via uint2 load/store.""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (2,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - src_buf = Tx.alloc_buffer((2,), "float32", scope="shared") - dst_buf = Tx.alloc_buffer((2,), "float32", scope="shared") - with Tx.thread(): - if lane < 2: - src_buf[lane] = Tx.float32(lane + 10) - Tx.cuda.cta_sync() - with Tx.thread(): - if lane == 0: - Tx.cuda.copy_64b(dst_buf.ptr_to([0]), src_buf.ptr_to([0])) - Tx.cuda.cta_sync() - with Tx.thread(): - if lane < 2: - out[lane] = dst_buf[lane] + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (2,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + src_buf = T.alloc_buffer((2,), "float32", scope="shared") + dst_buf = T.alloc_buffer((2,), "float32", scope="shared") + if lane < 2: + src_buf[lane] = T.float32(lane + 10) + T.cuda.cta_sync() + if lane == 0: + T.cuda.copy_64b(dst_buf.ptr_to([0]), src_buf.ptr_to([0])) + T.cuda.cta_sync() + if lane < 2: + out[lane] = dst_buf[lane] # fmt: on out_np = np.zeros(2, dtype="float32") @@ -104,27 +96,23 @@ def test_copy_32b(): """copy_32b: copies 4 bytes (1 float32 element) via unsigned int load/store.""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (1,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - src_buf = Tx.alloc_buffer((1,), "float32", scope="shared") - dst_buf = Tx.alloc_buffer((1,), "float32", scope="shared") - with Tx.thread(): - if lane == 0: - src_buf[0] = Tx.float32(42) - Tx.cuda.cta_sync() - with Tx.thread(): - if lane == 0: - Tx.cuda.copy_32b(dst_buf.ptr_to([0]), src_buf.ptr_to([0])) - Tx.cuda.cta_sync() - with Tx.thread(): - if lane == 0: - out[0] = dst_buf[0] + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (1,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + src_buf = T.alloc_buffer((1,), "float32", scope="shared") + dst_buf = T.alloc_buffer((1,), "float32", scope="shared") + if lane == 0: + src_buf[0] = T.float32(42) + T.cuda.cta_sync() + if lane == 0: + T.cuda.copy_32b(dst_buf.ptr_to([0]), src_buf.ptr_to([0])) + T.cuda.cta_sync() + if lane == 0: + out[0] = dst_buf[0] # fmt: on out_np = np.zeros(1, dtype="float32") @@ -137,27 +125,23 @@ def test_copy_16b(): """copy_16b: copies 2 bytes (1 float16 element) via unsigned short load/store.""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (1,), "float16") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - src_buf = Tx.alloc_buffer((1,), "float16", scope="shared") - dst_buf = Tx.alloc_buffer((1,), "float16", scope="shared") - with Tx.thread(): - if lane == 0: - src_buf[0] = Tx.float16(7) - Tx.cuda.cta_sync() - with Tx.thread(): - if lane == 0: - Tx.cuda.copy_16b(dst_buf.ptr_to([0]), src_buf.ptr_to([0])) - Tx.cuda.cta_sync() - with Tx.thread(): - if lane == 0: - out[0] = dst_buf[0] + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (1,), "float16") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + src_buf = T.alloc_buffer((1,), "float16", scope="shared") + dst_buf = T.alloc_buffer((1,), "float16", scope="shared") + if lane == 0: + src_buf[0] = T.float16(7) + T.cuda.cta_sync() + if lane == 0: + T.cuda.copy_16b(dst_buf.ptr_to([0]), src_buf.ptr_to([0])) + T.cuda.cta_sync() + if lane == 0: + out[0] = dst_buf[0] # fmt: on out_np = np.zeros(1, dtype="float16") @@ -170,27 +154,23 @@ def test_copy_8b(): """copy_8b: copies 1 byte (1 uint8 element) via unsigned char load/store.""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (1,), "uint8") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - src_buf = Tx.alloc_buffer((1,), "uint8", scope="shared") - dst_buf = Tx.alloc_buffer((1,), "uint8", scope="shared") - with Tx.thread(): - if lane == 0: - src_buf[0] = Tx.uint8(255) - Tx.cuda.cta_sync() - with Tx.thread(): - if lane == 0: - Tx.cuda.copy_8b(dst_buf.ptr_to([0]), src_buf.ptr_to([0])) - Tx.cuda.cta_sync() - with Tx.thread(): - if lane == 0: - out[0] = dst_buf[0] + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (1,), "uint8") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + src_buf = T.alloc_buffer((1,), "uint8", scope="shared") + dst_buf = T.alloc_buffer((1,), "uint8", scope="shared") + if lane == 0: + src_buf[0] = T.uint8(255) + T.cuda.cta_sync() + if lane == 0: + T.cuda.copy_8b(dst_buf.ptr_to([0]), src_buf.ptr_to([0])) + T.cuda.cta_sync() + if lane == 0: + out[0] = dst_buf[0] # fmt: on out_np = np.zeros(1, dtype="uint8") @@ -205,23 +185,21 @@ def func(out_ptr: Tx.handle): def test_codegen_function_names(num_bytes, func_suffix): """Verify each copy variant generates the expected C++ function name.""" - copy_fn = getattr(Tx.cuda, f"copy_{func_suffix}") + copy_fn = getattr(T.cuda, f"copy_{func_suffix}") # fmt: off - @Tx.prim_func - def func(dummy_ptr: Tx.handle): - dummy = Tx.match_buffer(dummy_ptr, (16,), "uint8") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - a = Tx.alloc_buffer((16,), "uint8", scope="shared") - b = Tx.alloc_buffer((16,), "uint8", scope="shared") - with Tx.thread(): - if lane == 0: - copy_fn(b.ptr_to([0]), a.ptr_to([0])) - dummy[0] = Tx.uint8(0) + @T.prim_func + def func(dummy_ptr: T.handle): + dummy = T.match_buffer(dummy_ptr, (16,), "uint8") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + a = T.alloc_buffer((16,), "uint8", scope="shared") + b = T.alloc_buffer((16,), "uint8", scope="shared") + if lane == 0: + copy_fn(b.ptr_to([0]), a.ptr_to([0])) + dummy[0] = T.uint8(0) # fmt: on mod = tvm.IRModule({"main": func}) diff --git a/tests/python/tirx/codegen/test_cuda_cta_reduce.py b/tests/python/tirx/codegen/test_cuda_cta_reduce.py index c17709cfaa7b..51b8f1099a91 100644 --- a/tests/python/tirx/codegen/test_cuda_cta_reduce.py +++ b/tests/python/tirx/codegen/test_cuda_cta_reduce.py @@ -20,7 +20,7 @@ import pytest import tvm -from tvm.script import tirx as Tx +from tvm.script import tirx as T DEV = tvm.cuda(0) TARGET = tvm.target.Target("cuda") @@ -41,20 +41,18 @@ def test_cta_sum_4_warps(): N = NUM_WARPS * 32 # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (N,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([NUM_WARPS]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([N]) - with Tx.cta(): - scratch = Tx.alloc_buffer((NUM_WARPS,), "float32", scope="shared") - with Tx.thread(): - val: Tx.f32 = Tx.float32(tid + 1) - val = Tx.cuda.cta_sum(val, NUM_WARPS, scratch.ptr_to([0])) - out[tid] = val + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (N,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([NUM_WARPS]) + lane_id = T.lane_id([32]) + tid = T.thread_id([N]) + scratch = T.alloc_buffer((NUM_WARPS,), "float32", scope="shared") + val: T.f32 = T.float32(tid + 1) + val = T.cuda.cta_sum(val, NUM_WARPS, scratch.ptr_to([0])) + out[tid] = val # fmt: on result, mod = _build_and_run(func, N) @@ -69,20 +67,18 @@ def test_cta_sum_8_warps(): N = NUM_WARPS * 32 # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (N,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([NUM_WARPS]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([N]) - with Tx.cta(): - scratch = Tx.alloc_buffer((NUM_WARPS,), "float32", scope="shared") - with Tx.thread(): - val: Tx.f32 = Tx.float32(tid + 1) - val = Tx.cuda.cta_sum(val, NUM_WARPS, scratch.ptr_to([0])) - out[tid] = val + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (N,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([NUM_WARPS]) + lane_id = T.lane_id([32]) + tid = T.thread_id([N]) + scratch = T.alloc_buffer((NUM_WARPS,), "float32", scope="shared") + val: T.f32 = T.float32(tid + 1) + val = T.cuda.cta_sum(val, NUM_WARPS, scratch.ptr_to([0])) + out[tid] = val # fmt: on result, _ = _build_and_run(func, N) @@ -96,20 +92,18 @@ def test_cta_max_4_warps(): N = NUM_WARPS * 32 # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (N,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([NUM_WARPS]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([N]) - with Tx.cta(): - scratch = Tx.alloc_buffer((NUM_WARPS,), "float32", scope="shared") - with Tx.thread(): - val: Tx.f32 = Tx.float32(tid + 1) - val = Tx.cuda.cta_max(val, NUM_WARPS, scratch.ptr_to([0])) - out[tid] = val + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (N,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([NUM_WARPS]) + lane_id = T.lane_id([32]) + tid = T.thread_id([N]) + scratch = T.alloc_buffer((NUM_WARPS,), "float32", scope="shared") + val: T.f32 = T.float32(tid + 1) + val = T.cuda.cta_max(val, NUM_WARPS, scratch.ptr_to([0])) + out[tid] = val # fmt: on result, _ = _build_and_run(func, N) @@ -122,20 +116,18 @@ def test_cta_min_4_warps(): N = NUM_WARPS * 32 # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (N,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([NUM_WARPS]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([N]) - with Tx.cta(): - scratch = Tx.alloc_buffer((NUM_WARPS,), "float32", scope="shared") - with Tx.thread(): - val: Tx.f32 = Tx.float32(tid + 1) - val = Tx.cuda.cta_min(val, NUM_WARPS, scratch.ptr_to([0])) - out[tid] = val + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (N,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([NUM_WARPS]) + lane_id = T.lane_id([32]) + tid = T.thread_id([N]) + scratch = T.alloc_buffer((NUM_WARPS,), "float32", scope="shared") + val: T.f32 = T.float32(tid + 1) + val = T.cuda.cta_min(val, NUM_WARPS, scratch.ptr_to([0])) + out[tid] = val # fmt: on result, _ = _build_and_run(func, N) @@ -148,20 +140,18 @@ def test_cta_sum_1_warp(): N = 32 # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (N,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([NUM_WARPS]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([N]) - with Tx.cta(): - scratch = Tx.alloc_buffer((NUM_WARPS,), "float32", scope="shared") - with Tx.thread(): - val: Tx.f32 = Tx.float32(tid + 1) - val = Tx.cuda.cta_sum(val, NUM_WARPS, scratch.ptr_to([0])) - out[tid] = val + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (N,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([NUM_WARPS]) + lane_id = T.lane_id([32]) + tid = T.thread_id([N]) + scratch = T.alloc_buffer((NUM_WARPS,), "float32", scope="shared") + val: T.f32 = T.float32(tid + 1) + val = T.cuda.cta_sum(val, NUM_WARPS, scratch.ptr_to([0])) + out[tid] = val # fmt: on result, _ = _build_and_run(func, N) @@ -175,20 +165,18 @@ def test_cta_sum_all_warp_counts(num_warps): N = num_warps * 32 # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (N,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([num_warps]) - lane_id = Tx.lane_id([32]) - tid = Tx.thread_id([N]) - with Tx.cta(): - scratch = Tx.alloc_buffer((num_warps,), "float32", scope="shared") - with Tx.thread(): - val: Tx.f32 = Tx.float32(tid + 1) - val = Tx.cuda.cta_sum(val, num_warps, scratch.ptr_to([0])) - out[tid] = val + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (N,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([num_warps]) + lane_id = T.lane_id([32]) + tid = T.thread_id([N]) + scratch = T.alloc_buffer((num_warps,), "float32", scope="shared") + val: T.f32 = T.float32(tid + 1) + val = T.cuda.cta_sum(val, num_warps, scratch.ptr_to([0])) + out[tid] = val # fmt: on result, _ = _build_and_run(func, N) diff --git a/tests/python/tirx/codegen/test_cuda_warp_reduce.py b/tests/python/tirx/codegen/test_cuda_warp_reduce.py index 615fa3eb36d1..df568a95e483 100644 --- a/tests/python/tirx/codegen/test_cuda_warp_reduce.py +++ b/tests/python/tirx/codegen/test_cuda_warp_reduce.py @@ -20,7 +20,7 @@ import pytest import tvm -from tvm.script import tirx as Tx +from tvm.script import tirx as T DEV = tvm.cuda(0) TARGET = tvm.target.Target("cuda") @@ -39,15 +39,15 @@ def test_warp_sum_full(): """Full warp sum (width=32): each lane gets the sum of all 32 values.""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (32,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - val: Tx.f32 = Tx.float32(lane + 1) - val = Tx.cuda.warp_sum(val) + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (32,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + val: T.f32 = T.float32(lane + 1) + val = T.cuda.warp_sum(val) out[lane] = val # fmt: on @@ -61,15 +61,15 @@ def test_warp_sum_partial_8(): """Partial warp sum (width=8): 4 groups of 8 lanes, each group sums independently.""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (32,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - val: Tx.f32 = Tx.float32(lane + 1) - val = Tx.cuda.warp_sum(val, width=8) + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (32,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + val: T.f32 = T.float32(lane + 1) + val = T.cuda.warp_sum(val, width=8) out[lane] = val # fmt: on @@ -89,15 +89,15 @@ def test_warp_max_partial_4(): """Partial warp max (width=4): 8 groups of 4 lanes.""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (32,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - val: Tx.f32 = Tx.float32(lane + 1) - val = Tx.cuda.warp_max(val, width=4) + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (32,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + val: T.f32 = T.float32(lane + 1) + val = T.cuda.warp_max(val, width=4) out[lane] = val # fmt: on @@ -113,15 +113,15 @@ def test_warp_min_full(): """Full warp min (width=32).""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (32,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - val: Tx.f32 = Tx.float32(lane + 1) - val = Tx.cuda.warp_min(val) + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (32,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + val: T.f32 = T.float32(lane + 1) + val = T.cuda.warp_min(val) out[lane] = val # fmt: on @@ -133,15 +133,15 @@ def test_warp_sum_partial_2(): """Smallest partial warp sum (width=2): 16 pairs of adjacent lanes.""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (32,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - val: Tx.f32 = Tx.float32(lane) - val = Tx.cuda.warp_sum(val, width=2) + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (32,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + val: T.f32 = T.float32(lane) + val = T.cuda.warp_sum(val, width=2) out[lane] = val # fmt: on @@ -160,15 +160,15 @@ def test_warp_sum_all_widths(width): """Parametric test: warp_sum with every valid width.""" # fmt: off - @Tx.prim_func - def func(out_ptr: Tx.handle): - out = Tx.match_buffer(out_ptr, (32,), "float32") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - val: Tx.f32 = Tx.float32(lane) - val = Tx.cuda.warp_sum(val, width=width) + @T.prim_func + def func(out_ptr: T.handle): + out = T.match_buffer(out_ptr, (32,), "float32") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane = T.lane_id([32]) + val: T.f32 = T.float32(lane) + val = T.cuda.warp_sum(val, width=width) out[lane] = val # fmt: on diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py index 2347cd0a0561..340eb9809493 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py @@ -30,7 +30,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import S, TileLayout # Force the fallback dispatch to register before any test compiles a kernel. @@ -74,57 +75,52 @@ def _build_round_trip_kernel(scope, n_threads, shape, dtype): # pair on ``A_smem`` would otherwise race. if scope == "warp": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.lane_id([32]) - Tx.thread_id([n_threads]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - with Tx.warp(): - Tx.copy(A_smem[full], A[full]) - Tx.cuda.cta_sync() - Tx.copy(B[full], A_smem[full]) + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.lane_id([32]) + T.thread_id([n_threads]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + Tx.warp.copy(A_smem[full], A[full]) + T.cuda.cta_sync() + Tx.warp.copy(B[full], A_smem[full]) elif scope == "warpgroup": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.warpgroup_id([n_threads // 128]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - Tx.thread_id_in_wg([128]) - Tx.thread_id([n_threads]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - with Tx.warpgroup(): - Tx.copy(A_smem[full], A[full]) - Tx.cuda.cta_sync() - Tx.copy(B[full], A_smem[full]) + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.warpgroup_id([n_threads // 128]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + T.thread_id_in_wg([128]) + T.thread_id([n_threads]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + Tx.wg.copy(A_smem[full], A[full]) + T.cuda.cta_sync() + Tx.wg.copy(B[full], A_smem[full]) elif scope == "cta": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.warp_id([n_threads // 32]) - Tx.lane_id([32]) - Tx.thread_id([n_threads]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - Tx.copy(A_smem[full], A[full]) - Tx.cuda.cta_sync() - Tx.copy(B[full], A_smem[full]) + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.warp_id([n_threads // 32]) + T.lane_id([32]) + T.thread_id([n_threads]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + Tx.cta.copy(A_smem[full], A[full]) + T.cuda.cta_sync() + Tx.cta.copy(B[full], A_smem[full]) else: raise ValueError(f"unsupported scope {scope!r}") @@ -163,7 +159,7 @@ def test_fallback_round_trip(scope, n_threads, shape, why): def test_fallback_thread_scope(): - """``Tx.thread()`` — single thread, no gate. Either ``gmem_smem`` picks + """``T.thread()`` — single thread, no gate. Either ``gmem_smem`` picks it up (n_elements % 1 == 0) or ``fallback`` does — both end up emitting a sensible single-thread copy. We only check the round trip is correct, not which variant fired.""" @@ -172,18 +168,17 @@ def test_fallback_thread_scope(): s_layout = TileLayout(S[shape]) full = tuple(slice(0, d) for d in shape) - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([1]) - with Tx.thread(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - Tx.copy(A_smem[full], A[full]) - Tx.cuda.cta_sync() - Tx.copy(B[full], A_smem[full]) + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.thread_id([1]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + Tx.copy(A_smem[full], A[full]) + T.cuda.cta_sync() + Tx.copy(B[full], A_smem[full]) dev = tvm.cuda(0) target = tvm.target.Target("cuda") @@ -209,19 +204,18 @@ def test_fallback_emits_gate(): s_layout = TileLayout(S[shape]) full = tuple(slice(0, d) for d in shape) - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.warp_id([8]) # 256 threads => 8 warps - Tx.lane_id([32]) - Tx.thread_id([256]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - Tx.copy(A_smem[full], A[full]) - Tx.copy(B[full], A_smem[full]) + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.warp_id([8]) # 256 threads => 8 warps + T.lane_id([32]) + T.thread_id([256]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + Tx.cta.copy(A_smem[full], A[full]) + Tx.cta.copy(B[full], A_smem[full]) target = tvm.target.Target("cuda") with target, pytest.warns(UserWarning, match="copy/fallback"): diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py index 3bde53a36d3d..86a33b940f9d 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py @@ -26,7 +26,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import ComposeLayout, S, SwizzleLayout, TileLayout @@ -36,57 +37,52 @@ def _build_kernel(scope, n_threads, shape, dtype): if scope == "warpgroup": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.warpgroup_id([n_threads // 128]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - Tx.thread_id_in_wg([128]) - Tx.thread_id([n_threads]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - with Tx.warpgroup(): - Tx.copy(A_smem[full_slices], A[full_slices]) - Tx.cuda.cta_sync() - Tx.copy(B[full_slices], A_smem[full_slices]) + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.warpgroup_id([n_threads // 128]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + T.thread_id_in_wg([128]) + T.thread_id([n_threads]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + Tx.wg.copy(A_smem[full_slices], A[full_slices]) + T.cuda.cta_sync() + Tx.wg.copy(B[full_slices], A_smem[full_slices]) elif scope == "warp": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.lane_id([32]) - Tx.thread_id([n_threads]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - with Tx.warp(): - Tx.copy(A_smem[full_slices], A[full_slices]) - Tx.cuda.cta_sync() - Tx.copy(B[full_slices], A_smem[full_slices]) + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.lane_id([32]) + T.thread_id([n_threads]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + Tx.warp.copy(A_smem[full_slices], A[full_slices]) + T.cuda.cta_sync() + Tx.warp.copy(B[full_slices], A_smem[full_slices]) elif scope == "cta": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.warp_id([n_threads // 32]) - Tx.lane_id([32]) - Tx.thread_id([n_threads]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - Tx.copy(A_smem[full_slices], A[full_slices]) - Tx.cuda.cta_sync() - Tx.copy(B[full_slices], A_smem[full_slices]) + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.warp_id([n_threads // 32]) + T.lane_id([32]) + T.thread_id([n_threads]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + Tx.cta.copy(A_smem[full_slices], A[full_slices]) + T.cuda.cta_sync() + Tx.cta.copy(B[full_slices], A_smem[full_slices]) else: raise ValueError(f"unsupported scope {scope!r}") @@ -207,26 +203,24 @@ def test_copy_g2s_s2g(task, dtype, scope): r_smem = tuple(slice(None) for _ in range(len(s_shape))) r_gmem = tuple(slice(g_region[i][0], g_region[i][1]) for i in range(len(g_shape))) - if scope == "cta": - scoper = Tx.cta - elif scope == "thread": - scoper = Tx.thread + if scope == "thread": thread_cnt = 1 - @Tx.prim_func - def copy_sync(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) - B = Tx.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) + @T.prim_func + def copy_sync(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) + B = T.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) - Tx.device_entry() - Tx.cta_id([2]) - Tx.thread_id([thread_cnt]) + T.device_entry() + T.cta_id([2]) + T.thread_id([thread_cnt]) - with scoper(): - A_smem = Tx.alloc_buffer(s_shape, dtype, scope="shared", layout=layoutS) - Tx.copy(A_smem[r_smem], A[r_gmem]) - Tx.cuda.cta_sync() - Tx.copy(B[r_gmem], A_smem[r_smem]) + A_smem = T.alloc_buffer(s_shape, dtype, scope="shared", layout=layoutS) + # `scope` is parametrized at runtime; select the scope namespace + # dynamically (T.cta / T.thread) instead of a literal prefix. + getattr(Tx, scope).copy(A_smem[r_smem], A[r_gmem]) + T.cuda.cta_sync() + getattr(Tx, scope).copy(B[r_gmem], A_smem[r_smem]) np_dtype = tvm.testing.np_dtype_from_str(dtype) target = tvm.target.Target("cuda") @@ -351,26 +345,24 @@ def test_swizzled_smem_emit_must_be_swizzle_aware(): ``s_buf.ptr_to([0,..,0]) + linear_offset`` which only matches a non-swizzled storage layout.""" import tvm - from tvm.script import tirx as Tx + from tvm.script import tirx as T from tvm.tirx.layout import ComposeLayout, S, SwizzleLayout, TileLayout shape = (128, 32) s_layout = ComposeLayout(SwizzleLayout(3, 3, 3), TileLayout(S[shape])) - @Tx.prim_func - def kernel(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, "float16") - Tx.device_entry() - Tx.cta_id([1]) - Tx.warpgroup_id([1]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - Tx.thread_id_in_wg([128]) - Tx.thread_id([128]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, "float16", scope="shared", layout=s_layout) - with Tx.warpgroup(): - Tx.copy(A_smem[0:128, 0:32], A[0:128, 0:32]) + @T.prim_func + def kernel(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, "float16") + T.device_entry() + T.cta_id([1]) + T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + T.thread_id_in_wg([128]) + T.thread_id([128]) + A_smem = T.alloc_buffer(shape, "float16", scope="shared", layout=s_layout) + Tx.wg.copy(A_smem[0:128, 0:32], A[0:128, 0:32]) # NB: pin sm_90 explicitly — the default cuda target falls back to sm_50 # when no GPU is detected, which nvcc 13+ rejects. Codegen happens before @@ -529,20 +521,18 @@ def test_gmem_smem_swizzle_fast_path_fires_with_var_bounds(): g_layout = TileLayout(S[shape]) s_layout = ComposeLayout(swizzle, TileLayout(S[shape])) - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, "float16", layout=g_layout) - B = Tx.match_buffer(B_ptr, shape, "float16", layout=g_layout) - Tx.device_entry() - Tx.cta_id([1]) - Tx.lane_id([32]) - Tx.thread_id([32]) - with Tx.cta(): - smem = Tx.alloc_buffer(shape, "float16", scope="shared", layout=s_layout) - with Tx.warp(): - Tx.copy(smem, A[:, :]) - Tx.cuda.cta_sync() - Tx.copy(B[:, :], smem) + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, "float16", layout=g_layout) + B = T.match_buffer(B_ptr, shape, "float16", layout=g_layout) + T.device_entry() + T.cta_id([1]) + T.lane_id([32]) + T.thread_id([32]) + smem = T.alloc_buffer(shape, "float16", scope="shared", layout=s_layout) + Tx.warp.copy(smem, A[:, :]) + T.cuda.cta_sync() + Tx.warp.copy(B[:, :], smem) target = tvm.target.Target("cuda") with target: diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_ld_stmatrix.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_ld_stmatrix.py index 37b7ac95b085..fc62806c9bf6 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_ld_stmatrix.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_ld_stmatrix.py @@ -18,9 +18,9 @@ """Round-trip tests for the ``ldstmatrix`` copy dispatch. Pipeline: - ld direction: A_gmem → A_smem (per-thread init) → R_local (Tx.copy dispatch + ld direction: A_gmem → A_smem (per-thread init) → R_local (T.copy dispatch under test) → B_gmem (per-thread write). - st direction: A_gmem → R_local (per-thread init) → A_smem (Tx.copy dispatch + st direction: A_gmem → R_local (per-thread init) → A_smem (T.copy dispatch under test) → B_gmem (per-thread write). Both directions must round-trip ``A == B``. Layout strides are constructed @@ -37,7 +37,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import ComposeLayout, S, SwizzleLayout, TileLayout, laneid, tid_in_wg, tx @@ -104,57 +105,53 @@ def _coord(row, cp, t, w): # fmt: off if direction == "ld": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M, N), "float16") - B = Tx.match_buffer(B_ptr, (M, N), "float16") - Tx.device_entry() - Tx.cta_id([1]) - Tx.lane_id([32]) - tid = Tx.thread_id([32]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) - with Tx.warp(): - row = tid // 4 - cp = tid % 4 - for t in range(num): - for w in range(2): - gr, gc = _coord(row, cp, t, w) - A_smem[row, cp, t, w] = A[gr, gc] - Tx.cuda.cta_sync() - R_local = Tx.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) - Tx.copy(R_local[full], A_smem[full]) - r_view = R_local.local() - for t in range(num): - for w in range(2): - gr, gc = _coord(row, cp, t, w) - B[gr, gc] = r_view[t * 2 + w] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M, N), "float16") + B = T.match_buffer(B_ptr, (M, N), "float16") + T.device_entry() + T.cta_id([1]) + T.lane_id([32]) + tid = T.thread_id([32]) + A_smem = T.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) + row = tid // 4 + cp = tid % 4 + for t in range(num): + for w in range(2): + gr, gc = _coord(row, cp, t, w) + A_smem[row, cp, t, w] = A[gr, gc] + T.cuda.cta_sync() + R_local = T.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) + Tx.warp.copy(R_local[full], A_smem[full]) + r_view = R_local.local() + for t in range(num): + for w in range(2): + gr, gc = _coord(row, cp, t, w) + B[gr, gc] = r_view[t * 2 + w] else: # direction == "st" - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M, N), "float16") - B = Tx.match_buffer(B_ptr, (M, N), "float16") - Tx.device_entry() - Tx.cta_id([1]) - Tx.lane_id([32]) - tid = Tx.thread_id([32]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) - with Tx.warp(): - row = tid // 4 - cp = tid % 4 - R_local = Tx.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) - r_view = R_local.local() - for t in range(num): - for w in range(2): - gr, gc = _coord(row, cp, t, w) - r_view[t * 2 + w] = A[gr, gc] - Tx.copy(A_smem[full], R_local[full]) - Tx.cuda.cta_sync() - for t in range(num): - for w in range(2): - gr, gc = _coord(row, cp, t, w) - B[gr, gc] = A_smem[row, cp, t, w] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M, N), "float16") + B = T.match_buffer(B_ptr, (M, N), "float16") + T.device_entry() + T.cta_id([1]) + T.lane_id([32]) + tid = T.thread_id([32]) + A_smem = T.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) + row = tid // 4 + cp = tid % 4 + R_local = T.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) + r_view = R_local.local() + for t in range(num): + for w in range(2): + gr, gc = _coord(row, cp, t, w) + r_view[t * 2 + w] = A[gr, gc] + Tx.warp.copy(A_smem[full], R_local[full]) + T.cuda.cta_sync() + for t in range(num): + for w in range(2): + gr, gc = _coord(row, cp, t, w) + B[gr, gc] = A_smem[row, cp, t, w] # fmt: on return kernel, (M, N) @@ -176,67 +173,63 @@ def _coord(wid, row, cp, t, w): # fmt: off if direction == "ld": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M, N), "float16") - B = Tx.match_buffer(B_ptr, (M, N), "float16") - Tx.device_entry() - Tx.cta_id([1]) - Tx.warpgroup_id([1]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - Tx.thread_id_in_wg([128]) - tid = Tx.thread_id([128]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) - with Tx.warpgroup(): - wid = tid // 32 - lid = tid % 32 - row = lid // 4 - cp = lid % 4 - for t in range(num): - for w in range(2): - gr, gc = _coord(wid, row, cp, t, w) - A_smem[wid, row, cp, t, w] = A[gr, gc] - Tx.cuda.cta_sync() - R_local = Tx.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) - Tx.copy(R_local[full], A_smem[full]) - r_view = R_local.local() - for t in range(num): - for w in range(2): - gr, gc = _coord(wid, row, cp, t, w) - B[gr, gc] = r_view[t * 2 + w] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M, N), "float16") + B = T.match_buffer(B_ptr, (M, N), "float16") + T.device_entry() + T.cta_id([1]) + T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + T.thread_id_in_wg([128]) + tid = T.thread_id([128]) + A_smem = T.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) + wid = tid // 32 + lid = tid % 32 + row = lid // 4 + cp = lid % 4 + for t in range(num): + for w in range(2): + gr, gc = _coord(wid, row, cp, t, w) + A_smem[wid, row, cp, t, w] = A[gr, gc] + T.cuda.cta_sync() + R_local = T.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) + Tx.wg.copy(R_local[full], A_smem[full]) + r_view = R_local.local() + for t in range(num): + for w in range(2): + gr, gc = _coord(wid, row, cp, t, w) + B[gr, gc] = r_view[t * 2 + w] else: - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M, N), "float16") - B = Tx.match_buffer(B_ptr, (M, N), "float16") - Tx.device_entry() - Tx.cta_id([1]) - Tx.warpgroup_id([1]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - Tx.thread_id_in_wg([128]) - tid = Tx.thread_id([128]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) - with Tx.warpgroup(): - wid = tid // 32 - lid = tid % 32 - row = lid // 4 - cp = lid % 4 - R_local = Tx.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) - r_view = R_local.local() - for t in range(num): - for w in range(2): - gr, gc = _coord(wid, row, cp, t, w) - r_view[t * 2 + w] = A[gr, gc] - Tx.copy(A_smem[full], R_local[full]) - Tx.cuda.cta_sync() - for t in range(num): - for w in range(2): - gr, gc = _coord(wid, row, cp, t, w) - B[gr, gc] = A_smem[wid, row, cp, t, w] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M, N), "float16") + B = T.match_buffer(B_ptr, (M, N), "float16") + T.device_entry() + T.cta_id([1]) + T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + T.thread_id_in_wg([128]) + tid = T.thread_id([128]) + A_smem = T.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) + wid = tid // 32 + lid = tid % 32 + row = lid // 4 + cp = lid % 4 + R_local = T.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) + r_view = R_local.local() + for t in range(num): + for w in range(2): + gr, gc = _coord(wid, row, cp, t, w) + r_view[t * 2 + w] = A[gr, gc] + Tx.wg.copy(A_smem[full], R_local[full]) + T.cuda.cta_sync() + for t in range(num): + for w in range(2): + gr, gc = _coord(wid, row, cp, t, w) + B[gr, gc] = A_smem[wid, row, cp, t, w] # fmt: on return kernel, (M, N) @@ -258,61 +251,59 @@ def _coord(wid, row, cp, t, w): # fmt: off if direction == "ld": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M, N), "float16") - B = Tx.match_buffer(B_ptr, (M, N), "float16") - Tx.device_entry() - Tx.cta_id([1]) - Tx.warp_id([4]) - Tx.lane_id([32]) - tid = Tx.thread_id([128]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) - wid = tid // 32 - lid = tid % 32 - row = lid // 4 - cp = lid % 4 - for t in range(num): - for w in range(2): - gr, gc = _coord(wid, row, cp, t, w) - A_smem[wid, row, cp, t, w] = A[gr, gc] - Tx.cuda.cta_sync() - R_local = Tx.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) - Tx.copy(R_local[full], A_smem[full]) - r_view = R_local.local() - for t in range(num): - for w in range(2): - gr, gc = _coord(wid, row, cp, t, w) - B[gr, gc] = r_view[t * 2 + w] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M, N), "float16") + B = T.match_buffer(B_ptr, (M, N), "float16") + T.device_entry() + T.cta_id([1]) + T.warp_id([4]) + T.lane_id([32]) + tid = T.thread_id([128]) + A_smem = T.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) + wid = tid // 32 + lid = tid % 32 + row = lid // 4 + cp = lid % 4 + for t in range(num): + for w in range(2): + gr, gc = _coord(wid, row, cp, t, w) + A_smem[wid, row, cp, t, w] = A[gr, gc] + T.cuda.cta_sync() + R_local = T.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) + Tx.cta.copy(R_local[full], A_smem[full]) + r_view = R_local.local() + for t in range(num): + for w in range(2): + gr, gc = _coord(wid, row, cp, t, w) + B[gr, gc] = r_view[t * 2 + w] else: - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M, N), "float16") - B = Tx.match_buffer(B_ptr, (M, N), "float16") - Tx.device_entry() - Tx.cta_id([1]) - Tx.warp_id([4]) - Tx.lane_id([32]) - tid = Tx.thread_id([128]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) - wid = tid // 32 - lid = tid % 32 - row = lid // 4 - cp = lid % 4 - R_local = Tx.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) - r_view = R_local.local() - for t in range(num): - for w in range(2): - gr, gc = _coord(wid, row, cp, t, w) - r_view[t * 2 + w] = A[gr, gc] - Tx.copy(A_smem[full], R_local[full]) - Tx.cuda.cta_sync() - for t in range(num): - for w in range(2): - gr, gc = _coord(wid, row, cp, t, w) - B[gr, gc] = A_smem[wid, row, cp, t, w] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M, N), "float16") + B = T.match_buffer(B_ptr, (M, N), "float16") + T.device_entry() + T.cta_id([1]) + T.warp_id([4]) + T.lane_id([32]) + tid = T.thread_id([128]) + A_smem = T.alloc_buffer(s_shape, "float16", scope="shared", layout=s_layout) + wid = tid // 32 + lid = tid % 32 + row = lid // 4 + cp = lid % 4 + R_local = T.alloc_buffer(s_shape, "float16", scope="local", layout=r_layout) + r_view = R_local.local() + for t in range(num): + for w in range(2): + gr, gc = _coord(wid, row, cp, t, w) + r_view[t * 2 + w] = A[gr, gc] + Tx.cta.copy(A_smem[full], R_local[full]) + T.cuda.cta_sync() + for t in range(num): + for w in range(2): + gr, gc = _coord(wid, row, cp, t, w) + B[gr, gc] = A_smem[wid, row, cp, t, w] # fmt: on return kernel, (M, N) @@ -406,35 +397,29 @@ def _build_multi_iter_kernel(outer_ext: int): s_layout = SwizzleLayout(3, 3, 3) full = tuple(slice(0, e) for e in shape) - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, "float16") - B = Tx.match_buffer(B_ptr, shape, "float16") - Tx.device_entry() - Tx.cta_id([1]) - Tx.lane_id([32]) - tid = Tx.thread_id([32]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, "float16", scope="shared", layout=s_layout) - with Tx.warp(): - for a in range(outer_ext): - for c in range(2): - for d in range(4): - for e in range(2): - A_smem[a, tid // 4, c, d, tid % 4, e] = A[ - a, tid // 4, c, d, tid % 4, e - ] - Tx.cuda.cta_sync() - R_local = Tx.alloc_buffer(shape, "float16", scope="local", layout=r_layout) - Tx.copy(R_local[full], A_smem[full]) - r_view = R_local.local() - for a in range(outer_ext): - for c in range(2): - for d in range(4): - for e in range(2): - B[a, tid // 4, c, d, tid % 4, e] = r_view[ - a * 16 + c * 8 + d * 2 + e - ] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, "float16") + B = T.match_buffer(B_ptr, shape, "float16") + T.device_entry() + T.cta_id([1]) + T.lane_id([32]) + tid = T.thread_id([32]) + A_smem = T.alloc_buffer(shape, "float16", scope="shared", layout=s_layout) + for a in range(outer_ext): + for c in range(2): + for d in range(4): + for e in range(2): + A_smem[a, tid // 4, c, d, tid % 4, e] = A[a, tid // 4, c, d, tid % 4, e] + T.cuda.cta_sync() + R_local = T.alloc_buffer(shape, "float16", scope="local", layout=r_layout) + Tx.warp.copy(R_local[full], A_smem[full]) + r_view = R_local.local() + for a in range(outer_ext): + for c in range(2): + for d in range(4): + for e in range(2): + B[a, tid // 4, c, d, tid % 4, e] = r_view[a * 16 + c * 8 + d * 2 + e] return kernel, shape diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py index 3e3bca1de601..451622530318 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py @@ -33,7 +33,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import S, TileLayout, laneid, tid_in_wg, tx @@ -65,162 +66,152 @@ def _build_roundtrip_kernel(scope, n_threads, k, dtype, non_r_scope): if scope == "warpgroup": - @Tx.prim_func - def kernel(B_ptr: Tx.handle) -> None: - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.warpgroup_id([n_threads // 128]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - Tx.thread_id_in_wg([128]) - tid = Tx.thread_id([n_threads]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - with Tx.warpgroup(): - for kk in range(k): - A_smem[tid, kk] = Tx.cast(tid * 100 + kk + 1, dtype) - Tx.cuda.cta_sync() - R_local = Tx.alloc_buffer(shape, dtype, scope="local", layout=r_layout) - Tx.copy(R_local[full_slices], A_smem[full_slices]) - for kk in range(k): - A_smem[tid, kk] = Tx.cast(0, dtype) - Tx.cuda.cta_sync() - Tx.copy(A_smem[full_slices], R_local[full_slices]) - Tx.cuda.cta_sync() - for kk in range(k): - B[tid, kk] = A_smem[tid, kk] + @T.prim_func + def kernel(B_ptr: T.handle) -> None: + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.warpgroup_id([n_threads // 128]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + T.thread_id_in_wg([128]) + tid = T.thread_id([n_threads]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + for kk in range(k): + A_smem[tid, kk] = T.cast(tid * 100 + kk + 1, dtype) + T.cuda.cta_sync() + R_local = T.alloc_buffer(shape, dtype, scope="local", layout=r_layout) + Tx.wg.copy(R_local[full_slices], A_smem[full_slices]) + for kk in range(k): + A_smem[tid, kk] = T.cast(0, dtype) + T.cuda.cta_sync() + Tx.wg.copy(A_smem[full_slices], R_local[full_slices]) + T.cuda.cta_sync() + for kk in range(k): + B[tid, kk] = A_smem[tid, kk] elif scope == "warp": - @Tx.prim_func - def kernel(B_ptr: Tx.handle) -> None: - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.lane_id([32]) - tid = Tx.thread_id([n_threads]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - with Tx.warp(): - for kk in range(k): - A_smem[tid, kk] = Tx.cast(tid * 100 + kk + 1, dtype) - Tx.cuda.cta_sync() - R_local = Tx.alloc_buffer(shape, dtype, scope="local", layout=r_layout) - Tx.copy(R_local[full_slices], A_smem[full_slices]) - for kk in range(k): - A_smem[tid, kk] = Tx.cast(0, dtype) - Tx.cuda.cta_sync() - Tx.copy(A_smem[full_slices], R_local[full_slices]) - Tx.cuda.cta_sync() - for kk in range(k): - B[tid, kk] = A_smem[tid, kk] + @T.prim_func + def kernel(B_ptr: T.handle) -> None: + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.lane_id([32]) + tid = T.thread_id([n_threads]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + for kk in range(k): + A_smem[tid, kk] = T.cast(tid * 100 + kk + 1, dtype) + T.cuda.cta_sync() + R_local = T.alloc_buffer(shape, dtype, scope="local", layout=r_layout) + Tx.warp.copy(R_local[full_slices], A_smem[full_slices]) + for kk in range(k): + A_smem[tid, kk] = T.cast(0, dtype) + T.cuda.cta_sync() + Tx.warp.copy(A_smem[full_slices], R_local[full_slices]) + T.cuda.cta_sync() + for kk in range(k): + B[tid, kk] = A_smem[tid, kk] elif scope == "cta": - @Tx.prim_func - def kernel(B_ptr: Tx.handle) -> None: - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.warp_id([n_threads // 32]) - Tx.lane_id([32]) - tid = Tx.thread_id([n_threads]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) - for kk in range(k): - A_smem[tid, kk] = Tx.cast(tid * 100 + kk + 1, dtype) - Tx.cuda.cta_sync() - R_local = Tx.alloc_buffer(shape, dtype, scope="local", layout=r_layout) - Tx.copy(R_local[full_slices], A_smem[full_slices]) - for kk in range(k): - A_smem[tid, kk] = Tx.cast(0, dtype) - Tx.cuda.cta_sync() - Tx.copy(A_smem[full_slices], R_local[full_slices]) - Tx.cuda.cta_sync() - for kk in range(k): - B[tid, kk] = A_smem[tid, kk] + @T.prim_func + def kernel(B_ptr: T.handle) -> None: + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.warp_id([n_threads // 32]) + T.lane_id([32]) + tid = T.thread_id([n_threads]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=s_layout) + for kk in range(k): + A_smem[tid, kk] = T.cast(tid * 100 + kk + 1, dtype) + T.cuda.cta_sync() + R_local = T.alloc_buffer(shape, dtype, scope="local", layout=r_layout) + Tx.cta.copy(R_local[full_slices], A_smem[full_slices]) + for kk in range(k): + A_smem[tid, kk] = T.cast(0, dtype) + T.cuda.cta_sync() + Tx.cta.copy(A_smem[full_slices], R_local[full_slices]) + T.cuda.cta_sync() + for kk in range(k): + B[tid, kk] = A_smem[tid, kk] return kernel if non_r_scope == "global": if scope == "warpgroup": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.warpgroup_id([n_threads // 128]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - Tx.thread_id_in_wg([128]) - tid = Tx.thread_id([n_threads]) - with Tx.cta(): - with Tx.warpgroup(): - for kk in range(k): - A[tid, kk] = Tx.cast(tid * 100 + kk + 1, dtype) - Tx.cuda.cta_sync() - R_local = Tx.alloc_buffer(shape, dtype, scope="local", layout=r_layout) - Tx.copy(R_local[full_slices], A[full_slices]) - for kk in range(k): - A[tid, kk] = Tx.cast(0, dtype) - Tx.cuda.cta_sync() - Tx.copy(A[full_slices], R_local[full_slices]) - Tx.cuda.cta_sync() - for kk in range(k): - B[tid, kk] = A[tid, kk] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.warpgroup_id([n_threads // 128]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + T.thread_id_in_wg([128]) + tid = T.thread_id([n_threads]) + for kk in range(k): + A[tid, kk] = T.cast(tid * 100 + kk + 1, dtype) + T.cuda.cta_sync() + R_local = T.alloc_buffer(shape, dtype, scope="local", layout=r_layout) + Tx.wg.copy(R_local[full_slices], A[full_slices]) + for kk in range(k): + A[tid, kk] = T.cast(0, dtype) + T.cuda.cta_sync() + Tx.wg.copy(A[full_slices], R_local[full_slices]) + T.cuda.cta_sync() + for kk in range(k): + B[tid, kk] = A[tid, kk] elif scope == "warp": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.lane_id([32]) - tid = Tx.thread_id([n_threads]) - with Tx.cta(): - with Tx.warp(): - for kk in range(k): - A[tid, kk] = Tx.cast(tid * 100 + kk + 1, dtype) - Tx.cuda.cta_sync() - R_local = Tx.alloc_buffer(shape, dtype, scope="local", layout=r_layout) - Tx.copy(R_local[full_slices], A[full_slices]) - for kk in range(k): - A[tid, kk] = Tx.cast(0, dtype) - Tx.cuda.cta_sync() - Tx.copy(A[full_slices], R_local[full_slices]) - Tx.cuda.cta_sync() - for kk in range(k): - B[tid, kk] = A[tid, kk] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.lane_id([32]) + tid = T.thread_id([n_threads]) + for kk in range(k): + A[tid, kk] = T.cast(tid * 100 + kk + 1, dtype) + T.cuda.cta_sync() + R_local = T.alloc_buffer(shape, dtype, scope="local", layout=r_layout) + Tx.warp.copy(R_local[full_slices], A[full_slices]) + for kk in range(k): + A[tid, kk] = T.cast(0, dtype) + T.cuda.cta_sync() + Tx.warp.copy(A[full_slices], R_local[full_slices]) + T.cuda.cta_sync() + for kk in range(k): + B[tid, kk] = A[tid, kk] elif scope == "cta": - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - Tx.device_entry() - Tx.cta_id([1]) - Tx.warp_id([n_threads // 32]) - Tx.lane_id([32]) - tid = Tx.thread_id([n_threads]) - with Tx.cta(): - for kk in range(k): - A[tid, kk] = Tx.cast(tid * 100 + kk + 1, dtype) - Tx.cuda.cta_sync() - R_local = Tx.alloc_buffer(shape, dtype, scope="local", layout=r_layout) - Tx.copy(R_local[full_slices], A[full_slices]) - for kk in range(k): - A[tid, kk] = Tx.cast(0, dtype) - Tx.cuda.cta_sync() - Tx.copy(A[full_slices], R_local[full_slices]) - Tx.cuda.cta_sync() - for kk in range(k): - B[tid, kk] = A[tid, kk] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + T.device_entry() + T.cta_id([1]) + T.warp_id([n_threads // 32]) + T.lane_id([32]) + tid = T.thread_id([n_threads]) + for kk in range(k): + A[tid, kk] = T.cast(tid * 100 + kk + 1, dtype) + T.cuda.cta_sync() + R_local = T.alloc_buffer(shape, dtype, scope="local", layout=r_layout) + Tx.cta.copy(R_local[full_slices], A[full_slices]) + for kk in range(k): + A[tid, kk] = T.cast(0, dtype) + T.cuda.cta_sync() + Tx.cta.copy(A[full_slices], R_local[full_slices]) + T.cuda.cta_sync() + for kk in range(k): + B[tid, kk] = A[tid, kk] return kernel @@ -305,19 +296,17 @@ def test_copy_g2l_l2g_vec_load(task, dtype): r_lmem = tuple(slice(None) for _ in range(len(l_shape))) r_gmem = tuple(slice(g_region[i][0], g_region[i][1]) for i in range(len(g_shape))) - @Tx.prim_func - def copy_sync(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) - B = Tx.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) + @T.prim_func + def copy_sync(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) + B = T.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) - Tx.device_entry() - Tx.cta_id([2]) - Tx.thread_id([thread_cnt]) - - with Tx.thread(): - A_local = Tx.alloc_buffer(l_shape, dtype, scope="local", layout=layoutLocal) - Tx.copy(A_local[r_lmem], A[r_gmem]) - Tx.copy(B[r_gmem], A_local[r_lmem]) + T.device_entry() + T.cta_id([2]) + T.thread_id([thread_cnt]) + A_local = T.alloc_buffer(l_shape, dtype, scope="local", layout=layoutLocal) + Tx.copy(A_local[r_lmem], A[r_gmem]) + Tx.copy(B[r_gmem], A_local[r_lmem]) np_dtype = tvm.testing.np_dtype_from_str(dtype) target = tvm.target.Target("cuda") @@ -356,7 +345,7 @@ def test_reg_copy_wg_local_to_swizzled_shared_uses_swizzle_fastpath(): (Python ``range`` doesn't actually unroll in TVMScript) the swizzle fast path's per-iter constant-fold can't kick in and the ``tvm_builtin_pointer_offset`` swizzle XOR ends up recomputed every - iteration. Loop must be ``Tx.unroll``. + iteration. Loop must be ``T.unroll``. """ from tvm.tirx.layout import SwizzleLayout, wg_local_layout @@ -366,30 +355,27 @@ def test_reg_copy_wg_local_to_swizzled_shared_uses_swizzle_fastpath(): # 128b swizzle on the SMEM side (per_element=3 ⇒ 8 fp16 atom width). smem_layout = SwizzleLayout(per_element=3, swizzle_len=3, atom_len=3) - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, "float16", layout=g_layout) - B = Tx.match_buffer(B_ptr, g_shape, "float16", layout=g_layout) - - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([N_THREADS]) - tid = Tx.thread_id_in_wg([N_THREADS]) - - with Tx.thread(): - reg = Tx.alloc_buffer(g_shape, "float16", scope="local", layout=wg_local_layout(EPI_N)) - smem = Tx.alloc_buffer(g_shape, "float16", scope="shared", layout=smem_layout) - - # Populate the per-thread slice via .local() (decomposes the wg - # thread-axis layout into a per-thread 1D view). - reg_local = reg.local(EPI_N) - for i in Tx.serial(EPI_N): - reg_local[i] = A[tid, i] - with Tx.warpgroup(): - Tx.copy(smem, reg) - Tx.cuda.cta_sync() - for i in Tx.serial(EPI_N): - B[tid, i] = smem[tid, i] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, "float16", layout=g_layout) + B = T.match_buffer(B_ptr, g_shape, "float16", layout=g_layout) + + T.device_entry() + T.cta_id([1]) + T.thread_id([N_THREADS]) + tid = T.thread_id_in_wg([N_THREADS]) + reg = T.alloc_buffer(g_shape, "float16", scope="local", layout=wg_local_layout(EPI_N)) + smem = T.alloc_buffer(g_shape, "float16", scope="shared", layout=smem_layout) + + # Populate the per-thread slice via .local() (decomposes the wg + # thread-axis layout into a per-thread 1D view). + reg_local = reg.local(EPI_N) + for i in T.serial(EPI_N): + reg_local[i] = A[tid, i] + Tx.wg.copy(smem, reg) + T.cuda.cta_sync() + for i in T.serial(EPI_N): + B[tid, i] = smem[tid, i] target = tvm.target.Target("cuda") with target: diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py index 2372347c951a..3e3070e8994f 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py @@ -28,7 +28,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx import IntImm, Var from tvm.tirx.exec_scope import ExecScope from tvm.tirx.layout import S, TileLayout @@ -69,7 +70,7 @@ def visit_for_(self, op): def visit_evaluate_(self, op): if isinstance(op.value, tvm.tirx.Call): - if op.value.op.name == "tirx.ptx_cp_async_bulk_shared_to_cluster": + if op.value.op.name == "tirx.ptx.cp_async_bulk_shared_to_cluster": n = 1 for e in self._loop_extents: n *= e @@ -127,7 +128,7 @@ def test_dsmem(shape, dtype, src_spec, dst_spec, expected): """Dispatch assertion + GPU correctness for DSMEM copy. Always tests dispatch (s2c op count or DispatchFail). - For non-fail cases: also runs a 2-CTA cluster kernel via Tx.copy_async + For non-fail cases: also runs a 2-CTA cluster kernel via T.copy_async dispatch (using src_spec as layout for both CTAs) and verifies correctness. """ from tvm.tirx.lang.pipeline import MBarrier @@ -157,54 +158,51 @@ def test_dsmem(shape, dtype, src_spec, dst_spec, expected): r = tuple(slice(0, s) for s in shape) # fmt: off - @Tx.prim_func - def dsmem_copy(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype) - B = Tx.match_buffer(B_ptr, shape, dtype) - - Tx.device_entry() - cbx = Tx.cta_id_in_cluster([CLUSTER_N]) - Tx.cta_id([CLUSTER_N]) - tid = Tx.thread_id([1]) - - with Tx.cta(): - pool = Tx.SMEMPool() - # src_smem: CTA 0 writes here, dispatch reads from here - src_raw = pool.alloc([src_phys], dtype, align=128) - src_smem = Tx.decl_buffer( - list(shape), dtype, src_raw.data, - elem_offset=0, scope="shared.dyn", layout=src_layout, - ) - # dst_smem: dispatch writes here (on remote CTA), CTA 1 reads - dst_raw = pool.alloc([dst_phys], dtype, align=128) - dst_smem = Tx.decl_buffer( - list(shape), dtype, dst_raw.data, - elem_offset=0, scope="shared.dyn", layout=dst_layout, - ) - mbar = MBarrier(pool, 1) - pool.commit() - - mbar.init(1) - Tx.ptx.fence.mbarrier_init() - Tx.cuda.cluster_sync() - - if tid == 0: - with Tx.thread(): - if cbx == 0: - Tx.copy(src_smem[r], A[r]) - Tx.ptx.fence.proxy_async("shared::cta") - - Tx.copy_async( - dst_smem[r], src_smem[r], - dispatch="dsmem", - mbar=mbar.ptr_to([0]), - remote_cta_id=Tx.int32(1), - ) - else: - Tx.ptx.mbarrier.arrive.expect_tx(mbar.ptr_to([0]), copy_bytes) - mbar.wait(0, 0) - - Tx.copy(B[r], dst_smem[r]) + @T.prim_func + def dsmem_copy(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype) + B = T.match_buffer(B_ptr, shape, dtype) + + T.device_entry() + cbx = T.cta_id_in_cluster([CLUSTER_N]) + T.cta_id([CLUSTER_N]) + tid = T.thread_id([1]) + pool = T.SMEMPool() + # src_smem: CTA 0 writes here, dispatch reads from here + src_raw = pool.alloc([src_phys], dtype, align=128) + src_smem = T.decl_buffer( + list(shape), dtype, src_raw.data, + elem_offset=0, scope="shared.dyn", layout=src_layout, + ) + # dst_smem: dispatch writes here (on remote CTA), CTA 1 reads + dst_raw = pool.alloc([dst_phys], dtype, align=128) + dst_smem = T.decl_buffer( + list(shape), dtype, dst_raw.data, + elem_offset=0, scope="shared.dyn", layout=dst_layout, + ) + mbar = MBarrier(pool, 1) + pool.commit() + + mbar.init(1) + T.ptx.fence.mbarrier_init() + T.cuda.cluster_sync() + + if tid == 0: + if cbx == 0: + Tx.copy(src_smem[r], A[r]) + T.ptx.fence.proxy_async("shared::cta") + + Tx.copy_async( + dst_smem[r], src_smem[r], + dispatch="dsmem", + mbar=mbar.ptr_to([0]), + remote_cta_id=T.int32(1), + ) + else: + T.ptx.mbarrier.arrive.expect_tx(mbar.ptr_to([0]), copy_bytes) + mbar.wait(0, 0) + + Tx.copy(B[r], dst_smem[r]) # fmt: on np_dtype = tvm.testing.np_dtype_from_str(dtype) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py index 08ee93b3b6ba..b4d54d2b4109 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py @@ -22,7 +22,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import S, TileLayout @@ -75,22 +76,21 @@ def test_copy_g2s_s2g_cta_vec_load(task, dtype): r_gmem = list(slice(g_st[i], g_st[i] + g_extent[i]) for i in range(len(g_shape))) # fmt: off - @Tx.prim_func - def copy_async(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) - B = Tx.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) + @T.prim_func + def copy_async(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) + B = T.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([thread_cnt]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, dtype, scope="shared", layout=layoutS) + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([thread_cnt]) + A_smem = T.alloc_buffer(s_shape, dtype, scope="shared", layout=layoutS) - Tx.copy_async(A_smem[tuple(r_smem)], A[tuple(r_gmem)], dispatch="ldgsts") - Tx.ptx.cp_async.commit_group() - Tx.ptx.cp_async.wait_group() - Tx.cuda.cta_sync() - Tx.copy(B[tuple(r_gmem)], A_smem[tuple(r_smem)]) + Tx.cta.copy_async(A_smem[tuple(r_smem)], A[tuple(r_gmem)], dispatch="ldgsts") + T.ptx.cp_async.commit_group() + T.ptx.cp_async.wait_group() + T.cuda.cta_sync() + Tx.cta.copy(B[tuple(r_gmem)], A_smem[tuple(r_smem)]) # fmt: on np_dtype = tvm.testing.np_dtype_from_str(dtype) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py index a01ee2a95928..036bd786a24a 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py @@ -29,7 +29,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import R, S, TCol, TileLayout, TLane from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode, mma_shared_layout @@ -57,75 +58,66 @@ def _make_2d_kernel( OUT_LANES = 32 OUT_BYTES = 16 - @Tx.prim_func(check_well_formed=False) - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, s_full_shape, dtype) - B = Tx.match_buffer(B_ptr, (OUT_LANES, OUT_BYTES), dtype) - Tx.device_entry() - warp_id = Tx.warp_id([4]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - lane_id = Tx.lane_id([32]) - A_smem = Tx.alloc_buffer(s_full_shape, dtype, scope="shared", layout=s_full, align=1024) - tmem_addr = Tx.alloc_shared([1], "uint32") - cp_mbar = Tx.alloc_shared([1], "uint64") + @T.prim_func(check_well_formed=False) + def kernel(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, s_full_shape, dtype) + B = T.match_buffer(B_ptr, (OUT_LANES, OUT_BYTES), dtype) + T.device_entry() + warp_id = T.warp_id([4]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + lane_id = T.lane_id([32]) + A_smem = T.alloc_buffer(s_full_shape, dtype, scope="shared", layout=s_full, align=1024) + tmem_addr = T.alloc_shared([1], "uint32") + cp_mbar = T.alloc_shared([1], "uint64") if wg_id == 0: - with Tx.warpgroup(): - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc( - Tx.address_of(tmem_addr), - n_cols=n_tmem_cols_total, - cta_group=cta_group, - ) - if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(cp_mbar.ptr_to([0]), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - with Tx.cta(): - Tx.copy(A_smem[:, :], A[:, :]) - Tx.cuda.cta_sync() - tmem = Tx.decl_buffer( - t_full_shape, - dtype, - scope="tmem", - allocated_addr=tmem_addr[0], - layout=t_full, + if warp_id == 0: + T.ptx.tcgen05.alloc( + T.address_of(tmem_addr), + n_cols=n_tmem_cols_total, + cta_group=cta_group, ) - if tid_in_wg == 0: - with Tx.thread(): - Tx.copy_async( - tmem[t_r0:t_r1, t_c0:t_c1], - A_smem[s_r0:s_r1, s_c0:s_c1], - cta_group=cta_group, - ) - Tx.ptx.tcgen05.commit(cp_mbar.ptr_to([0]), cta_group=cta_group) - Tx.ptx.mbarrier.try_wait(cp_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() - Tx.ptx.tcgen05.fence.after_thread_sync() - if warp_id == 0: - with Tx.warp(): - reg = Tx.alloc_buffer((4,), "uint32", scope="local") - for i in range(4): - Tx.ptx.tcgen05.ld( - tmem.allocated_addr[0], - reg[i], - shape="32x32b", - num=1, - row=0, - col=i, - ) - Tx.ptx.tcgen05.wait.ld() - B_bytes = reg.view(dtype) - for i in range(OUT_BYTES): - B[lane_id, i] = B_bytes[i] - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) - Tx.ptx.tcgen05.dealloc( - tmem_addr[0], n_cols=n_tmem_cols_total, cta_group=cta_group - ) + if tid_in_wg == 0: + T.ptx.mbarrier.init(cp_mbar.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + Tx.cta.copy(A_smem[:, :], A[:, :]) + T.cuda.cta_sync() + tmem = T.decl_buffer( + t_full_shape, + dtype, + scope="tmem", + allocated_addr=tmem_addr[0], + layout=t_full, + ) + if tid_in_wg == 0: + Tx.copy_async( + tmem[t_r0:t_r1, t_c0:t_c1], + A_smem[s_r0:s_r1, s_c0:s_c1], + cta_group=cta_group, + ) + T.ptx.tcgen05.commit(cp_mbar.ptr_to([0]), cta_group=cta_group) + T.ptx.mbarrier.try_wait(cp_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + T.ptx.tcgen05.fence.after_thread_sync() + if warp_id == 0: + reg = T.alloc_buffer((4,), "uint32", scope="local") + for i in range(4): + T.ptx.tcgen05.ld( + tmem.allocated_addr[0], + reg[i], + shape="32x32b", + num=1, + row=0, + col=i, + ) + T.ptx.tcgen05.wait.ld() + B_bytes = reg.view(dtype) + for i in range(OUT_BYTES): + B[lane_id, i] = B_bytes[i] + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=n_tmem_cols_total, cta_group=cta_group) return kernel @@ -134,75 +126,66 @@ def _make_3d_4tile_kernel(s_full, t_full, s_full_shape, t_full_shape, dtype, cta """3D variant: 4 stacked tiles (NVFP4-style multi-cp test).""" n_tmem_cols_total = max(32, t_full_shape[-1]) - @Tx.prim_func(check_well_formed=False) - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, s_full_shape, dtype) - B = Tx.match_buffer(B_ptr, (32, 16), dtype) - Tx.device_entry() - warp_id = Tx.warp_id([4]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - lane_id = Tx.lane_id([32]) - A_smem = Tx.alloc_buffer(s_full_shape, dtype, scope="shared", layout=s_full, align=1024) - tmem_addr = Tx.alloc_shared([1], "uint32") - cp_mbar = Tx.alloc_shared([1], "uint64") + @T.prim_func(check_well_formed=False) + def kernel(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, s_full_shape, dtype) + B = T.match_buffer(B_ptr, (32, 16), dtype) + T.device_entry() + warp_id = T.warp_id([4]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + lane_id = T.lane_id([32]) + A_smem = T.alloc_buffer(s_full_shape, dtype, scope="shared", layout=s_full, align=1024) + tmem_addr = T.alloc_shared([1], "uint32") + cp_mbar = T.alloc_shared([1], "uint64") if wg_id == 0: - with Tx.warpgroup(): - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc( - Tx.address_of(tmem_addr), - n_cols=n_tmem_cols_total, - cta_group=cta_group, - ) - if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(cp_mbar.ptr_to([0]), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - with Tx.cta(): - Tx.copy(A_smem[:, :, :], A[:, :, :]) - Tx.cuda.cta_sync() - tmem = Tx.decl_buffer( - t_full_shape, - dtype, - scope="tmem", - allocated_addr=tmem_addr[0], - layout=t_full, + if warp_id == 0: + T.ptx.tcgen05.alloc( + T.address_of(tmem_addr), + n_cols=n_tmem_cols_total, + cta_group=cta_group, + ) + if tid_in_wg == 0: + T.ptx.mbarrier.init(cp_mbar.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + Tx.cta.copy(A_smem[:, :, :], A[:, :, :]) + T.cuda.cta_sync() + tmem = T.decl_buffer( + t_full_shape, + dtype, + scope="tmem", + allocated_addr=tmem_addr[0], + layout=t_full, + ) + if tid_in_wg == 0: + Tx.copy_async( + tmem[:, :, :], + A_smem[:, :, :], + cta_group=cta_group, ) - if tid_in_wg == 0: - with Tx.thread(): - Tx.copy_async( - tmem[:, :, :], - A_smem[:, :, :], - cta_group=cta_group, - ) - Tx.ptx.tcgen05.commit(cp_mbar.ptr_to([0]), cta_group=cta_group) - Tx.ptx.mbarrier.try_wait(cp_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() - Tx.ptx.tcgen05.fence.after_thread_sync() - if warp_id == 0: - with Tx.warp(): - reg = Tx.alloc_buffer((4,), "uint32", scope="local") - for i in range(4): - Tx.ptx.tcgen05.ld( - tmem.allocated_addr[0], - reg[i], - shape="32x32b", - num=1, - row=0, - col=i, - ) - Tx.ptx.tcgen05.wait.ld() - B_bytes = reg.view(dtype) - for i in range(16): - B[lane_id, i] = B_bytes[i] - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) - Tx.ptx.tcgen05.dealloc( - tmem_addr[0], n_cols=n_tmem_cols_total, cta_group=cta_group - ) + T.ptx.tcgen05.commit(cp_mbar.ptr_to([0]), cta_group=cta_group) + T.ptx.mbarrier.try_wait(cp_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + T.ptx.tcgen05.fence.after_thread_sync() + if warp_id == 0: + reg = T.alloc_buffer((4,), "uint32", scope="local") + for i in range(4): + T.ptx.tcgen05.ld( + tmem.allocated_addr[0], + reg[i], + shape="32x32b", + num=1, + row=0, + col=i, + ) + T.ptx.tcgen05.wait.ld() + B_bytes = reg.view(dtype) + for i in range(16): + B[lane_id, i] = B_bytes[i] + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=n_tmem_cols_total, cta_group=cta_group) return kernel @@ -327,67 +310,58 @@ def test_align_middle_2_to_1_nvfp4_sfb(): t_full_shape = [256, 16] n_tmem_cols_total = max(32, 32) # SFB occupies 32 cols total (8*4 elements / 4 epc) - @Tx.prim_func(check_well_formed=False) - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, s_full_shape, "uint8") - B = Tx.match_buffer(B_ptr, (32, 16), "uint8") - Tx.device_entry() - warp_id = Tx.warp_id([4]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - lane_id = Tx.lane_id([32]) - A_smem = Tx.alloc_buffer(s_full_shape, "uint8", scope="shared", layout=s_full, align=1024) - tmem_addr = Tx.alloc_shared([1], "uint32") - cp_mbar = Tx.alloc_shared([1], "uint64") + @T.prim_func(check_well_formed=False) + def kernel(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, s_full_shape, "uint8") + B = T.match_buffer(B_ptr, (32, 16), "uint8") + T.device_entry() + warp_id = T.warp_id([4]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + lane_id = T.lane_id([32]) + A_smem = T.alloc_buffer(s_full_shape, "uint8", scope="shared", layout=s_full, align=1024) + tmem_addr = T.alloc_shared([1], "uint32") + cp_mbar = T.alloc_shared([1], "uint64") if wg_id == 0: - with Tx.warpgroup(): - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc( - Tx.address_of(tmem_addr), n_cols=n_tmem_cols_total, cta_group=1 - ) - if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(cp_mbar.ptr_to([0]), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - with Tx.cta(): - Tx.copy(A_smem[:, :], A[:, :]) - Tx.cuda.cta_sync() - tmem = Tx.decl_buffer( - t_full_shape, - "uint8", - scope="tmem", - allocated_addr=tmem_addr[0], - layout=t_full, - ) - if tid_in_wg == 0: - with Tx.thread(): - Tx.copy_async(tmem[:, :], A_smem[:, :], cta_group=1) - Tx.ptx.tcgen05.commit(cp_mbar.ptr_to([0]), cta_group=1) - Tx.ptx.mbarrier.try_wait(cp_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() - Tx.ptx.tcgen05.fence.after_thread_sync() - if warp_id == 0: - with Tx.warp(): - reg = Tx.alloc_buffer((4,), "uint32", scope="local") - for i in range(4): - Tx.ptx.tcgen05.ld( - tmem.allocated_addr[0], - reg[i], - shape="32x32b", - num=1, - row=0, - col=i, - ) - Tx.ptx.tcgen05.wait.ld() - B_bytes = reg.view("uint8") - for i in range(16): - B[lane_id, i] = B_bytes[i] - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=n_tmem_cols_total, cta_group=1) + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=n_tmem_cols_total, cta_group=1) + if tid_in_wg == 0: + T.ptx.mbarrier.init(cp_mbar.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + Tx.cta.copy(A_smem[:, :], A[:, :]) + T.cuda.cta_sync() + tmem = T.decl_buffer( + t_full_shape, + "uint8", + scope="tmem", + allocated_addr=tmem_addr[0], + layout=t_full, + ) + if tid_in_wg == 0: + Tx.copy_async(tmem[:, :], A_smem[:, :], cta_group=1) + T.ptx.tcgen05.commit(cp_mbar.ptr_to([0]), cta_group=1) + T.ptx.mbarrier.try_wait(cp_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + T.ptx.tcgen05.fence.after_thread_sync() + if warp_id == 0: + reg = T.alloc_buffer((4,), "uint32", scope="local") + for i in range(4): + T.ptx.tcgen05.ld( + tmem.allocated_addr[0], + reg[i], + shape="32x32b", + num=1, + row=0, + col=i, + ) + T.ptx.tcgen05.wait.ld() + B_bytes = reg.view("uint8") + for i in range(16): + B[lane_id, i] = B_bytes[i] + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=n_tmem_cols_total, cta_group=1) A_np = (np.arange(256 * 16, dtype=np.int32) & 0xFF).astype(np.uint8).reshape(256, 16) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py index 5cb5ab66a0c5..1c4bf5221625 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py @@ -24,7 +24,8 @@ import tvm.testing from tvm.ir import PointerType, PrimType from tvm.ir.type import TensorMapType -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx import IntImm, StringImm, Var from tvm.tirx.exec_scope import ExecScope from tvm.tirx.layout import S, TileLayout @@ -64,9 +65,9 @@ def visit_for_(self, op): def visit_evaluate_(self, op): if isinstance(op.value, tvm.tirx.Call): if op.value.op.name in ( - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster", - "tirx.ptx_cp_async_bulk_tensor_shared_to_global", - "tirx.ptx_cp_async_bulk_tensor_shared_to_global_reduce", + "tirx.ptx.cp_async_bulk_tensor_global_to_cluster", + "tirx.ptx.cp_async_bulk_tensor_shared_to_global", + "tirx.ptx.cp_async_bulk_tensor_shared_to_global_reduce", ): # Multiply all enclosing loop extents iters = 1 @@ -158,7 +159,7 @@ def _build_expected_host_init(dtype, encode_args): + [IntImm("int32", v) for v in encode_args[1:]] ) encode_call = tvm.tirx.Call("int32", tvm.ir.Op.get("tirx.tvm_call_packed"), call_args) - replace_point = TilePrimitiveCall(op=tvm.ir.Op.get("tirx.tvm_kernel_replace_point")) + replace_point = TilePrimitiveCall(op=tvm.ir.Op.get("tirx.tile.tvm_kernel_replace_point")) return tvm.tirx.SeqStmt( [tvm.tirx.Bind(A_tensormap, stack_alloca), tvm.tirx.Evaluate(encode_call), replace_point] ) @@ -234,7 +235,7 @@ def _build_expected_impl(direction, dtype, s_shape, s_layout, impl_spec): if direction == "g2s": # g2c(dim, addr, mbar, tensormap, cta_mask, cta_group, # cache_policy, has_cache_policy, *coords) - ptx_op = tvm.ir.Op.get("tirx.ptx_cp_async_bulk_tensor_global_to_cluster") + ptx_op = tvm.ir.Op.get("tirx.ptx.cp_async_bulk_tensor_global_to_cluster") ptx_args = [ IntImm("int32", dim), addr_of, @@ -248,7 +249,7 @@ def _build_expected_impl(direction, dtype, s_shape, s_layout, impl_spec): ] else: # s2g # s2g(dim, addr, tensormap, cache_policy, has_cache_policy, *coords) - ptx_op = tvm.ir.Op.get("tirx.ptx_cp_async_bulk_tensor_shared_to_global") + ptx_op = tvm.ir.Op.get("tirx.ptx.cp_async_bulk_tensor_shared_to_global") ptx_args = [ IntImm("int32", dim), addr_of, @@ -1067,9 +1068,9 @@ def test_copy_tma_symbolic_dimension(dtype, swizzle_len): dev = tvm.cuda(0) # Shared memory layout with swizzle - shared_layout = Tx.ComposeLayout( - Tx.SwizzleLayout(3, swizzle_len, 3, swizzle_inner=True), - Tx.TileLayout(Tx.S[(SMEM_PIPE_DEPTH, BLK_M, BLK_K) : (BLK_M * BLK_K, BLK_K, 1)]), + shared_layout = T.ComposeLayout( + T.SwizzleLayout(3, swizzle_len, 3, swizzle_inner=True), + T.TileLayout(T.S[(SMEM_PIPE_DEPTH, BLK_M, BLK_K) : (BLK_M * BLK_K, BLK_K, 1)]), ) # Compute bytes for mbarrier @@ -1077,54 +1078,47 @@ def test_copy_tma_symbolic_dimension(dtype, swizzle_len): copy_bytes = BLK_M * BLK_K * tvm.DataType(dtype).bits // 8 # fmt: off - @Tx.prim_func - def copy_async(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - M = Tx.int32() - A = Tx.match_buffer(A_ptr, [M, K], dtype) - B = Tx.match_buffer(B_ptr, [SMEM_PIPE_DEPTH, BLK_M, BLK_K], dtype) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([thread_cnt]) - - with Tx.thread(): - dyn = Tx.alloc_buffer([smem_bytes + 64], "uint8", scope="shared.dyn") - A_smem = Tx.decl_buffer( - [SMEM_PIPE_DEPTH, BLK_M, BLK_K], dtype, dyn.data, elem_offset=0, layout=shared_layout # noqa: E501 - ) - mbarrier = Tx.decl_buffer([1], "uint64", dyn.data, elem_offset=smem_bytes // 8) - mbar_ptr = Tx.meta_var(mbarrier.ptr_to([0])) + @T.prim_func + def copy_async(A_ptr: T.handle, B_ptr: T.handle) -> None: + M = T.int32() + A = T.match_buffer(A_ptr, [M, K], dtype) + B = T.match_buffer(B_ptr, [SMEM_PIPE_DEPTH, BLK_M, BLK_K], dtype) + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([thread_cnt]) + dyn = T.alloc_buffer([smem_bytes + 64], "uint8", scope="shared.dyn") + A_smem = T.decl_buffer( + [SMEM_PIPE_DEPTH, BLK_M, BLK_K], dtype, dyn.data, elem_offset=0, layout=shared_layout + ) + mbarrier = T.decl_buffer([1], "uint64", dyn.data, elem_offset=smem_bytes // 8) + mbar_ptr = T.meta_var(mbarrier.ptr_to([0])) + + if tid == 0: + T.ptx.mbarrier.init(mbar_ptr, 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + # Copy with pipeline index (like hgemm pattern) + for ks in range(SMEM_PIPE_DEPTH): if tid == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(mbar_ptr, 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + Tx.copy_async( + A_smem[ks, :, :], + A[0:BLK_M, ks * BLK_K:(ks + 1) * BLK_K], + dispatch="tma", + mbar=mbar_ptr + ) + T.ptx.mbarrier.arrive.expect_tx(mbar_ptr, copy_bytes) - # Copy with pipeline index (like hgemm pattern) - for ks in range(SMEM_PIPE_DEPTH): - if tid == 0: - with Tx.thread(): - Tx.copy_async( - A_smem[ks, :, :], - A[0:BLK_M, ks * BLK_K:(ks + 1) * BLK_K], - dispatch="tma", - mbar=mbar_ptr - ) - Tx.ptx.mbarrier.arrive.expect_tx(mbar_ptr, copy_bytes) - - Tx.ptx.mbarrier.try_wait(mbar_ptr, ks % 2) - - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - - # Copy back to global for verification - with Tx.cta(): - for ks in range(SMEM_PIPE_DEPTH): - Tx.copy( - B[ks, :, :], - A_smem[ks, :, :] - ) + T.ptx.mbarrier.try_wait(mbar_ptr, ks % 2) + + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + for ks in range(SMEM_PIPE_DEPTH): + Tx.cta.copy( + B[ks, :, :], + A_smem[ks, :, :] + ) # fmt: on np_dtype = tvm.testing.np_dtype_from_str(dtype) @@ -1166,62 +1160,55 @@ def test_copy_tma_3d_with_view(dtype, swizzle_len): copy_bytes_per_blk = 32 * 4 * 64 * tvm.DataType(dtype).bits // 8 # Shared memory layout with swizzle - shared_layout = Tx.ComposeLayout( - Tx.SwizzleLayout(3, swizzle_len, 3, swizzle_inner=True), - Tx.TileLayout(Tx.S[(2, 128, 128) : (128 * 128, 128, 1)]), + shared_layout = T.ComposeLayout( + T.SwizzleLayout(3, swizzle_len, 3, swizzle_inner=True), + T.TileLayout(T.S[(2, 128, 128) : (128 * 128, 128, 1)]), ) # fmt: off - @Tx.prim_func - def copy_async(Q_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - Q = Tx.match_buffer(Q_ptr, (2, 128, 8, 128), dtype) - B = Tx.match_buffer(B_ptr, (32, 4, 64), dtype) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([128]) - - with Tx.thread(): - dyn = Tx.alloc_buffer([smem_bytes + 64], "uint8", scope="shared.dyn") - # Allocate as 4D like FA4: (SMEM_PIPE_DEPTH, NUM_BLK_K, BLK_M, BLK_K) - Q_smem = Tx.decl_buffer( - (2, 2, 128, 64), - dtype, dyn.data, elem_offset=0, layout=shared_layout + @T.prim_func + def copy_async(Q_ptr: T.handle, B_ptr: T.handle) -> None: + Q = T.match_buffer(Q_ptr, (2, 128, 8, 128), dtype) + B = T.match_buffer(B_ptr, (32, 4, 64), dtype) + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([128]) + dyn = T.alloc_buffer([smem_bytes + 64], "uint8", scope="shared.dyn") + # Allocate as 4D like FA4: (SMEM_PIPE_DEPTH, NUM_BLK_K, BLK_M, BLK_K) + Q_smem = T.decl_buffer( + (2, 2, 128, 64), + dtype, dyn.data, elem_offset=0, layout=shared_layout + ) + mbarrier = T.decl_buffer([1], "uint64", dyn.data, elem_offset=smem_bytes // 8) + mbar_ptr = T.meta_var(mbarrier.ptr_to([0])) + + # Create 5D view for 3D copy pattern + Q_smem_5d = Q_smem.view(2, 2, 32, 4, 64) + + if tid == 0: + T.ptx.mbarrier.init(mbar_ptr, 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + + if tid == 0: + # 3D copy: [SEQ_Q_PER_TILE, GQA_RATIO, BLK_K] + Tx.copy_async( + Q_smem_5d[0, 0, :, :, :], + Q[0, 0:32, 0:4, 0:64], + dispatch="tma", + mbar=mbar_ptr ) - mbarrier = Tx.decl_buffer([1], "uint64", dyn.data, elem_offset=smem_bytes // 8) - mbar_ptr = Tx.meta_var(mbarrier.ptr_to([0])) - - # Create 5D view for 3D copy pattern - Q_smem_5d = Q_smem.view(2, 2, 32, 4, 64) + T.ptx.mbarrier.arrive.expect_tx(mbar_ptr, copy_bytes_per_blk) - if tid == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(mbar_ptr, 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.mbarrier.try_wait(mbar_ptr, 0) - if tid == 0: - with Tx.thread(): - # 3D copy: [SEQ_Q_PER_TILE, GQA_RATIO, BLK_K] - Tx.copy_async( - Q_smem_5d[0, 0, :, :, :], - Q[0, 0:32, 0:4, 0:64], - dispatch="tma", - mbar=mbar_ptr - ) - Tx.ptx.mbarrier.arrive.expect_tx(mbar_ptr, copy_bytes_per_blk) - - Tx.ptx.mbarrier.try_wait(mbar_ptr, 0) - - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - - # Copy back to global for verification - with Tx.cta(): - Tx.copy( - B[:, :, :], - Q_smem_5d[0, 0, :, :, :] - ) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + Tx.cta.copy( + B[:, :, :], + Q_smem_5d[0, 0, :, :, :] + ) # fmt: on np_dtype = tvm.testing.np_dtype_from_str(dtype) @@ -1332,41 +1319,36 @@ def r_gmem(stage): ] # fmt: off - @Tx.prim_func - def copy_async(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) - B = Tx.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([thread_cnt]) - - with Tx.thread(): - dyn = Tx.alloc_buffer([smem_bytes + 8], "uint8", scope="shared.dyn") - A_smem = Tx.decl_buffer(s_shape, dtype, dyn.data, elem_offset=0, layout=shared_layout) # noqa: E501 - mbarrier = Tx.decl_buffer([1], "uint64", dyn.data, elem_offset=smem_bytes // 8) - phase: Tx.int32 - - phase = 0 + @T.prim_func + def copy_async(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) + B = T.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([thread_cnt]) + dyn = T.alloc_buffer([smem_bytes + 8], "uint8", scope="shared.dyn") + A_smem = T.decl_buffer(s_shape, dtype, dyn.data, elem_offset=0, layout=shared_layout) + mbarrier = T.decl_buffer([1], "uint64", dyn.data, elem_offset=smem_bytes // 8) + phase: T.int32 + + phase = 0 + if tid == 0: + T.ptx.mbarrier.init(mbarrier.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + + for stage in range(n): if tid == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(mbarrier.ptr_to([0]), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - - for stage in range(n): - if tid == 0: - with Tx.thread(): - Tx.copy_async(A_smem[tuple(r_smem)], A[tuple(r_gmem(stage))], dispatch="tma", mbar=mbarrier.ptr_to([0])) # noqa: E501 - Tx.ptx.mbarrier.arrive.expect_tx(mbarrier.ptr_to([0]), smem_bytes) - - Tx.ptx.mbarrier.try_wait(mbarrier.ptr_to([0]), phase) - phase = phase ^ 1 - - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - with Tx.cta(): - Tx.copy(B[tuple(r_gmem(stage))], A_smem[tuple(r_smem)]) + Tx.copy_async(A_smem[tuple(r_smem)], A[tuple(r_gmem(stage))], dispatch="tma", mbar=mbarrier.ptr_to([0])) # noqa: E501 + T.ptx.mbarrier.arrive.expect_tx(mbarrier.ptr_to([0]), smem_bytes) + + T.ptx.mbarrier.try_wait(mbarrier.ptr_to([0]), phase) + phase = phase ^ 1 + + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + Tx.cta.copy(B[tuple(r_gmem(stage))], A_smem[tuple(r_smem)]) # fmt: on np_dtype = tvm.testing.np_dtype_from_str(dtype) @@ -1396,36 +1378,30 @@ def copy_async(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: r_gmem = [slice(g_region[i][0], g_region[i][1]) for i in range(len(g_shape))] # fmt: off - @Tx.prim_func - def copy_async(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) - B = Tx.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([thread_cnt]) + @T.prim_func + def copy_async(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) + B = T.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([thread_cnt]) + dyn = T.alloc_buffer([smem_bytes + 64], "uint8", scope="shared.dyn") + A_smem = T.decl_buffer(s_shape, dtype, dyn.data, elem_offset=0, layout=shared_layout) + mbarrier = T.decl_buffer([1], "uint64", dyn.data, elem_offset=smem_bytes // 8) + mbar_ptr = T.meta_var(mbarrier.ptr_to([0])) - with Tx.thread(): - dyn = Tx.alloc_buffer([smem_bytes + 64], "uint8", scope="shared.dyn") - A_smem = Tx.decl_buffer(s_shape, dtype, dyn.data, elem_offset=0, layout=shared_layout) # noqa: E501 - mbarrier = Tx.decl_buffer([1], "uint64", dyn.data, elem_offset=smem_bytes // 8) - mbar_ptr = Tx.meta_var(mbarrier.ptr_to([0])) - - if tid == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(mbar_ptr, 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + if tid == 0: + T.ptx.mbarrier.init(mbar_ptr, 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() - if tid == 0: - with Tx.thread(): - Tx.copy_async(A_smem[tuple(r_smem)], A[tuple(r_gmem)], dispatch="tma", mbar=mbar_ptr) # noqa: E501 - Tx.ptx.mbarrier.arrive.expect_tx(mbar_ptr, total_bytes) - Tx.ptx.mbarrier.try_wait(mbar_ptr, 0) - Tx.cuda.cta_sync() - - with Tx.cta(): - Tx.copy(B[tuple(r_gmem)], A_smem[tuple(r_smem)]) + if tid == 0: + Tx.copy_async(A_smem[tuple(r_smem)], A[tuple(r_gmem)], dispatch="tma", mbar=mbar_ptr) # noqa: E501 + T.ptx.mbarrier.arrive.expect_tx(mbar_ptr, total_bytes) + T.ptx.mbarrier.try_wait(mbar_ptr, 0) + T.cuda.cta_sync() + Tx.cta.copy(B[tuple(r_gmem)], A_smem[tuple(r_smem)]) # fmt: on np_dtype = tvm.testing.np_dtype_from_str(dtype) @@ -1470,29 +1446,26 @@ def r_gmem(stage): layoutB = TileLayout(S[3, 8, 256]) # fmt: off - @Tx.prim_func - def copy_async(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) - B = Tx.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([thread_cnt]) - - with Tx.thread(): - dyn = Tx.alloc_buffer([smem_bytes], "uint8", scope="shared.dyn") - A_smem = Tx.decl_buffer(s_shape, dtype, dyn.data, elem_offset=0, layout=shared_layout) - - for stage in range(n): - Tx.copy(A_smem[tuple(r_smem)], A[tuple(r_gmem(stage))]) - Tx.cuda.cta_sync() - Tx.ptx.fence.proxy_async("shared::cta") - if tid == 0: - with Tx.thread(): - Tx.copy_async(B[tuple(r_gmem(stage))], A_smem[tuple(r_smem)], dispatch="tma") # noqa: E501 - Tx.ptx.cp_async.bulk.commit_group() - Tx.ptx.cp_async.bulk.wait_group() - Tx.cuda.cta_sync() + @T.prim_func + def copy_async(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=layoutA) + B = T.match_buffer(B_ptr, g_shape, dtype, layout=layoutB) + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([thread_cnt]) + dyn = T.alloc_buffer([smem_bytes], "uint8", scope="shared.dyn") + A_smem = T.decl_buffer(s_shape, dtype, dyn.data, elem_offset=0, layout=shared_layout) + + for stage in range(n): + Tx.copy(A_smem[tuple(r_smem)], A[tuple(r_gmem(stage))]) + T.cuda.cta_sync() + T.ptx.fence.proxy_async("shared::cta") + if tid == 0: + Tx.copy_async(B[tuple(r_gmem(stage))], A_smem[tuple(r_smem)], dispatch="tma") + T.ptx.cp_async.bulk.commit_group() + T.ptx.cp_async.bulk.wait_group() + T.cuda.cta_sync() # fmt: on np_dtype = tvm.testing.np_dtype_from_str(dtype) @@ -1519,7 +1492,7 @@ def copy_async(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: def test_copy_tma_dynamic_cta_mask(dtype): """Regression test for B00004: dynamic cta_mask expression in TMA multicast. - Verifies that a TIR expression (depending on Tx.cta_id) used as cta_mask in + Verifies that a TIR expression (depending on T.cta_id) used as cta_mask in copy_async compiles through the full TIRX pipeline without crashing. Previously, lower_tirx_scope_ids replaced scope-ID vars via Substitute, but Substitute didn't visit TilePrimitiveCall.config values, leaving stale var @@ -1533,52 +1506,48 @@ def test_copy_tma_dynamic_cta_mask(dtype): thread_cnt = 128 smem_shape = (BLK_M, BLK_K) - shared_layout = Tx.ComposeLayout( - Tx.SwizzleLayout(3, 3, 3, swizzle_inner=True), Tx.TileLayout(Tx.S[smem_shape : (BLK_K, 1)]) + shared_layout = T.ComposeLayout( + T.SwizzleLayout(3, 3, 3, swizzle_inner=True), T.TileLayout(T.S[smem_shape : (BLK_K, 1)]) ) smem_bytes = BLK_M * BLK_K * tvm.DataType(dtype).bits // 8 copy_bytes = smem_bytes # fmt: off - @Tx.prim_func - def copy_async_dynamic_mask(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, [BLK_M, BLK_K], dtype) + @T.prim_func + def copy_async_dynamic_mask(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, [BLK_M, BLK_K], dtype) - Tx.device_entry() - cbx = Tx.cta_id_in_cluster([CLUSTER_SIZE]) - cta_id = Tx.cta_id([CLUSTER_SIZE]) - tid = Tx.thread_id([thread_cnt]) + T.device_entry() + cbx = T.cta_id_in_cluster([CLUSTER_SIZE]) + cta_id = T.cta_id([CLUSTER_SIZE]) + tid = T.thread_id([thread_cnt]) # Dynamic cta_mask: exact expression from B00004 bug report - cta_mask = Tx.meta_var(5 + 5 * cbx) - - with Tx.thread(): - dyn = Tx.alloc_buffer([smem_bytes + 64], "uint8", scope="shared.dyn") - A_smem = Tx.decl_buffer( - smem_shape, dtype, dyn.data, elem_offset=0, layout=shared_layout, + cta_mask = T.meta_var(5 + 5 * cbx) + dyn = T.alloc_buffer([smem_bytes + 64], "uint8", scope="shared.dyn") + A_smem = T.decl_buffer( + smem_shape, dtype, dyn.data, elem_offset=0, layout=shared_layout, + ) + mbarrier = T.decl_buffer([1], "uint64", dyn.data, elem_offset=smem_bytes // 8) + mbar_ptr = T.meta_var(mbarrier.ptr_to([0])) + + if tid == 0: + T.ptx.mbarrier.init(mbar_ptr, 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + + if tid == 0: + Tx.copy_async( + A_smem[:, :], + A[:, :], + dispatch="tma", + mbar=mbar_ptr, + cta_mask=cta_mask, + cta_group=CTA_GROUP, ) - mbarrier = Tx.decl_buffer([1], "uint64", dyn.data, elem_offset=smem_bytes // 8) - mbar_ptr = Tx.meta_var(mbarrier.ptr_to([0])) - - if tid == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(mbar_ptr, 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.mbarrier.arrive.expect_tx(mbar_ptr, copy_bytes) - if tid == 0: - with Tx.thread(): - Tx.copy_async( - A_smem[:, :], - A[:, :], - dispatch="tma", - mbar=mbar_ptr, - cta_mask=cta_mask, - cta_group=CTA_GROUP, - ) - Tx.ptx.mbarrier.arrive.expect_tx(mbar_ptr, copy_bytes) - - Tx.ptx.mbarrier.try_wait(mbar_ptr, 0) + T.ptx.mbarrier.try_wait(mbar_ptr, 0) # fmt: on target = tvm.target.Target("cuda") diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py index 4ca1c99cec23..0f910a43766d 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py @@ -22,7 +22,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import S, TCol, TileLayout, TLane from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg @@ -55,69 +56,60 @@ def next_power_of_2(x): local_view = TileLayout(S[(128, WIDTH) : (1 @ axis_tid_in_wg, 1)]) # fmt: off - @Tx.prim_func - def copy_async_test(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128, WIDTH), dtype) - B = Tx.match_buffer(B_ptr, (128, WIDTH), dtype) + @T.prim_func + def copy_async_test(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128, WIDTH), dtype) + B = T.match_buffer(B_ptr, (128, WIDTH), dtype) A_flat = A.view(-1) B_flat = B.view(-1) - Tx.device_entry() - warp_id = Tx.warp_id([(128) // 32]) - cta_id = Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - warp_id_in_wg = Tx.warp_id_in_wg([4]) - lane_id = Tx.lane_id([32]) - tid_in_wg = Tx.thread_id([128]) + T.device_entry() + warp_id = T.warp_id([(128) // 32]) + cta_id = T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + warp_id_in_wg = T.warp_id_in_wg([4]) + lane_id = T.lane_id([32]) + tid_in_wg = T.thread_id([128]) - tmem_addr = Tx.alloc_shared([1], "uint32") + tmem_addr = T.alloc_shared([1], "uint32") if wg_id == 0: - with Tx.warpgroup(): - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=max(32, next_power_of_2(width_32b)), cta_group=1) # noqa: E501 - - Tx.tvm_storage_sync("shared") - - tmem = Tx.decl_buffer((128, WIDTH), dtype, scope="tmem", allocated_addr=tmem_addr[0], # noqa: E501 - layout=TileLayout(S[(128, WIDTH) : (1 @ TLane, 1 @ TCol)])) - - A_reg = Tx.alloc_local((WIDTH), dtype) - B_reg = Tx.alloc_local((WIDTH), dtype) - A_local = A_reg.view(128, WIDTH, layout=local_view) - B_local = B_reg.view(128, WIDTH, layout=local_view) - - # A -> A_local - with Tx.thread(): - for i in range(WIDTH // VEC_LEN): - g_offset = Tx.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) - Tx.copy(A_reg[i * VEC_LEN: i * VEC_LEN + VEC_LEN], A_flat[g_offset: g_offset + VEC_LEN]) # noqa: E501 - for i in range(WIDTH): - B_reg[i] = Tx.cast(0, dtype) - Tx.cuda.cta_sync() - - # A_local -> tmem (async) - Tx.copy_async(tmem[:, :], A_local[:, :]) - Tx.ptx.tcgen05.wait.st() # explicit wait - Tx.cuda.cta_sync() - - # tmem -> B_local (async) - Tx.copy_async(B_local[:, :], tmem[:, :]) - Tx.ptx.tcgen05.wait.ld() # explicit wait - Tx.cuda.cta_sync() - - # B_local -> B - with Tx.thread(): - for i in range(WIDTH // VEC_LEN): - g_offset = Tx.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) - Tx.copy(B_flat[g_offset: g_offset + VEC_LEN], B_reg[i * VEC_LEN: i * VEC_LEN + VEC_LEN]) # noqa: E501 - - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=max(32, next_power_of_2(width_32b)), cta_group=1) # noqa: E501 + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=max(32, next_power_of_2(width_32b)), cta_group=1) # noqa: E501 + + T.tvm_storage_sync("shared") + + tmem = T.decl_buffer((128, WIDTH), dtype, scope="tmem", allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, WIDTH) : (1 @ TLane, 1 @ TCol)])) + + A_reg = T.alloc_local((WIDTH), dtype) + B_reg = T.alloc_local((WIDTH), dtype) + A_local = A_reg.view(128, WIDTH, layout=local_view) + B_local = B_reg.view(128, WIDTH, layout=local_view) + for i in range(WIDTH // VEC_LEN): + g_offset = T.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) + Tx.copy(A_reg[i * VEC_LEN: i * VEC_LEN + VEC_LEN], A_flat[g_offset: g_offset + VEC_LEN]) # noqa: E501 + for i in range(WIDTH): + B_reg[i] = T.cast(0, dtype) + T.cuda.cta_sync() + + # A_local -> tmem (async) + Tx.wg.copy_async(tmem[:, :], A_local[:, :]) + T.ptx.tcgen05.wait.st() # explicit wait + T.cuda.cta_sync() + + # tmem -> B_local (async) + Tx.wg.copy_async(B_local[:, :], tmem[:, :]) + T.ptx.tcgen05.wait.ld() # explicit wait + T.cuda.cta_sync() + for i in range(WIDTH // VEC_LEN): + g_offset = T.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) + Tx.copy(B_flat[g_offset: g_offset + VEC_LEN], B_reg[i * VEC_LEN: i * VEC_LEN + VEC_LEN]) # noqa: E501 + + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=max(32, next_power_of_2(width_32b)), cta_group=1) # noqa: E501 # fmt: on target = tvm.target.Target("cuda") @@ -134,7 +126,7 @@ def copy_async_test(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: # ---------------------------------------------------------------------------- -# Migrated from test_copy_sync.py: tmem<->reg round-trip via Tx.copy_async +# Migrated from test_copy_sync.py: tmem<->reg round-trip via T.copy_async # (the kernels themselves are the actual async tmem dispatch tests; the # G↔L copies bookending them just stage data). # ---------------------------------------------------------------------------- @@ -163,69 +155,60 @@ def next_power_of_2(x): local_view = TileLayout(S[(128, WIDTH) : (1 @ axis_tid_in_wg, 1)]) # fmt: off - @Tx.prim_func - def copy_sync(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128, WIDTH), dtype) - B = Tx.match_buffer(B_ptr, (128, WIDTH), dtype) + @T.prim_func + def copy_sync(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128, WIDTH), dtype) + B = T.match_buffer(B_ptr, (128, WIDTH), dtype) A_flat = A.view(-1) B_flat = B.view(-1) - Tx.device_entry() - warp_id = Tx.warp_id([(128) // 32]) - Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - tid_in_wg = Tx.thread_id([128]) + T.device_entry() + warp_id = T.warp_id([(128) // 32]) + T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + tid_in_wg = T.thread_id([128]) - tmem_addr = Tx.alloc_shared([1], "uint32") + tmem_addr = T.alloc_shared([1], "uint32") if wg_id == 0: - with Tx.warpgroup(): - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=max(32, next_power_of_2(offset_32b + width_32b)), cta_group=1) # noqa: E501 - - Tx.tvm_storage_sync("shared") - - tmem = Tx.decl_buffer((128, OFFSET + WIDTH), dtype, scope="tmem", allocated_addr=tmem_addr[0], # noqa: E501 - layout=TileLayout(S[(128, OFFSET + WIDTH) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 - - A_reg = Tx.alloc_local((WIDTH), dtype) - B_reg = Tx.alloc_local((WIDTH), dtype) - A_local = A_reg.view(128, WIDTH, layout=local_view) - B_local = B_reg.view(128, WIDTH, layout=local_view) - - # A -> A_local - with Tx.thread(): - for i in range(WIDTH // VEC_LEN): - g_offset = Tx.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) - Tx.copy(A_reg[i * VEC_LEN: i * VEC_LEN + VEC_LEN], A_flat[g_offset: g_offset + VEC_LEN]) # noqa: E501 - for i in range(WIDTH): - B_reg[i] = Tx.cast(0, dtype) - Tx.cuda.cta_sync() - - # A_local -> tmem - Tx.copy_async(tmem[:, OFFSET: OFFSET + WIDTH], A_local[:, :]) - Tx.ptx.tcgen05.wait.st() - Tx.cuda.cta_sync() - - # tmem -> B_local - Tx.copy_async(B_local[:, :], tmem[:, OFFSET: OFFSET + WIDTH]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - - # B_local -> B - with Tx.thread(): - for i in range(WIDTH // VEC_LEN): - g_offset = Tx.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) - Tx.copy(B_flat[g_offset: g_offset + VEC_LEN], B_reg[i * VEC_LEN: i * VEC_LEN + VEC_LEN]) # noqa: E501 - - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=max(32, next_power_of_2(offset_32b + width_32b)), cta_group=1) # noqa: E501 + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=max(32, next_power_of_2(offset_32b + width_32b)), cta_group=1) # noqa: E501 + + T.tvm_storage_sync("shared") + + tmem = T.decl_buffer((128, OFFSET + WIDTH), dtype, scope="tmem", allocated_addr=tmem_addr[0], # noqa: E501 + layout=TileLayout(S[(128, OFFSET + WIDTH) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + + A_reg = T.alloc_local((WIDTH), dtype) + B_reg = T.alloc_local((WIDTH), dtype) + A_local = A_reg.view(128, WIDTH, layout=local_view) + B_local = B_reg.view(128, WIDTH, layout=local_view) + for i in range(WIDTH // VEC_LEN): + g_offset = T.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) + Tx.copy(A_reg[i * VEC_LEN: i * VEC_LEN + VEC_LEN], A_flat[g_offset: g_offset + VEC_LEN]) # noqa: E501 + for i in range(WIDTH): + B_reg[i] = T.cast(0, dtype) + T.cuda.cta_sync() + + # A_local -> tmem + Tx.wg.copy_async(tmem[:, OFFSET: OFFSET + WIDTH], A_local[:, :]) + T.ptx.tcgen05.wait.st() + T.cuda.cta_sync() + + # tmem -> B_local + Tx.wg.copy_async(B_local[:, :], tmem[:, OFFSET: OFFSET + WIDTH]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + for i in range(WIDTH // VEC_LEN): + g_offset = T.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) + Tx.copy(B_flat[g_offset: g_offset + VEC_LEN], B_reg[i * VEC_LEN: i * VEC_LEN + VEC_LEN]) # noqa: E501 + + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=max(32, next_power_of_2(offset_32b + width_32b)), cta_group=1) # noqa: E501 # fmt: on target = tvm.target.Target("cuda") @@ -269,69 +252,60 @@ def next_power_of_2(x): local_view = TileLayout(S[(128, TOTAL_LOCAL_WIDTH) : (1 @ axis_tid_in_wg, 1)]) # fmt: off - @Tx.prim_func - def copy_sync(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128, WIDTH), dtype) - B = Tx.match_buffer(B_ptr, (128, WIDTH), dtype) + @T.prim_func + def copy_sync(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128, WIDTH), dtype) + B = T.match_buffer(B_ptr, (128, WIDTH), dtype) A_flat = A.view(-1) B_flat = B.view(-1) - Tx.device_entry() - warp_id = Tx.warp_id([(128) // 32]) - Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - tid_in_wg = Tx.thread_id([128]) + T.device_entry() + warp_id = T.warp_id([(128) // 32]) + T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + tid_in_wg = T.thread_id([128]) - tmem_addr = Tx.alloc_shared([1], "uint32") + tmem_addr = T.alloc_shared([1], "uint32") if wg_id == 0: - with Tx.warpgroup(): - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=max(32, next_power_of_2(width_32b)), cta_group=1) # noqa: E501 - - Tx.tvm_storage_sync("shared") - - tmem = Tx.decl_buffer((128, WIDTH), dtype, scope="tmem", allocated_addr=tmem_addr[0], # noqa: E501 - layout=TileLayout(S[(128, WIDTH) : (1 @ TLane, 1 @ TCol)])) - - A_reg = Tx.alloc_local((TOTAL_LOCAL_WIDTH), dtype) - B_reg = Tx.alloc_local((TOTAL_LOCAL_WIDTH), dtype) - A_local = A_reg.view(128, TOTAL_LOCAL_WIDTH, layout=local_view) - B_local = B_reg.view(128, TOTAL_LOCAL_WIDTH, layout=local_view) - - # A -> A_local (only the slice we care about) - with Tx.thread(): - for i in range(WIDTH // VEC_LEN): - g_offset = Tx.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) - Tx.copy(A_reg[LOCAL_OFFSET + i * VEC_LEN: LOCAL_OFFSET + i * VEC_LEN + VEC_LEN], A_flat[g_offset: g_offset + VEC_LEN]) # noqa: E501 - for i in range(TOTAL_LOCAL_WIDTH): - B_reg[i] = Tx.cast(0, dtype) - Tx.cuda.cta_sync() - - # A_local[sliced] -> tmem (use sliced region) - Tx.copy_async(tmem[:, 0:WIDTH], A_local[:, LOCAL_OFFSET:LOCAL_OFFSET + WIDTH]) - Tx.ptx.tcgen05.wait.st() - Tx.cuda.cta_sync() - - # tmem -> B_local[sliced] (use sliced region) - Tx.copy_async(B_local[:, LOCAL_OFFSET:LOCAL_OFFSET + WIDTH], tmem[:, 0:WIDTH]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - - # B_local -> B - with Tx.thread(): - for i in range(WIDTH // VEC_LEN): - g_offset = Tx.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) - Tx.copy(B_flat[g_offset: g_offset + VEC_LEN], B_reg[LOCAL_OFFSET + i * VEC_LEN: LOCAL_OFFSET + i * VEC_LEN + VEC_LEN]) # noqa: E501 - - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=max(32, next_power_of_2(width_32b)), cta_group=1) # noqa: E501 + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=max(32, next_power_of_2(width_32b)), cta_group=1) # noqa: E501 + + T.tvm_storage_sync("shared") + + tmem = T.decl_buffer((128, WIDTH), dtype, scope="tmem", allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, WIDTH) : (1 @ TLane, 1 @ TCol)])) + + A_reg = T.alloc_local((TOTAL_LOCAL_WIDTH), dtype) + B_reg = T.alloc_local((TOTAL_LOCAL_WIDTH), dtype) + A_local = A_reg.view(128, TOTAL_LOCAL_WIDTH, layout=local_view) + B_local = B_reg.view(128, TOTAL_LOCAL_WIDTH, layout=local_view) + for i in range(WIDTH // VEC_LEN): + g_offset = T.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) + Tx.copy(A_reg[LOCAL_OFFSET + i * VEC_LEN: LOCAL_OFFSET + i * VEC_LEN + VEC_LEN], A_flat[g_offset: g_offset + VEC_LEN]) # noqa: E501 + for i in range(TOTAL_LOCAL_WIDTH): + B_reg[i] = T.cast(0, dtype) + T.cuda.cta_sync() + + # A_local[sliced] -> tmem (use sliced region) + Tx.wg.copy_async(tmem[:, 0:WIDTH], A_local[:, LOCAL_OFFSET:LOCAL_OFFSET + WIDTH]) + T.ptx.tcgen05.wait.st() + T.cuda.cta_sync() + + # tmem -> B_local[sliced] (use sliced region) + Tx.wg.copy_async(B_local[:, LOCAL_OFFSET:LOCAL_OFFSET + WIDTH], tmem[:, 0:WIDTH]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + for i in range(WIDTH // VEC_LEN): + g_offset = T.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) + Tx.copy(B_flat[g_offset: g_offset + VEC_LEN], B_reg[LOCAL_OFFSET + i * VEC_LEN: LOCAL_OFFSET + i * VEC_LEN + VEC_LEN]) # noqa: E501 + + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=max(32, next_power_of_2(width_32b)), cta_group=1) # noqa: E501 # fmt: on target = tvm.target.Target("cuda") diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py index 7f1c42598b7e..420935946028 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py @@ -21,7 +21,7 @@ 1. Fill a (128, FULL_W) host buffer ``A`` with random values. 2. Stage ``A`` into TMEM via the existing ``.32x32b`` ld/st round-trip. -3. Issue the new ``.16x*b`` atom via ``Tx.copy_async`` to read a (64, K_cols) +3. Issue the new ``.16x*b`` atom via ``T.copy_async`` to read a (64, K_cols) fragment from TMEM into a register tile shaped by ``tcgen05_atom_layout``. 4. Dump the register tile to a ``(128, regs_per_thread)`` global buffer indexed ``B[tid_in_wg, r]``. @@ -41,7 +41,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import ( S, TCol, @@ -256,73 +257,66 @@ def _run_roundtrip_16b( atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype) tmem_layout = tmem_datapath_layout(tmem_datapath, tmem_rows, stage_width_elem) - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: # Per-thread input/output: A[tid_in_wg, i] feeds register slot i of the # warpgroup-collective fragment; B[tid_in_wg, i] is what comes back # after a .16x*b.st → .16x*b.ld round-trip. - A = Tx.match_buffer(A_ptr, (128, per_thread_elems), dtype) - B = Tx.match_buffer(B_ptr, (128, per_thread_elems), dtype) + A = T.match_buffer(A_ptr, (128, per_thread_elems), dtype) + B = T.match_buffer(B_ptr, (128, per_thread_elems), dtype) - Tx.device_entry() - warp_id = Tx.warp_id([128 // 32]) - Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - tid_in_wg = Tx.thread_id([128]) + T.device_entry() + warp_id = T.warp_id([128 // 32]) + T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + tid_in_wg = T.thread_id([128]) - tmem_addr = Tx.alloc_shared([1], "uint32") + tmem_addr = T.alloc_shared([1], "uint32") if wg_id == 0: - with Tx.warpgroup(): - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc( - Tx.address_of(tmem_addr), - n_cols=tmem_col_width_32b, - cta_group=1, - ) - - Tx.tvm_storage_sync("shared") - - tmem = Tx.decl_buffer( - (tmem_rows, stage_width_elem), - dtype, - scope="tmem", - allocated_addr=tmem_addr[0], - layout=tmem_layout, + if warp_id == 0: + T.ptx.tcgen05.alloc( + T.address_of(tmem_addr), + n_cols=tmem_col_width_32b, + cta_group=1, ) - # Load per-thread A → reg_in - reg_in = Tx.alloc_local((per_thread_elems,), dtype) - with Tx.thread(): - for i in range(per_thread_elems): - reg_in[i] = A[tid_in_wg, i] - Tx.cuda.cta_sync() - - # reg_in -> TMEM via ..x.st.unpack::16b - frag_in = reg_in.view(frag_rows, K_cols_elem, layout=atom_view) - Tx.copy_async(tmem[0:frag_rows, 0:K_cols_elem], frag_in[:, :]) - Tx.ptx.tcgen05.wait.st() - Tx.cuda.cta_sync() - - # TMEM -> reg_out via ..x.ld.pack::16b - reg_out = Tx.alloc_local((per_thread_elems,), dtype) - frag_out = reg_out.view(frag_rows, K_cols_elem, layout=atom_view) - Tx.copy_async(frag_out[:, :], tmem[0:frag_rows, 0:K_cols_elem]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - - # reg_out -> B - with Tx.thread(): - for i in range(per_thread_elems): - B[tid_in_wg, i] = reg_out[i] - - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1) + T.tvm_storage_sync("shared") + + tmem = T.decl_buffer( + (tmem_rows, stage_width_elem), + dtype, + scope="tmem", + allocated_addr=tmem_addr[0], + layout=tmem_layout, + ) + + # Load per-thread A → reg_in + reg_in = T.alloc_local((per_thread_elems,), dtype) + for i in range(per_thread_elems): + reg_in[i] = A[tid_in_wg, i] + T.cuda.cta_sync() + + # reg_in -> TMEM via ..x.st.unpack::16b + frag_in = reg_in.view(frag_rows, K_cols_elem, layout=atom_view) + Tx.wg.copy_async(tmem[0:frag_rows, 0:K_cols_elem], frag_in[:, :]) + T.ptx.tcgen05.wait.st() + T.cuda.cta_sync() + + # TMEM -> reg_out via ..x.ld.pack::16b + reg_out = T.alloc_local((per_thread_elems,), dtype) + frag_out = reg_out.view(frag_rows, K_cols_elem, layout=atom_view) + Tx.wg.copy_async(frag_out[:, :], tmem[0:frag_rows, 0:K_cols_elem]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + for i in range(per_thread_elems): + B[tid_in_wg, i] = reg_out[i] + + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1) target = tvm.target.Target("cuda") with target: @@ -451,29 +445,28 @@ def test_layout_F_rejects_incompatible_atoms(atom_kind, frag_rows): tmem_rows = 64 stage_width_elem = max(32, local_cols) - @Tx.prim_func + @T.prim_func def kernel() -> None: - Tx.device_entry() - Tx.warp_id([128 // 32]) - Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - Tx.thread_id([128]) - tmem_addr = Tx.alloc_shared([1], "uint32") + T.device_entry() + T.warp_id([128 // 32]) + T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + T.thread_id([128]) + tmem_addr = T.alloc_shared([1], "uint32") if wg_id == 0: - with Tx.warpgroup(): - Tx.tvm_storage_sync("shared") - tmem = Tx.decl_buffer( - (tmem_rows, stage_width_elem), - "float32", - scope="tmem", - allocated_addr=tmem_addr[0], - layout=tmem_layout, - ) - frag = Tx.alloc_local((local_extent_rows * local_cols // 128,), "float32") - frag_view = frag.view(local_extent_rows, local_cols, layout=atom_view) - Tx.copy_async(frag_view[:, :], tmem[0:local_extent_rows, 0:local_cols]) + T.tvm_storage_sync("shared") + tmem = T.decl_buffer( + (tmem_rows, stage_width_elem), + "float32", + scope="tmem", + allocated_addr=tmem_addr[0], + layout=tmem_layout, + ) + frag = T.alloc_local((local_extent_rows * local_cols // 128,), "float32") + frag_view = frag.view(local_extent_rows, local_cols, layout=atom_view) + Tx.wg.copy_async(frag_view[:, :], tmem[0:local_extent_rows, 0:local_cols]) target = tvm.target.Target("cuda") with target: @@ -484,7 +477,7 @@ def kernel() -> None: def _run_load_test(shape: str, rep: int, dtype: str): """Stage A into TMEM via .32x32b, then read it back as the fragment via - ..x (through ``Tx.alloc_tcgen05_ldst_frag``), and compare each + ..x (through ``T.alloc_tcgen05_ldst_frag``), and compare each thread's registers against the expected layout-derived value.""" bits = tvm.runtime.DataType(dtype).bits elem_per_32b = 32 // bits @@ -520,94 +513,85 @@ def _run_load_test(shape: str, rep: int, dtype: str): chunk_view = TileLayout(S[(128, chunk_width_elem) : (1 @ axis_tid_in_wg, 1)]) # The factory + wrapper both go through ``tcgen05_atom_layout``; we use it # explicitly here so that ``frag_local`` has the canonical layout that - # ``Tx.copy_async`` matches when dispatching to the right atom path. + # ``T.copy_async`` matches when dispatching to the right atom path. atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype) - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: # A is the host data we stage into TMEM via the standard .32x32b path. - A = Tx.match_buffer(A_ptr, (128, stage_width_elem), dtype) + A = T.match_buffer(A_ptr, (128, stage_width_elem), dtype) # B is a per-thread register dump: B[tid_in_wg, reg_idx_in_elements]. - B = Tx.match_buffer(B_ptr, (128, per_thread_elems), dtype) + B = T.match_buffer(B_ptr, (128, per_thread_elems), dtype) A_flat = A.view(-1) - Tx.device_entry() - warp_id = Tx.warp_id([128 // 32]) - Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - tid_in_wg = Tx.thread_id([128]) + T.device_entry() + warp_id = T.warp_id([128 // 32]) + T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + tid_in_wg = T.thread_id([128]) - tmem_addr = Tx.alloc_shared([1], "uint32") + tmem_addr = T.alloc_shared([1], "uint32") if wg_id == 0: - with Tx.warpgroup(): - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc( - Tx.address_of(tmem_addr), - n_cols=tmem_col_width_32b, - cta_group=1, - ) - - Tx.tvm_storage_sync("shared") - - tmem = Tx.decl_buffer( - (128, stage_width_elem), - dtype, - scope="tmem", - allocated_addr=tmem_addr[0], - layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ TCol)]), + if warp_id == 0: + T.ptx.tcgen05.alloc( + T.address_of(tmem_addr), + n_cols=tmem_col_width_32b, + cta_group=1, ) - # Per-thread chunk staging buffer (CHUNK_FP32 fp32 worth). - stage_reg = Tx.alloc_local((chunk_width_elem,), dtype) - stage_local = stage_reg.view(128, chunk_width_elem, layout=chunk_view) - - # Walk chunks: A[:, ck:ck+chunk] -> stage_reg -> TMEM[:, ck:ck+chunk] - for chunk_idx in range(num_chunks): - col_off_elem = chunk_idx * chunk_width_elem - with Tx.thread(): - for i in range(chunk_width_elem // VEC_LEN): - # Each thread's row offset in A_flat: stage_width_elem; within - # the row, this chunk starts at col_off_elem and each vector - # picks up VEC_LEN elements at slot i. - g_offset = Tx.meta_var( - tid_in_wg * stage_width_elem + col_off_elem + i * VEC_LEN - ) - Tx.copy( - stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN], - A_flat[g_offset : g_offset + VEC_LEN], - ) - Tx.cuda.cta_sync() - Tx.copy_async( - tmem[:, col_off_elem : col_off_elem + chunk_width_elem], - stage_local[:, :], + T.tvm_storage_sync("shared") + + tmem = T.decl_buffer( + (128, stage_width_elem), + dtype, + scope="tmem", + allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ TCol)]), + ) + + # Per-thread chunk staging buffer (CHUNK_FP32 fp32 worth). + stage_reg = T.alloc_local((chunk_width_elem,), dtype) + stage_local = stage_reg.view(128, chunk_width_elem, layout=chunk_view) + + # Walk chunks: A[:, ck:ck+chunk] -> stage_reg -> TMEM[:, ck:ck+chunk] + for chunk_idx in range(num_chunks): + col_off_elem = chunk_idx * chunk_width_elem + for i in range(chunk_width_elem // VEC_LEN): + # Each thread's row offset in A_flat: stage_width_elem; within + # the row, this chunk starts at col_off_elem and each vector + # picks up VEC_LEN elements at slot i. + g_offset = T.meta_var(tid_in_wg * stage_width_elem + col_off_elem + i * VEC_LEN) + Tx.copy( + stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN], + A_flat[g_offset : g_offset + VEC_LEN], ) - Tx.ptx.tcgen05.wait.st() - Tx.cuda.cta_sync() - - # TMEM[0:frag_rows, 0:K_cols] -> frag_local via ..x.ld. - # Use ``tcgen05_atom_layout`` so dispatch matches the new path - # (or stays on .32x32b for instr_shape="32x32b"). Keep the flat - # ``frag_reg`` for the per-thread dump below. - frag_reg = Tx.alloc_local((per_thread_elems,), dtype) - frag_local = frag_reg.view(frag_rows, K_cols_elem, layout=atom_view) - Tx.copy_async(frag_local[:, :], tmem[0:frag_rows, 0:K_cols_elem]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - - # Dump per-thread regs to B[tid_in_wg, :] - with Tx.thread(): - for i in range(per_thread_elems): - B[tid_in_wg, i] = frag_reg[i] - - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1) + T.cuda.cta_sync() + Tx.wg.copy_async( + tmem[:, col_off_elem : col_off_elem + chunk_width_elem], + stage_local[:, :], + ) + T.ptx.tcgen05.wait.st() + T.cuda.cta_sync() + + # TMEM[0:frag_rows, 0:K_cols] -> frag_local via ..x.ld. + # Use ``tcgen05_atom_layout`` so dispatch matches the new path + # (or stays on .32x32b for instr_shape="32x32b"). Keep the flat + # ``frag_reg`` for the per-thread dump below. + frag_reg = T.alloc_local((per_thread_elems,), dtype) + frag_local = frag_reg.view(frag_rows, K_cols_elem, layout=atom_view) + Tx.wg.copy_async(frag_local[:, :], tmem[0:frag_rows, 0:K_cols_elem]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + for i in range(per_thread_elems): + B[tid_in_wg, i] = frag_reg[i] + + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1) target = tvm.target.Target("cuda") with target: @@ -694,77 +678,70 @@ def test_tcgen05_st_16xnb_store(shape, rep, dtype): stage_view = TileLayout(S[(128, stage_width_elem) : (1 @ axis_tid_in_wg, 1)]) atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype) - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: # A[tid_in_wg, i] is the i-th per-thread element to feed into the atom store. - A = Tx.match_buffer(A_ptr, (128, per_thread_elems), dtype) + A = T.match_buffer(A_ptr, (128, per_thread_elems), dtype) # B[lane, col] is the TMEM-staged readout after the round-trip. - B = Tx.match_buffer(B_ptr, (128, stage_width_elem), dtype) + B = T.match_buffer(B_ptr, (128, stage_width_elem), dtype) B_flat = B.view(-1) - Tx.device_entry() - warp_id = Tx.warp_id([128 // 32]) - Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - tid_in_wg = Tx.thread_id([128]) + T.device_entry() + warp_id = T.warp_id([128 // 32]) + T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + tid_in_wg = T.thread_id([128]) - tmem_addr = Tx.alloc_shared([1], "uint32") + tmem_addr = T.alloc_shared([1], "uint32") if wg_id == 0: - with Tx.warpgroup(): - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc( - Tx.address_of(tmem_addr), - n_cols=tmem_col_width_32b, - cta_group=1, - ) - - Tx.tvm_storage_sync("shared") - - tmem = Tx.decl_buffer( - (128, stage_width_elem), - dtype, - scope="tmem", - allocated_addr=tmem_addr[0], - layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ TCol)]), + if warp_id == 0: + T.ptx.tcgen05.alloc( + T.address_of(tmem_addr), + n_cols=tmem_col_width_32b, + cta_group=1, + ) + + T.tvm_storage_sync("shared") + + tmem = T.decl_buffer( + (128, stage_width_elem), + dtype, + scope="tmem", + allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ TCol)]), + ) + + # Load per-thread A → frag_reg + frag_reg = T.alloc_local((per_thread_elems,), dtype) + for i in range(per_thread_elems): + frag_reg[i] = A[tid_in_wg, i] + T.cuda.cta_sync() + + # frag_local -> TMEM via ..x.st + frag_local = frag_reg.view(frag_rows, K_cols_elem, layout=atom_view) + Tx.wg.copy_async(tmem[0:frag_rows, 0:K_cols_elem], frag_local[:, :]) + T.ptx.tcgen05.wait.st() + T.cuda.cta_sync() + + # TMEM -> readout via .32x32b.ld + stage_reg = T.alloc_local((stage_width_elem,), dtype) + stage_local = stage_reg.view(128, stage_width_elem, layout=stage_view) + Tx.wg.copy_async(stage_local[:, :], tmem[:, :]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + for i in range(stage_width_elem // VEC_LEN): + g_offset = T.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) + Tx.copy( + B_flat[g_offset : g_offset + VEC_LEN], + stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN], ) - # Load per-thread A → frag_reg - frag_reg = Tx.alloc_local((per_thread_elems,), dtype) - with Tx.thread(): - for i in range(per_thread_elems): - frag_reg[i] = A[tid_in_wg, i] - Tx.cuda.cta_sync() - - # frag_local -> TMEM via ..x.st - frag_local = frag_reg.view(frag_rows, K_cols_elem, layout=atom_view) - Tx.copy_async(tmem[0:frag_rows, 0:K_cols_elem], frag_local[:, :]) - Tx.ptx.tcgen05.wait.st() - Tx.cuda.cta_sync() - - # TMEM -> readout via .32x32b.ld - stage_reg = Tx.alloc_local((stage_width_elem,), dtype) - stage_local = stage_reg.view(128, stage_width_elem, layout=stage_view) - Tx.copy_async(stage_local[:, :], tmem[:, :]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - - # readout -> B (full 128xstage_width_elem dump) - with Tx.thread(): - for i in range(stage_width_elem // VEC_LEN): - g_offset = Tx.meta_var(g_layout.apply(tid_in_wg, i, 0)["m"]) - Tx.copy( - B_flat[g_offset : g_offset + VEC_LEN], - stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN], - ) - - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1) + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1) target = tvm.target.Target("cuda") with target: @@ -816,7 +793,7 @@ def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: # -------------------------------------------------------------------------- -# Wrapper test: exercise Tx.alloc_tcgen05_ldst_frag directly (compile-only smoke). +# Wrapper test: exercise T.alloc_tcgen05_ldst_frag directly (compile-only smoke). # -------------------------------------------------------------------------- @@ -831,44 +808,39 @@ def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: ], ) def test_alloc_tcgen05_frag_wrapper_compiles(shape, frag_rows, K_cols): - """Ensure Tx.alloc_tcgen05_ldst_frag yields a buffer that ``Tx.copy_async`` accepts + """Ensure T.alloc_tcgen05_ldst_frag yields a buffer that ``T.copy_async`` accepts and lowers to the correct tcgen05 atom for each supported instr_shape.""" - @Tx.prim_func - def kernel(A_ptr: Tx.handle) -> None: - Tx.match_buffer(A_ptr, (128, K_cols), "float32") - Tx.device_entry() - warp_id = Tx.warp_id([4]) - Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - Tx.thread_id([128]) - - tmem_addr = Tx.alloc_shared([1], "uint32") + @T.prim_func + def kernel(A_ptr: T.handle) -> None: + T.match_buffer(A_ptr, (128, K_cols), "float32") + T.device_entry() + warp_id = T.warp_id([4]) + T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + T.thread_id([128]) + + tmem_addr = T.alloc_shared([1], "uint32") if wg_id == 0: - with Tx.warpgroup(): - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc( - Tx.address_of(tmem_addr), n_cols=max(32, K_cols), cta_group=1 - ) - Tx.tvm_storage_sync("shared") - tmem = Tx.decl_buffer( - (128, K_cols), - "float32", - scope="tmem", - allocated_addr=tmem_addr[0], - layout=TileLayout(S[(128, K_cols) : (1 @ TLane, 1 @ TCol)]), - ) - # One-liner: wrapper handles per-thread storage + layout. - frag = Tx.alloc_tcgen05_ldst_frag(shape, (frag_rows, K_cols), "float32") - Tx.copy_async(frag[:, :], tmem[0:frag_rows, 0:K_cols]) - Tx.ptx.tcgen05.wait.ld() - if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=max(32, K_cols), cta_group=1) + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=max(32, K_cols), cta_group=1) + T.tvm_storage_sync("shared") + tmem = T.decl_buffer( + (128, K_cols), + "float32", + scope="tmem", + allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, K_cols) : (1 @ TLane, 1 @ TCol)]), + ) + # One-liner: wrapper handles per-thread storage + layout. + frag = T.alloc_tcgen05_ldst_frag(shape, (frag_rows, K_cols), "float32") + Tx.wg.copy_async(frag[:, :], tmem[0:frag_rows, 0:K_cols]) + T.ptx.tcgen05.wait.ld() + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=max(32, K_cols), cta_group=1) target = tvm.target.Target("cuda") with target: diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py index 8780768f031e..1ce0d34ea6e0 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py @@ -21,7 +21,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import S, TileLayout, wg_local_layout @@ -82,72 +83,68 @@ def test_binary_op_shared(input, op_type, operands_type, dtype): map_slice_b = list(slice(st_b[i], st_b[i] + ext_b[i]) for i in range(len(g_shape))) map_slice_res = list(slice(st_res[i], st_res[i] + ext_res[i]) for i in range(len(g_shape))) - const = Tx.float16(3.0) if dtype == "float16" else Tx.float32(3.0) + const = T.float16(3.0) if dtype == "float16" else T.float32(3.0) # fmt: off - @Tx.prim_func - def binary_op_region_region(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=g_layout) - B = Tx.match_buffer(B_ptr, g_shape, dtype, layout=g_layout) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([thread_cnt]) - - with Tx.cta(): - A_smem = Tx.alloc_buffer(g_shape, dtype, scope="shared", layout=s_layout) - B_smem = Tx.alloc_buffer(g_shape, dtype, scope="shared", layout=s_layout) - - Tx.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) - Tx.copy(B_smem[tuple(copy_slice)], B[tuple(copy_slice)]) - Tx.cuda.cta_sync() - if op_type == "add": - Tx.add(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], B_smem[tuple(map_slice_b)]) # noqa: E501 - elif op_type == "sub": - Tx.sub(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], B_smem[tuple(map_slice_b)]) # noqa: E501 - elif op_type == "mul": - Tx.mul(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], B_smem[tuple(map_slice_b)]) # noqa: E501 - elif op_type == "fdiv": - Tx.fdiv(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], B_smem[tuple(map_slice_b)]) # noqa: E501 - Tx.cuda.cta_sync() - Tx.copy(A[tuple(copy_slice)], A_smem[tuple(copy_slice)]) - - @Tx.prim_func - def binary_op_const_region_or_region_const(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=g_layout) - _B = Tx.match_buffer(B_ptr, g_shape, dtype, layout=g_layout) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([thread_cnt]) - - with Tx.cta(): - A_smem = Tx.alloc_buffer(g_shape, dtype, scope="shared", layout=s_layout) - - Tx.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) - Tx.cuda.cta_sync() - if op_type == "add": - if operands_type == "const_region": - Tx.add(A_smem[tuple(map_slice_res)], const, A_smem[tuple(map_slice_a)]) - elif operands_type == "region_const": - Tx.add(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], const) - elif op_type == "sub": - if operands_type == "const_region": - Tx.sub(A_smem[tuple(map_slice_res)], const, A_smem[tuple(map_slice_a)]) - elif operands_type == "region_const": - Tx.sub(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], const) - elif op_type == "mul": - if operands_type == "const_region": - Tx.mul(A_smem[tuple(map_slice_res)], const, A_smem[tuple(map_slice_a)]) - elif operands_type == "region_const": - Tx.mul(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], const) - elif op_type == "fdiv": - if operands_type == "const_region": - Tx.fdiv(A_smem[tuple(map_slice_res)], const, A_smem[tuple(map_slice_a)]) - elif operands_type == "region_const": - Tx.fdiv(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], const) - Tx.cuda.cta_sync() - Tx.copy(A[tuple(copy_slice)], A_smem[tuple(copy_slice)]) + @T.prim_func + def binary_op_region_region(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=g_layout) + B = T.match_buffer(B_ptr, g_shape, dtype, layout=g_layout) + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([thread_cnt]) + A_smem = T.alloc_buffer(g_shape, dtype, scope="shared", layout=s_layout) + B_smem = T.alloc_buffer(g_shape, dtype, scope="shared", layout=s_layout) + + Tx.cta.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) + Tx.cta.copy(B_smem[tuple(copy_slice)], B[tuple(copy_slice)]) + T.cuda.cta_sync() + if op_type == "add": + Tx.cta.add(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], B_smem[tuple(map_slice_b)]) # noqa: E501 + elif op_type == "sub": + Tx.cta.sub(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], B_smem[tuple(map_slice_b)]) # noqa: E501 + elif op_type == "mul": + Tx.cta.mul(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], B_smem[tuple(map_slice_b)]) # noqa: E501 + elif op_type == "fdiv": + Tx.cta.fdiv(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], B_smem[tuple(map_slice_b)]) # noqa: E501 + T.cuda.cta_sync() + Tx.cta.copy(A[tuple(copy_slice)], A_smem[tuple(copy_slice)]) + + @T.prim_func + def binary_op_const_region_or_region_const(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=g_layout) + _B = T.match_buffer(B_ptr, g_shape, dtype, layout=g_layout) + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([thread_cnt]) + A_smem = T.alloc_buffer(g_shape, dtype, scope="shared", layout=s_layout) + + Tx.cta.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) + T.cuda.cta_sync() + if op_type == "add": + if operands_type == "const_region": + Tx.cta.add(A_smem[tuple(map_slice_res)], const, A_smem[tuple(map_slice_a)]) + elif operands_type == "region_const": + Tx.cta.add(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], const) + elif op_type == "sub": + if operands_type == "const_region": + Tx.cta.sub(A_smem[tuple(map_slice_res)], const, A_smem[tuple(map_slice_a)]) + elif operands_type == "region_const": + Tx.cta.sub(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], const) + elif op_type == "mul": + if operands_type == "const_region": + Tx.cta.mul(A_smem[tuple(map_slice_res)], const, A_smem[tuple(map_slice_a)]) + elif operands_type == "region_const": + Tx.cta.mul(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], const) + elif op_type == "fdiv": + if operands_type == "const_region": + Tx.cta.fdiv(A_smem[tuple(map_slice_res)], const, A_smem[tuple(map_slice_a)]) + elif operands_type == "region_const": + Tx.cta.fdiv(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)], const) + T.cuda.cta_sync() + Tx.cta.copy(A[tuple(copy_slice)], A_smem[tuple(copy_slice)]) # fmt: on def get_prim_func(operands_type): @@ -205,21 +202,20 @@ def test_binary_non_commutative_const_lhs_rejected(op_type): dtype = "float16" shape = (16, 16) layout = TileLayout(S[shape]) - const = Tx.float16(3.0) + const = T.float16(3.0) with pytest.raises(Exception): - @Tx.prim_func + @T.prim_func def bad_kernel() -> None: - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([64]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(shape, dtype, scope="shared", layout=layout) - if op_type == "sub": - Tx.sub(A_smem, const, A_smem) - elif op_type == "fdiv": - Tx.fdiv(A_smem, const, A_smem) + T.device_entry() + _bx = T.cta_id([1]) + _tid = T.thread_id([64]) + A_smem = T.alloc_buffer(shape, dtype, scope="shared", layout=layout) + if op_type == "sub": + Tx.cta.sub(A_smem, const, A_smem) + elif op_type == "fdiv": + Tx.cta.fdiv(A_smem, const, A_smem) target = tvm.target.Target("cuda") with target: @@ -235,33 +231,35 @@ def test_binary_op_shared_subcta_scope(exec_scope, op_type): n_warps = 4 if exec_scope == "warpgroup" else 1 g_shape = (n_warps * 32, 8) dev = tvm.cuda(0) - tx_op = {"add": Tx.add, "mul": Tx.mul}[op_type] - - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=TileLayout(S[g_shape])) - B = Tx.match_buffer(B_ptr, g_shape, dtype, layout=TileLayout(S[g_shape])) - Tx.device_entry() - warp_id = Tx.warp_id([(256) // 32]) - wg_id = Tx.warpgroup_id([(256) // 128]) - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([256]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(g_shape, dtype, scope="shared", layout=TileLayout(S[g_shape])) - B_smem = Tx.alloc_buffer(g_shape, dtype, scope="shared", layout=TileLayout(S[g_shape])) - Tx.copy(A_smem, A) - Tx.copy(B_smem, B) - Tx.cuda.cta_sync() - if exec_scope == "warp": - if warp_id == 5: - with Tx.warp(): - tx_op(A_smem, A_smem, B_smem) - elif exec_scope == "warpgroup": - if wg_id == 1: - with Tx.warpgroup(): - tx_op(A_smem, A_smem, B_smem) - Tx.cuda.cta_sync() - Tx.copy(A, A_smem) + tx_op = { + ("warp", "add"): Tx.warp.add, + ("warp", "mul"): Tx.warp.mul, + ("warpgroup", "add"): Tx.wg.add, + ("warpgroup", "mul"): Tx.wg.mul, + }[(exec_scope, op_type)] + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=TileLayout(S[g_shape])) + B = T.match_buffer(B_ptr, g_shape, dtype, layout=TileLayout(S[g_shape])) + T.device_entry() + warp_id = T.warp_id([(256) // 32]) + wg_id = T.warpgroup_id([(256) // 128]) + _bx = T.cta_id([1]) + _tid = T.thread_id([256]) + A_smem = T.alloc_buffer(g_shape, dtype, scope="shared", layout=TileLayout(S[g_shape])) + B_smem = T.alloc_buffer(g_shape, dtype, scope="shared", layout=TileLayout(S[g_shape])) + Tx.cta.copy(A_smem, A) + Tx.cta.copy(B_smem, B) + T.cuda.cta_sync() + if exec_scope == "warp": + if warp_id == 5: + tx_op(A_smem, A_smem, B_smem) + elif exec_scope == "warpgroup": + if wg_id == 1: + tx_op(A_smem, A_smem, B_smem) + T.cuda.cta_sync() + Tx.cta.copy(A, A_smem) target = tvm.target.Target("cuda") with target: @@ -290,72 +288,62 @@ def test_binary_op_local_subcta_trivial(exec_scope, rhs_kind, op_type): a_shape = (n_threads, m, n) b_shape = (n_threads, m, n if rhs_kind == "region" else 1) c_shape = a_shape - const = Tx.float16(1.25) + const = T.float16(1.25) dev = tvm.cuda(0) tx_op = {"add": Tx.add, "sub": Tx.sub, "mul": Tx.mul, "fdiv": Tx.fdiv}[op_type] - tid_in_scope_fn = {"cta": Tx.thread_id, "warpgroup": Tx.thread_id_in_wg, "warp": Tx.lane_id}[ + tid_in_scope_fn = {"cta": T.thread_id, "warpgroup": T.thread_id_in_wg, "warp": T.lane_id}[ exec_scope ] - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, a_shape, dtype, layout=TileLayout(S[a_shape])) - B = Tx.match_buffer(B_ptr, b_shape, dtype, layout=TileLayout(S[b_shape])) - C = Tx.match_buffer(C_ptr, c_shape, dtype, layout=TileLayout(S[c_shape])) - - Tx.device_entry() - wg_id = Tx.warpgroup_id([(256) // 128]) - warp_id = Tx.warp_id([(256) // 32]) - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([256]) + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, a_shape, dtype, layout=TileLayout(S[a_shape])) + B = T.match_buffer(B_ptr, b_shape, dtype, layout=TileLayout(S[b_shape])) + C = T.match_buffer(C_ptr, c_shape, dtype, layout=TileLayout(S[c_shape])) + + T.device_entry() + wg_id = T.warpgroup_id([(256) // 128]) + warp_id = T.warp_id([(256) // 32]) + _bx = T.cta_id([1]) + _tid = T.thread_id([256]) tid_in_scope = tid_in_scope_fn([n_threads]) + b_n = T.meta_var(n if rhs_kind == "region" else 1) + A_local = T.alloc_buffer((m, n), dtype, scope="local", layout=TileLayout(S[(m, n)])) + C_local = T.alloc_buffer((m, n), dtype, scope="local", layout=TileLayout(S[(m, n)])) + B_local = T.alloc_buffer((m, b_n), dtype, scope="local", layout=TileLayout(S[(m, b_n)])) + + if thr_str <= _tid and _tid < thr_str + n_threads: + for i in T.serial(m): + for j in T.serial(n): + A_local[i, j] = A[tid_in_scope, i, j] + if rhs_kind != "const": + for i in T.serial(m): + for j in T.serial(b_n): + B_local[i, j] = B[tid_in_scope, i, j] - with Tx.cta(): - b_n = Tx.meta_var(n if rhs_kind == "region" else 1) - A_local = Tx.alloc_buffer((m, n), dtype, scope="local", layout=TileLayout(S[(m, n)])) - C_local = Tx.alloc_buffer((m, n), dtype, scope="local", layout=TileLayout(S[(m, n)])) - B_local = Tx.alloc_buffer( - (m, b_n), dtype, scope="local", layout=TileLayout(S[(m, b_n)]) - ) - - if thr_str <= _tid and _tid < thr_str + n_threads: - with Tx.thread(): - for i in Tx.serial(m): - for j in Tx.serial(n): - A_local[i, j] = A[tid_in_scope, i, j] - if rhs_kind != "const": - for i in Tx.serial(m): - for j in Tx.serial(b_n): - B_local[i, j] = B[tid_in_scope, i, j] - # Tx.cuda.cta_sync() - - if exec_scope == "cta": - with Tx.cta(): - if rhs_kind == "const": - tx_op(C_local, A_local, const) - else: - tx_op(C_local, A_local, B_local) - elif exec_scope == "warpgroup": - if wg_id == 1: - with Tx.warpgroup(): - if rhs_kind == "const": - tx_op(C_local, A_local, const) - else: - tx_op(C_local, A_local, B_local) + if exec_scope == "cta": + if rhs_kind == "const": + tx_op(C_local, A_local, const) else: - if warp_id == 3: - with Tx.warp(): - if rhs_kind == "const": - tx_op(C_local, A_local, const) - else: - tx_op(C_local, A_local, B_local) - # Tx.cuda.cta_sync() - - if thr_str <= _tid and _tid < thr_str + n_threads: - with Tx.thread(): - for i in Tx.serial(m): - for j in Tx.serial(n): - C[tid_in_scope, i, j] = C_local[i, j] + tx_op(C_local, A_local, B_local) + elif exec_scope == "warpgroup": + if wg_id == 1: + if rhs_kind == "const": + tx_op(C_local, A_local, const) + else: + tx_op(C_local, A_local, B_local) + else: + if warp_id == 3: + if rhs_kind == "const": + tx_op(C_local, A_local, const) + else: + tx_op(C_local, A_local, B_local) + # T.cuda.cta_sync() + + if thr_str <= _tid and _tid < thr_str + n_threads: + for i in T.serial(m): + for j in T.serial(n): + C[tid_in_scope, i, j] = C_local[i, j] target = tvm.target.Target("cuda") with target: @@ -413,76 +401,71 @@ def test_binary_op_vectorized(input, storage_scope, exec_scope, op_type, dtype): tx_op = {"add": Tx.add, "sub": Tx.sub, "mul": Tx.mul, "fdiv": Tx.fdiv}[op_type] # fmt: off - @Tx.prim_func - def test_binary_cta(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, a_shape, dtype, layout=TileLayout(S[a_shape])) - B = Tx.match_buffer(B_ptr, b_shape, dtype, layout=TileLayout(S[b_shape])) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - tx = Tx.thread_id([thread_cnt]) - with Tx.cta(): - if storage_scope == "shared": - A_smem = Tx.alloc_buffer( - a_shape, dtype, scope="shared", layout=TileLayout(S[a_shape]) - ) - B_smem = Tx.alloc_buffer( - b_shape, dtype, scope="shared", layout=TileLayout(S[b_shape]) - ) - Tx.copy(A_smem, A) - Tx.copy(B_smem, B) - Tx.cuda.cta_sync() - tx_op(A_smem, A_smem, B_smem) - Tx.cuda.cta_sync() - Tx.copy(A, A_smem) - with Tx.thread(): - if storage_scope == "local": - A_local = Tx.alloc_buffer( - a_shape[1:], dtype, scope="local", layout=TileLayout(S[a_shape[1:]]) - ) - B_local = Tx.alloc_buffer( - b_shape[1:], dtype, scope="local", layout=TileLayout(S[b_shape[1:]]) - ) - Tx.copy(A_local, A[tx]) - Tx.copy(B_local, B[tx]) - with Tx.cta(): - tx_op(A_local, A_local, B_local) - Tx.copy(A[tx], A_local) - - @Tx.prim_func - def test_binary_thread(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, a_shape, dtype, layout=TileLayout(S[a_shape])) - B = Tx.match_buffer(B_ptr, b_shape, dtype, layout=TileLayout(S[b_shape])) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - tx = Tx.thread_id([thread_cnt]) - - with Tx.thread(): - if storage_scope == "shared": - A_smem = Tx.alloc_buffer( - a_shape, dtype, scope="shared", layout=TileLayout(S[a_shape]) - ) - B_smem = Tx.alloc_buffer( - b_shape, dtype, scope="shared", layout=TileLayout(S[b_shape]) - ) - Tx.copy(A_smem, A) - Tx.copy(B_smem, B) - Tx.cuda.cta_sync() - tx_op(A_smem, A_smem, B_smem) - Tx.cuda.cta_sync() - Tx.copy(A, A_smem) - elif storage_scope == "local": - A_local = Tx.alloc_buffer( - a_shape[1:], dtype, scope="local", layout=TileLayout(S[a_shape[1:]]) - ) - B_local = Tx.alloc_buffer( - b_shape[1:], dtype, scope="local", layout=TileLayout(S[b_shape[1:]]) - ) - Tx.copy(A_local, A[tx]) - Tx.copy(B_local, B[tx]) - tx_op(A_local, A_local, B_local) - Tx.copy(A[tx], A_local) + @T.prim_func + def test_binary_cta(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, a_shape, dtype, layout=TileLayout(S[a_shape])) + B = T.match_buffer(B_ptr, b_shape, dtype, layout=TileLayout(S[b_shape])) + + T.device_entry() + _bx = T.cta_id([1]) + tx = T.thread_id([thread_cnt]) + if storage_scope == "shared": + A_smem = T.alloc_buffer( + a_shape, dtype, scope="shared", layout=TileLayout(S[a_shape]) + ) + B_smem = T.alloc_buffer( + b_shape, dtype, scope="shared", layout=TileLayout(S[b_shape]) + ) + Tx.cta.copy(A_smem, A) + Tx.cta.copy(B_smem, B) + T.cuda.cta_sync() + tx_op(A_smem, A_smem, B_smem) + T.cuda.cta_sync() + Tx.cta.copy(A, A_smem) + if storage_scope == "local": + A_local = T.alloc_buffer( + a_shape[1:], dtype, scope="local", layout=TileLayout(S[a_shape[1:]]) + ) + B_local = T.alloc_buffer( + b_shape[1:], dtype, scope="local", layout=TileLayout(S[b_shape[1:]]) + ) + Tx.copy(A_local, A[tx]) + Tx.copy(B_local, B[tx]) + tx_op(A_local, A_local, B_local) + Tx.copy(A[tx], A_local) + + @T.prim_func + def test_binary_thread(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, a_shape, dtype, layout=TileLayout(S[a_shape])) + B = T.match_buffer(B_ptr, b_shape, dtype, layout=TileLayout(S[b_shape])) + + T.device_entry() + _bx = T.cta_id([1]) + tx = T.thread_id([thread_cnt]) + if storage_scope == "shared": + A_smem = T.alloc_buffer( + a_shape, dtype, scope="shared", layout=TileLayout(S[a_shape]) + ) + B_smem = T.alloc_buffer( + b_shape, dtype, scope="shared", layout=TileLayout(S[b_shape]) + ) + Tx.copy(A_smem, A) + Tx.copy(B_smem, B) + T.cuda.cta_sync() + tx_op(A_smem, A_smem, B_smem) + T.cuda.cta_sync() + Tx.copy(A, A_smem) + elif storage_scope == "local": + A_local = T.alloc_buffer( + a_shape[1:], dtype, scope="local", layout=TileLayout(S[a_shape[1:]]) + ) + B_local = T.alloc_buffer( + b_shape[1:], dtype, scope="local", layout=TileLayout(S[b_shape[1:]]) + ) + Tx.copy(A_local, A[tx]) + Tx.copy(B_local, B[tx]) + tx_op(A_local, A_local, B_local) + Tx.copy(A[tx], A_local) # fmt: on def get_prim_func(): @@ -529,30 +512,29 @@ def test_binary_op_packed_f32x2_auto_dispatch(op_type): dtype = "float32" dev = tvm.cuda(0) - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, a_shape, dtype, layout=TileLayout(S[a_shape])) - B = Tx.match_buffer(B_ptr, b_shape, dtype, layout=TileLayout(S[b_shape])) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - tx = Tx.thread_id([64]) - with Tx.thread(): - A_local = Tx.alloc_buffer( - a_shape[1:], dtype, scope="local", layout=TileLayout(S[a_shape[1:]]) - ) - B_local = Tx.alloc_buffer( - b_shape[1:], dtype, scope="local", layout=TileLayout(S[b_shape[1:]]) - ) - Tx.copy(A_local, A[tx]) - Tx.copy(B_local, B[tx]) - if op_type == "add": - Tx.add(A_local, A_local, B_local) - elif op_type == "sub": - Tx.sub(A_local, A_local, B_local) - elif op_type == "mul": - Tx.mul(A_local, A_local, B_local) - Tx.copy(A[tx], A_local) + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, a_shape, dtype, layout=TileLayout(S[a_shape])) + B = T.match_buffer(B_ptr, b_shape, dtype, layout=TileLayout(S[b_shape])) + + T.device_entry() + _bx = T.cta_id([1]) + tx = T.thread_id([64]) + A_local = T.alloc_buffer( + a_shape[1:], dtype, scope="local", layout=TileLayout(S[a_shape[1:]]) + ) + B_local = T.alloc_buffer( + b_shape[1:], dtype, scope="local", layout=TileLayout(S[b_shape[1:]]) + ) + Tx.copy(A_local, A[tx]) + Tx.copy(B_local, B[tx]) + if op_type == "add": + Tx.add(A_local, A_local, B_local) + elif op_type == "sub": + Tx.sub(A_local, A_local, B_local) + elif op_type == "mul": + Tx.mul(A_local, A_local, B_local) + Tx.copy(A[tx], A_local) with target: np.random.seed(0) @@ -593,42 +575,36 @@ def test_binary_op_warpgroup_wg_local_layout(op_name): dev = tvm.cuda(0) target = tvm.target.Target("cuda") - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - B = Tx.match_buffer(B_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - C = Tx.match_buffer(C_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid = Tx.thread_id_in_wg([rows]) - - lhs = Tx.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) - rhs = Tx.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) - out = Tx.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) - - with Tx.thread(): - lhs_row = lhs.local(cols) - rhs_row = rhs.local(cols) - out_row = out.local(cols) - for i in Tx.serial(cols): - lhs_row[i] = A[tid, i] - rhs_row[i] = B[tid, i] - out_row[i] = Tx.float32(0) - - with Tx.warpgroup(): - if op_name == "add": - Tx.add(out, lhs, rhs) - elif op_name == "sub": - Tx.sub(out, lhs, rhs) - elif op_name == "mul": - Tx.mul(out, lhs, rhs) - - with Tx.thread(): - out_row = out.local(cols) - for i in Tx.serial(cols): - C[tid, i] = out_row[i] + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + B = T.match_buffer(B_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + C = T.match_buffer(C_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + + T.device_entry() + _bx = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([rows]) + + lhs = T.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) + rhs = T.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) + out = T.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) + lhs_row = lhs.local(cols) + rhs_row = rhs.local(cols) + out_row = out.local(cols) + for i in T.serial(cols): + lhs_row[i] = A[tid, i] + rhs_row[i] = B[tid, i] + out_row[i] = T.float32(0) + if op_name == "add": + Tx.wg.add(out, lhs, rhs) + elif op_name == "sub": + Tx.wg.sub(out, lhs, rhs) + elif op_name == "mul": + Tx.wg.mul(out, lhs, rhs) + out_row_1 = out.local(cols) + for i in T.serial(cols): + C[tid, i] = out_row_1[i] with target: np.random.seed(0) @@ -658,7 +634,7 @@ def test_binary_op_warpgroup_wg_local_emits_packed_f32x2(op_name, ptx_op): """Warpgroup-scope binary on a wg-local fp32 view must lower to packed f32x2 PTX on SM100+, mirroring the thread-scope packed dispatch. - Regression test for the fa4 perf path: rescale-style ``Tx.{add,sub,mul}`` + Regression test for the fa4 perf path: rescale-style ``T.{add,sub,mul}`` calls in warpgroup scope used to fall through to scalar codegen because ``_emit_binary_local_view`` only emitted ``op_func(...)`` per element. """ @@ -673,42 +649,36 @@ def test_binary_op_warpgroup_wg_local_emits_packed_f32x2(op_name, ptx_op): dtype = "float32" rows, cols = 128, 16 - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - B = Tx.match_buffer(B_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - C = Tx.match_buffer(C_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _wg_id = Tx.warpgroup_id([1]) - tid = Tx.thread_id_in_wg([rows]) - - lhs = Tx.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) - rhs = Tx.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) - out = Tx.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) - - with Tx.thread(): - lhs_row = lhs.local(cols) - rhs_row = rhs.local(cols) - out_row = out.local(cols) - for i in Tx.serial(cols): - lhs_row[i] = A[tid, i] - rhs_row[i] = B[tid, i] - out_row[i] = Tx.float32(0) - - with Tx.warpgroup(): - if op_name == "add": - Tx.add(out, lhs, rhs) - elif op_name == "sub": - Tx.sub(out, lhs, rhs) - else: - Tx.mul(out, lhs, rhs) - - with Tx.thread(): - out_row = out.local(cols) - for i in Tx.serial(cols): - C[tid, i] = out_row[i] + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + B = T.match_buffer(B_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + C = T.match_buffer(C_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + + T.device_entry() + _bx = T.cta_id([1]) + _wg_id = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([rows]) + + lhs = T.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) + rhs = T.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) + out = T.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) + lhs_row = lhs.local(cols) + rhs_row = rhs.local(cols) + out_row = out.local(cols) + for i in T.serial(cols): + lhs_row[i] = A[tid, i] + rhs_row[i] = B[tid, i] + out_row[i] = T.float32(0) + if op_name == "add": + Tx.wg.add(out, lhs, rhs) + elif op_name == "sub": + Tx.wg.sub(out, lhs, rhs) + else: + Tx.wg.mul(out, lhs, rhs) + out_row_1 = out.local(cols) + for i in T.serial(cols): + C[tid, i] = out_row_1[i] with target: mod = tvm.IRModule({"main": test_func}) @@ -722,7 +692,7 @@ def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: def test_fma_warpgroup_wg_local_emits_packed_f32x2(): - """Same regression coverage as the binary case but for ``Tx.fma``.""" + """Same regression coverage as the binary case but for ``T.fma``.""" target = tvm.target.Target("cuda") arch = target.arch if hasattr(target, "arch") else "" if not arch.startswith("sm_"): @@ -734,30 +704,24 @@ def test_fma_warpgroup_wg_local_emits_packed_f32x2(): dtype = "float32" rows, cols = 128, 16 - @Tx.prim_func - def test_func(A_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - C = Tx.match_buffer(C_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _wg_id = Tx.warpgroup_id([1]) - tid = Tx.thread_id_in_wg([rows]) - - buf = Tx.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) - - with Tx.thread(): - buf_row = buf.local(cols) - for i in Tx.serial(cols): - buf_row[i] = A[tid, i] - - with Tx.warpgroup(): - Tx.fma(buf, buf, Tx.float32(2.0), Tx.float32(0.5)) - - with Tx.thread(): - buf_row = buf.local(cols) - for i in Tx.serial(cols): - C[tid, i] = buf_row[i] + @T.prim_func + def test_func(A_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + C = T.match_buffer(C_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + + T.device_entry() + _bx = T.cta_id([1]) + _wg_id = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([rows]) + + buf = T.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) + buf_row = buf.local(cols) + for i in T.serial(cols): + buf_row[i] = A[tid, i] + Tx.wg.fma(buf, buf, T.float32(2.0), T.float32(0.5)) + buf_row_1 = buf.local(cols) + for i in T.serial(cols): + C[tid, i] = buf_row_1[i] with target: mod = tvm.IRModule({"main": test_func}) @@ -776,28 +740,23 @@ def test_func(A_ptr: Tx.handle, C_ptr: Tx.handle) -> None: # even on hosts where ``Target("cuda")`` cannot detect the GPU. # ----------------------------------------------------------------------------- def test_binary_add_f32_sm100_packed_f32x2_dispatch(): - """add f32 + all-local → reg.py + add_f32x2 packed (no Tx.vectorized).""" + """add f32 + all-local → reg.py + add_f32x2 packed (no T.vectorized).""" shape = (64, 32) lay = TileLayout(S[shape]) - @Tx.prim_func - def k(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, "float32", layout=lay) - B = Tx.match_buffer(B_ptr, shape, "float32", layout=lay) - Tx.device_entry() - _bx = Tx.cta_id([1]) - tx = Tx.thread_id([64]) - with Tx.thread(): - ra = Tx.alloc_buffer( - shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]]) - ) - rb = Tx.alloc_buffer( - shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]]) - ) - Tx.copy(ra, A[tx]) - Tx.copy(rb, B[tx]) - Tx.add(ra, ra, rb) - Tx.copy(A[tx], ra) + @T.prim_func + def k(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, "float32", layout=lay) + B = T.match_buffer(B_ptr, shape, "float32", layout=lay) + T.device_entry() + _bx = T.cta_id([1]) + tx = T.thread_id([64]) + ra = T.alloc_buffer(shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]])) + rb = T.alloc_buffer(shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]])) + Tx.copy(ra, A[tx]) + Tx.copy(rb, B[tx]) + Tx.add(ra, ra, rb) + Tx.copy(A[tx], ra) target = tvm.target.Target({"kind": "cuda", "arch": "sm_100a"}) with target: @@ -810,28 +769,23 @@ def k(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: def test_binary_add_f16_scalar_fallback_dispatch(): - """add f16 has no packed VecImpl → reg.py scalar fallback (Tx.vectorized).""" + """add f16 has no packed VecImpl → reg.py scalar fallback (T.vectorized).""" shape = (64, 32) lay = TileLayout(S[shape]) - @Tx.prim_func - def k(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, "float16", layout=lay) - B = Tx.match_buffer(B_ptr, shape, "float16", layout=lay) - Tx.device_entry() - _bx = Tx.cta_id([1]) - tx = Tx.thread_id([64]) - with Tx.thread(): - ra = Tx.alloc_buffer( - shape[1:], "float16", scope="local", layout=TileLayout(S[shape[1:]]) - ) - rb = Tx.alloc_buffer( - shape[1:], "float16", scope="local", layout=TileLayout(S[shape[1:]]) - ) - Tx.copy(ra, A[tx]) - Tx.copy(rb, B[tx]) - Tx.add(ra, ra, rb) - Tx.copy(A[tx], ra) + @T.prim_func + def k(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, "float16", layout=lay) + B = T.match_buffer(B_ptr, shape, "float16", layout=lay) + T.device_entry() + _bx = T.cta_id([1]) + tx = T.thread_id([64]) + ra = T.alloc_buffer(shape[1:], "float16", scope="local", layout=TileLayout(S[shape[1:]])) + rb = T.alloc_buffer(shape[1:], "float16", scope="local", layout=TileLayout(S[shape[1:]])) + Tx.copy(ra, A[tx]) + Tx.copy(rb, B[tx]) + Tx.add(ra, ra, rb) + Tx.copy(A[tx], ra) target = tvm.target.Target({"kind": "cuda", "arch": "sm_80"}) with target: diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py index dd6a8a1fdd0e..aa0f5ced8f58 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py @@ -24,7 +24,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import S, TileLayout, wg_local_layout @@ -53,17 +54,16 @@ def test_fma_scalar_scalar(): scale_val = 0.5 bias_val = -1.0 - @Tx.prim_func - def test_func(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) - Tx.device_entry() - _bx = Tx.cta_id([1]) - tx = Tx.thread_id([N]) - with Tx.thread(): - buf = Tx.alloc_buffer((1,), dtype, scope="local", layout=TileLayout(S[1])) - Tx.copy(buf, A[tx : tx + 1]) - Tx.fma(buf, buf, Tx.float32(scale_val), Tx.float32(bias_val)) - Tx.copy(A[tx : tx + 1], buf) + @T.prim_func + def test_func(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) + T.device_entry() + _bx = T.cta_id([1]) + tx = T.thread_id([N]) + buf = T.alloc_buffer((1,), dtype, scope="local", layout=TileLayout(S[1])) + Tx.copy(buf, A[tx : tx + 1]) + Tx.fma(buf, buf, T.float32(scale_val), T.float32(bias_val)) + Tx.copy(A[tx : tx + 1], buf) with target: A_np = np.random.rand(N).astype(dtype) @@ -90,20 +90,19 @@ def test_fma_buffer_scale_scalar_bias(): coeff = 0.695 - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) - B = Tx.match_buffer(B_ptr, (N,), dtype, layout=TileLayout(S[N])) - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tx = Tx.thread_id([1]) - with Tx.thread(): - acc = Tx.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) - frac = Tx.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) - Tx.copy(acc, A[0:N]) - Tx.copy(frac, B[0:N]) - Tx.fma(acc, acc, frac, Tx.float32(coeff)) - Tx.copy(A[0:N], acc) + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) + B = T.match_buffer(B_ptr, (N,), dtype, layout=TileLayout(S[N])) + T.device_entry() + _bx = T.cta_id([1]) + _tx = T.thread_id([1]) + acc = T.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) + frac = T.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) + Tx.copy(acc, A[0:N]) + Tx.copy(frac, B[0:N]) + Tx.fma(acc, acc, frac, T.float32(coeff)) + Tx.copy(A[0:N], acc) with target: A_np = np.random.rand(N).astype(dtype) @@ -130,20 +129,19 @@ def test_mul_scalar_broadcast(): dev = tvm.cuda(0) target = tvm.target.Target("cuda") - @Tx.prim_func - def test_func(A_ptr: Tx.handle, S_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) - Scale = Tx.match_buffer(S_ptr, (1,), dtype, layout=TileLayout(S[1])) - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tx = Tx.thread_id([1]) - with Tx.thread(): - a_local = Tx.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) - s_local = Tx.alloc_buffer((1,), dtype, scope="local", layout=TileLayout(S[1])) - Tx.copy(a_local, A[0:N]) - Tx.copy(s_local, Scale[0:1]) - Tx.mul(a_local, a_local, s_local[0]) - Tx.copy(A[0:N], a_local) + @T.prim_func + def test_func(A_ptr: T.handle, S_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) + Scale = T.match_buffer(S_ptr, (1,), dtype, layout=TileLayout(S[1])) + T.device_entry() + _bx = T.cta_id([1]) + _tx = T.thread_id([1]) + a_local = T.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) + s_local = T.alloc_buffer((1,), dtype, scope="local", layout=TileLayout(S[1])) + Tx.copy(a_local, A[0:N]) + Tx.copy(s_local, Scale[0:1]) + Tx.mul(a_local, a_local, s_local[0]) + Tx.copy(A[0:N], a_local) with target: A_np = np.random.rand(N).astype(dtype) @@ -172,17 +170,16 @@ def test_add_rounding_mode(): round_const = float(2**23 + 2**22) - @Tx.prim_func - def test_func(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tx = Tx.thread_id([1]) - with Tx.thread(): - buf = Tx.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) - Tx.copy(buf, A[0:N]) - Tx.add(buf, buf, Tx.float32(round_const), rounding_mode="rm") - Tx.copy(A[0:N], buf) + @T.prim_func + def test_func(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) + T.device_entry() + _bx = T.cta_id([1]) + _tx = T.thread_id([1]) + buf = T.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) + Tx.copy(buf, A[0:N]) + Tx.add(buf, buf, T.float32(round_const), rounding_mode="rm") + Tx.copy(A[0:N], buf) with target: A_np = np.array([1.3, 2.7], dtype=dtype) @@ -215,19 +212,18 @@ def test_fma_no_layout(): scale_val = 2.0 bias_val = 1.0 - @Tx.prim_func - def test_func(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tx = Tx.thread_id([1]) - with Tx.thread(): - buf = Tx.alloc_local([N], dtype) - for i in Tx.serial(N): - buf[i] = A[i] - Tx.fma(buf[0:N], buf[0:N], Tx.float32(scale_val), Tx.float32(bias_val)) - for i in Tx.serial(N): - A[i] = buf[i] + @T.prim_func + def test_func(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) + T.device_entry() + _bx = T.cta_id([1]) + _tx = T.thread_id([1]) + buf = T.alloc_local([N], dtype) + for i in T.serial(N): + buf[i] = A[i] + Tx.fma(buf[0:N], buf[0:N], T.float32(scale_val), T.float32(bias_val)) + for i in T.serial(N): + A[i] = buf[i] with target: A_np = np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype) @@ -252,20 +248,19 @@ def test_sub_buffer_buffer_rounding(): dev = tvm.cuda(0) target = tvm.target.Target("cuda") - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) - B = Tx.match_buffer(B_ptr, (N,), dtype, layout=TileLayout(S[N])) - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tx = Tx.thread_id([1]) - with Tx.thread(): - a_buf = Tx.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) - b_buf = Tx.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) - Tx.copy(a_buf, A[0:N]) - Tx.copy(b_buf, B[0:N]) - Tx.sub(a_buf, a_buf, b_buf, rounding_mode="rn") - Tx.copy(A[0:N], a_buf) + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (N,), dtype, layout=TileLayout(S[N])) + B = T.match_buffer(B_ptr, (N,), dtype, layout=TileLayout(S[N])) + T.device_entry() + _bx = T.cta_id([1]) + _tx = T.thread_id([1]) + a_buf = T.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) + b_buf = T.alloc_buffer((N,), dtype, scope="local", layout=TileLayout(S[N])) + Tx.copy(a_buf, A[0:N]) + Tx.copy(b_buf, B[0:N]) + Tx.sub(a_buf, a_buf, b_buf, rounding_mode="rn") + Tx.copy(A[0:N], a_buf) with target: A_np = np.array([3.14, 2.71], dtype=dtype) @@ -291,29 +286,23 @@ def test_fma_warpgroup_wg_local_layout(): dev = tvm.cuda(0) target = tvm.target.Target("cuda") - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - B = Tx.match_buffer(B_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - Tx.device_entry() - _bx = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid = Tx.thread_id_in_wg([rows]) - - reg = Tx.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) - - with Tx.thread(): - reg_row = reg.local(cols) - for i in Tx.serial(cols): - reg_row[i] = A[tid, i] - - with Tx.warpgroup(): - Tx.fma(reg, reg, Tx.float32(scale_val), Tx.float32(bias_val)) - - with Tx.thread(): - reg_row = reg.local(cols) - for i in Tx.serial(cols): - B[tid, i] = reg_row[i] + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + B = T.match_buffer(B_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + T.device_entry() + _bx = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([rows]) + + reg = T.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) + reg_row = reg.local(cols) + for i in T.serial(cols): + reg_row[i] = A[tid, i] + Tx.wg.fma(reg, reg, T.float32(scale_val), T.float32(bias_val)) + reg_row_1 = reg.local(cols) + for i in T.serial(cols): + B[tid, i] = reg_row_1[i] with target: np.random.seed(0) @@ -334,37 +323,28 @@ def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: # the host-detected ``Target("cuda")`` and skips when arch < sm_100). # ----------------------------------------------------------------------------- def test_fma_f32_sm100_packed_f32x2_dispatch(): - """fma f32 + all-local → reg.py + fma_f32x2 packed (no Tx.vectorized).""" + """fma f32 + all-local → reg.py + fma_f32x2 packed (no T.vectorized).""" shape = (64, 32) lay = TileLayout(S[shape]) - @Tx.prim_func - def k(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle, D_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, "float32", layout=lay) - B = Tx.match_buffer(B_ptr, shape, "float32", layout=lay) - C = Tx.match_buffer(C_ptr, shape, "float32", layout=lay) - D = Tx.match_buffer(D_ptr, shape, "float32", layout=lay) - Tx.device_entry() - _bx = Tx.cta_id([1]) - tx = Tx.thread_id([64]) - with Tx.thread(): - ra = Tx.alloc_buffer( - shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]]) - ) - rb = Tx.alloc_buffer( - shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]]) - ) - rc = Tx.alloc_buffer( - shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]]) - ) - rd = Tx.alloc_buffer( - shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]]) - ) - Tx.copy(ra, A[tx]) - Tx.copy(rb, B[tx]) - Tx.copy(rc, C[tx]) - Tx.fma(rd, ra, rb, rc) - Tx.copy(D[tx], rd) + @T.prim_func + def k(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, D_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, "float32", layout=lay) + B = T.match_buffer(B_ptr, shape, "float32", layout=lay) + C = T.match_buffer(C_ptr, shape, "float32", layout=lay) + D = T.match_buffer(D_ptr, shape, "float32", layout=lay) + T.device_entry() + _bx = T.cta_id([1]) + tx = T.thread_id([64]) + ra = T.alloc_buffer(shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]])) + rb = T.alloc_buffer(shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]])) + rc = T.alloc_buffer(shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]])) + rd = T.alloc_buffer(shape[1:], "float32", scope="local", layout=TileLayout(S[shape[1:]])) + Tx.copy(ra, A[tx]) + Tx.copy(rb, B[tx]) + Tx.copy(rc, C[tx]) + Tx.fma(rd, ra, rb, rc) + Tx.copy(D[tx], rd) target = tvm.target.Target({"kind": "cuda", "arch": "sm_100a"}) with target: diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py index bd2f6463efe9..3aa02bb5e2f0 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py @@ -21,7 +21,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import S, TileLayout, laneid, tid_in_wg, tx, warpid from tvm.tirx.operator.tile_primitive.cuda.layout_utils import ( cast_layout_supported_for_local as _cast_layout_supported_for_local, @@ -69,47 +70,43 @@ def test_unary_op_shared(input, op_type, src_dtype, dst_dtype): if in_place: # fmt: off - @Tx.prim_func - def unary_op(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, src_dtype, layout=g_layout) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tx = Tx.thread_id([thread_cnt]) - - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) - Tx.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) - Tx.cuda.cta_sync() - if op_type == "zero": - Tx.zero(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)]) - elif op_type == "sqrt": - Tx.sqrt(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)]) - Tx.cuda.cta_sync() - Tx.copy(A[tuple(copy_slice)], A_smem[tuple(copy_slice)]) + @T.prim_func + def unary_op(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, src_dtype, layout=g_layout) + + T.device_entry() + _bx = T.cta_id([1]) + _tx = T.thread_id([thread_cnt]) + A_smem = T.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) + Tx.cta.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) + T.cuda.cta_sync() + if op_type == "zero": + Tx.cta.zero(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)]) + elif op_type == "sqrt": + Tx.cta.sqrt(A_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)]) + T.cuda.cta_sync() + Tx.cta.copy(A[tuple(copy_slice)], A_smem[tuple(copy_slice)]) # fmt: on else: # fmt: off - @Tx.prim_func - def unary_op(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, src_dtype, layout=g_layout) - B = Tx.match_buffer(B_ptr, g_shape, dst_dtype, layout=g_layout) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tx = Tx.thread_id([thread_cnt]) - - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) - B_smem = Tx.alloc_buffer(s_shape, dst_dtype, scope="shared", layout=s_layout) - Tx.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) - Tx.cuda.cta_sync() - if op_type == "zero": - Tx.zero(B_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)]) - elif op_type == "sqrt": - Tx.sqrt(B_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)]) - Tx.cuda.cta_sync() - Tx.copy(B[tuple(map_slice_res)], B_smem[tuple(map_slice_res)]) + @T.prim_func + def unary_op(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, src_dtype, layout=g_layout) + B = T.match_buffer(B_ptr, g_shape, dst_dtype, layout=g_layout) + + T.device_entry() + _bx = T.cta_id([1]) + _tx = T.thread_id([thread_cnt]) + A_smem = T.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) + B_smem = T.alloc_buffer(s_shape, dst_dtype, scope="shared", layout=s_layout) + Tx.cta.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) + T.cuda.cta_sync() + if op_type == "zero": + Tx.cta.zero(B_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)]) + elif op_type == "sqrt": + Tx.cta.sqrt(B_smem[tuple(map_slice_res)], A_smem[tuple(map_slice_a)]) + T.cuda.cta_sync() + Tx.cta.copy(B[tuple(map_slice_res)], B_smem[tuple(map_slice_res)]) # fmt: on def get_ref(A_np): @@ -155,29 +152,26 @@ def test_unary_op_shared_subcta_scope(exec_scope): g_shape = (n_warps * 32, 8) dev = tvm.cuda(0) - @Tx.prim_func - def unary_op_subcta(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=TileLayout(S[g_shape])) - - Tx.device_entry() - warp_id = Tx.warp_id([(256) // 32]) - wg_id = Tx.warpgroup_id([(256) // 128]) - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([256]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(g_shape, dtype, scope="shared", layout=TileLayout(S[g_shape])) - Tx.copy(A_smem, A) - Tx.cuda.cta_sync() - if exec_scope == "warp": - if warp_id == 5: - with Tx.warp(): - Tx.zero(A_smem, A_smem) - elif exec_scope == "warpgroup": - if wg_id == 1: - with Tx.warpgroup(): - Tx.zero(A_smem, A_smem) - Tx.cuda.cta_sync() - Tx.copy(A, A_smem) + @T.prim_func + def unary_op_subcta(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=TileLayout(S[g_shape])) + + T.device_entry() + warp_id = T.warp_id([(256) // 32]) + wg_id = T.warpgroup_id([(256) // 128]) + _bx = T.cta_id([1]) + _tid = T.thread_id([256]) + A_smem = T.alloc_buffer(g_shape, dtype, scope="shared", layout=TileLayout(S[g_shape])) + Tx.cta.copy(A_smem, A) + T.cuda.cta_sync() + if exec_scope == "warp": + if warp_id == 5: + Tx.warp.zero(A_smem, A_smem) + elif exec_scope == "warpgroup": + if wg_id == 1: + Tx.wg.zero(A_smem, A_smem) + T.cuda.cta_sync() + Tx.cta.copy(A, A_smem) target = tvm.target.Target("cuda") with target: @@ -237,109 +231,105 @@ def test_unary_op_shared_with_bias_scale(input, op_type, bias_type, src_dtype, d map_slice_res = list(slice(st_res[i], st_res[i] + ext_res[i]) for i in range(len(g_shape))) # scale and bias in compute_dtype (= src_dtype) - scale = Tx.FloatImm(src_dtype, 1.5) - const_bias = Tx.FloatImm(src_dtype, 0.88) + scale = T.FloatImm(src_dtype, 1.5) + const_bias = T.FloatImm(src_dtype, 0.88) if in_place: - @Tx.prim_func - def unary_op_with_bias(A_ptr: Tx.handle, bias_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, src_dtype, layout=g_layout) - bias = Tx.match_buffer(bias_ptr, g_shape, src_dtype, layout=g_layout) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tx = Tx.thread_id([thread_cnt]) - - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) - bias_smem = Tx.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) - Tx.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) - Tx.copy(bias_smem[tuple(copy_slice)], bias[tuple(copy_slice)]) - Tx.cuda.cta_sync() - if bias_type == "const": - if op_type == "sqrt": - Tx.sqrt( - A_smem[tuple(map_slice_res)], - A_smem[tuple(map_slice_a)], - const_bias, - scale, - ) - elif op_type == "exp": - Tx.exp( - A_smem[tuple(map_slice_res)], - A_smem[tuple(map_slice_a)], - const_bias, - scale, - ) - elif bias_type == "region": - if op_type == "sqrt": - Tx.sqrt( - A_smem[tuple(map_slice_res)], - A_smem[tuple(map_slice_a)], - bias_smem[tuple(map_slice_a)], - scale, - ) - elif op_type == "exp": - Tx.exp( - A_smem[tuple(map_slice_res)], - A_smem[tuple(map_slice_a)], - bias_smem[tuple(map_slice_a)], - scale, - ) - Tx.cuda.cta_sync() - Tx.copy(A[tuple(copy_slice)], A_smem[tuple(copy_slice)]) + @T.prim_func + def unary_op_with_bias(A_ptr: T.handle, bias_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, src_dtype, layout=g_layout) + bias = T.match_buffer(bias_ptr, g_shape, src_dtype, layout=g_layout) + + T.device_entry() + _bx = T.cta_id([1]) + _tx = T.thread_id([thread_cnt]) + A_smem = T.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) + bias_smem = T.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) + Tx.cta.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) + Tx.cta.copy(bias_smem[tuple(copy_slice)], bias[tuple(copy_slice)]) + T.cuda.cta_sync() + if bias_type == "const": + if op_type == "sqrt": + Tx.cta.sqrt( + A_smem[tuple(map_slice_res)], + A_smem[tuple(map_slice_a)], + const_bias, + scale, + ) + elif op_type == "exp": + Tx.cta.exp( + A_smem[tuple(map_slice_res)], + A_smem[tuple(map_slice_a)], + const_bias, + scale, + ) + elif bias_type == "region": + if op_type == "sqrt": + Tx.cta.sqrt( + A_smem[tuple(map_slice_res)], + A_smem[tuple(map_slice_a)], + bias_smem[tuple(map_slice_a)], + scale, + ) + elif op_type == "exp": + Tx.cta.exp( + A_smem[tuple(map_slice_res)], + A_smem[tuple(map_slice_a)], + bias_smem[tuple(map_slice_a)], + scale, + ) + T.cuda.cta_sync() + Tx.cta.copy(A[tuple(copy_slice)], A_smem[tuple(copy_slice)]) else: - @Tx.prim_func - def unary_op_with_bias(A_ptr: Tx.handle, B_ptr: Tx.handle, bias_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, src_dtype, layout=g_layout) - B = Tx.match_buffer(B_ptr, g_shape, dst_dtype, layout=g_layout) - bias = Tx.match_buffer(bias_ptr, g_shape, src_dtype, layout=g_layout) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tx = Tx.thread_id([thread_cnt]) - - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) - B_smem = Tx.alloc_buffer(s_shape, dst_dtype, scope="shared", layout=s_layout) - bias_smem = Tx.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) - Tx.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) - Tx.copy(bias_smem[tuple(copy_slice)], bias[tuple(copy_slice)]) - Tx.cuda.cta_sync() - if bias_type == "const": - if op_type == "sqrt": - Tx.sqrt( - B_smem[tuple(map_slice_res)], - A_smem[tuple(map_slice_a)], - const_bias, - scale, - ) - elif op_type == "exp": - Tx.exp( - B_smem[tuple(map_slice_res)], - A_smem[tuple(map_slice_a)], - const_bias, - scale, - ) - elif bias_type == "region": - if op_type == "sqrt": - Tx.sqrt( - B_smem[tuple(map_slice_res)], - A_smem[tuple(map_slice_a)], - bias_smem[tuple(map_slice_a)], - scale, - ) - elif op_type == "exp": - Tx.exp( - B_smem[tuple(map_slice_res)], - A_smem[tuple(map_slice_a)], - bias_smem[tuple(map_slice_a)], - scale, - ) - Tx.cuda.cta_sync() - Tx.copy(B[tuple(map_slice_res)], B_smem[tuple(map_slice_res)]) + @T.prim_func + def unary_op_with_bias(A_ptr: T.handle, B_ptr: T.handle, bias_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, src_dtype, layout=g_layout) + B = T.match_buffer(B_ptr, g_shape, dst_dtype, layout=g_layout) + bias = T.match_buffer(bias_ptr, g_shape, src_dtype, layout=g_layout) + + T.device_entry() + _bx = T.cta_id([1]) + _tx = T.thread_id([thread_cnt]) + A_smem = T.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) + B_smem = T.alloc_buffer(s_shape, dst_dtype, scope="shared", layout=s_layout) + bias_smem = T.alloc_buffer(s_shape, src_dtype, scope="shared", layout=s_layout) + Tx.cta.copy(A_smem[tuple(copy_slice)], A[tuple(copy_slice)]) + Tx.cta.copy(bias_smem[tuple(copy_slice)], bias[tuple(copy_slice)]) + T.cuda.cta_sync() + if bias_type == "const": + if op_type == "sqrt": + Tx.cta.sqrt( + B_smem[tuple(map_slice_res)], + A_smem[tuple(map_slice_a)], + const_bias, + scale, + ) + elif op_type == "exp": + Tx.cta.exp( + B_smem[tuple(map_slice_res)], + A_smem[tuple(map_slice_a)], + const_bias, + scale, + ) + elif bias_type == "region": + if op_type == "sqrt": + Tx.cta.sqrt( + B_smem[tuple(map_slice_res)], + A_smem[tuple(map_slice_a)], + bias_smem[tuple(map_slice_a)], + scale, + ) + elif op_type == "exp": + Tx.cta.exp( + B_smem[tuple(map_slice_res)], + A_smem[tuple(map_slice_a)], + bias_smem[tuple(map_slice_a)], + scale, + ) + T.cuda.cta_sync() + Tx.cta.copy(B[tuple(map_slice_res)], B_smem[tuple(map_slice_res)]) def get_ref(A_np, bias_np): if in_place: @@ -456,67 +446,60 @@ def test_unary_op_local(input, op_type, src_dtype, dst_dtype): g_layout_a = g_layout_b = TileLayout(S[g_shape_a]) acc_shape = red_shape = (16, NUM_COL) - @Tx.prim_func - def test_unary(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape_a, src_dtype, layout=g_layout_a) - B = Tx.match_buffer(B_ptr, g_shape_b, dst_dtype, layout=g_layout_b) - - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - wg_id = Tx.warpgroup_id([N_GROUPS]) - warp_id_in_wg = Tx.warp_id_in_wg([N_WARPS // N_GROUPS]) - lane_id = Tx.lane_id([thread_cnt]) - - with Tx.thread(): - # acc layout - atom = Tx.TileLayout(Tx.S[(1, 2) : (2, 1)]) - warp_layout = Tx.TileLayout(Tx.S[(8, 4) : (4 @ laneid, 1 @ laneid)]) - warp_atom = atom.tile(warp_layout, (8, 4), (1, 2)) - tile = Tx.TileLayout(Tx.S[(2, NUM_COL // 8) : (1, 2)]) - acc_layout = warp_atom.tile(tile, (2, NUM_COL // 8), (8, 8)) - acc = Tx.alloc_buffer( - [2, NUM_COL // 4], - dtype=src_dtype, - scope="local", - layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), - ) - res = Tx.alloc_buffer( - [2, NUM_COL // 4], - dtype=dst_dtype, - scope="local", - layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), - ) + @T.prim_func + def test_unary(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape_a, src_dtype, layout=g_layout_a) + B = T.match_buffer(B_ptr, g_shape_b, dst_dtype, layout=g_layout_b) + + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + wg_id = T.warpgroup_id([N_GROUPS]) + warp_id_in_wg = T.warp_id_in_wg([N_WARPS // N_GROUPS]) + lane_id = T.lane_id([thread_cnt]) + # acc layout + atom = T.TileLayout(T.S[(1, 2) : (2, 1)]) + warp_layout = T.TileLayout(T.S[(8, 4) : (4 @ laneid, 1 @ laneid)]) + warp_atom = atom.tile(warp_layout, (8, 4), (1, 2)) + tile = T.TileLayout(T.S[(2, NUM_COL // 8) : (1, 2)]) + acc_layout = warp_atom.tile(tile, (2, NUM_COL // 8), (8, 8)) + acc = T.alloc_buffer( + [2, NUM_COL // 4], + dtype=src_dtype, + scope="local", + layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), + ) + res = T.alloc_buffer( + [2, NUM_COL // 4], + dtype=dst_dtype, + scope="local", + layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), + ) + for i in T.serial(NUM_COL // 8): + for j in T.unroll(2): + for vec in T.vectorized(2): + acc[j, i * 2 + vec] = A[ + wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, + i * 8 + lane_id % 4 * 2 + vec, + ] + + # unary op + acc_view = acc.view(*acc_shape, layout=acc_layout) + res_view = res.view(*red_shape, layout=acc_layout) + if op_type == "reciprocal": + Tx.warp.reciprocal(res_view, acc_view) + elif op_type == "exp": + Tx.warp.exp(res_view, acc_view) + elif op_type == "exp2": + Tx.warp.exp2(res_view, acc_view) - # load A into acc - with Tx.thread(): - for i in Tx.serial(NUM_COL // 8): - for j in Tx.unroll(2): - for vec in Tx.vectorized(2): - acc[j, i * 2 + vec] = A[ - wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, - i * 8 + lane_id % 4 * 2 + vec, - ] - - # unary op - with Tx.warp(): - acc_view = acc.view(*acc_shape, layout=acc_layout) - res_view = res.view(*red_shape, layout=acc_layout) - if op_type == "reciprocal": - Tx.reciprocal(res_view, acc_view) - elif op_type == "exp": - Tx.exp(res_view, acc_view) - elif op_type == "exp2": - Tx.exp2(res_view, acc_view) - - # write res into B - with Tx.thread(): - for i in Tx.serial(NUM_COL // 8): - for j in Tx.unroll(2): - for vec in Tx.vectorized(2): - B[ - wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, - i * 8 + lane_id % 4 * 2 + vec, - ] = res[j, i * 2 + vec] + # write res into B + for i in T.serial(NUM_COL // 8): + for j in T.unroll(2): + for vec in T.vectorized(2): + B[ + wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, + i * 8 + lane_id % 4 * 2 + vec, + ] = res[j, i * 2 + vec] # fmt: on @@ -586,91 +569,83 @@ def test_unary_op_local_with_bias_scale(input, op_type, bias_type, src_dtype, ds g_layout_a = g_layout_b = g_layout_bias = TileLayout(S[g_shape_a]) acc_shape = red_shape = bias_shape = (16, NUM_COL) - scale = Tx.float16(1.5) if src_dtype == "float16" else Tx.float32(1.5) - const_bias = Tx.float16(0.88) if src_dtype == "float16" else Tx.float32(0.88) - - @Tx.prim_func - def test_unary_with_bias(A_ptr: Tx.handle, B_ptr: Tx.handle, bias_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape_a, src_dtype, layout=g_layout_a) - B = Tx.match_buffer(B_ptr, g_shape_b, dst_dtype, layout=g_layout_b) - bias = Tx.match_buffer(bias_ptr, g_shape_bias, src_dtype, layout=g_layout_bias) - - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - wg_id = Tx.warpgroup_id([N_GROUPS]) - warp_id_in_wg = Tx.warp_id_in_wg([N_WARPS // N_GROUPS]) - lane_id = Tx.lane_id([thread_cnt]) - - with Tx.thread(): - # acc layout - atom = Tx.TileLayout(Tx.S[(1, 2) : (2, 1)]) - warp_layout = Tx.TileLayout(Tx.S[(8, 4) : (4 @ laneid, 1 @ laneid)]) - warp_atom = atom.tile(warp_layout, (8, 4), (1, 2)) - tile = Tx.TileLayout(Tx.S[(2, NUM_COL // 8) : (1, 2)]) - acc_layout = warp_atom.tile(tile, (2, NUM_COL // 8), (8, 8)) - acc = Tx.alloc_buffer( - [2, NUM_COL // 4], - dtype=src_dtype, - scope="local", - layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), - ) - bias_local = Tx.alloc_buffer( - [2, NUM_COL // 4], - dtype=src_dtype, - scope="local", - layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), - ) - res = Tx.alloc_buffer( - [2, NUM_COL // 4], - dtype=dst_dtype, - scope="local", - layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), - ) + scale = T.float16(1.5) if src_dtype == "float16" else T.float32(1.5) + const_bias = T.float16(0.88) if src_dtype == "float16" else T.float32(0.88) + + @T.prim_func + def test_unary_with_bias(A_ptr: T.handle, B_ptr: T.handle, bias_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape_a, src_dtype, layout=g_layout_a) + B = T.match_buffer(B_ptr, g_shape_b, dst_dtype, layout=g_layout_b) + bias = T.match_buffer(bias_ptr, g_shape_bias, src_dtype, layout=g_layout_bias) + + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + wg_id = T.warpgroup_id([N_GROUPS]) + warp_id_in_wg = T.warp_id_in_wg([N_WARPS // N_GROUPS]) + lane_id = T.lane_id([thread_cnt]) + # acc layout + atom = T.TileLayout(T.S[(1, 2) : (2, 1)]) + warp_layout = T.TileLayout(T.S[(8, 4) : (4 @ laneid, 1 @ laneid)]) + warp_atom = atom.tile(warp_layout, (8, 4), (1, 2)) + tile = T.TileLayout(T.S[(2, NUM_COL // 8) : (1, 2)]) + acc_layout = warp_atom.tile(tile, (2, NUM_COL // 8), (8, 8)) + acc = T.alloc_buffer( + [2, NUM_COL // 4], + dtype=src_dtype, + scope="local", + layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), + ) + bias_local = T.alloc_buffer( + [2, NUM_COL // 4], + dtype=src_dtype, + scope="local", + layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), + ) + res = T.alloc_buffer( + [2, NUM_COL // 4], + dtype=dst_dtype, + scope="local", + layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), + ) + for i in T.serial(NUM_COL // 8): + for j in T.unroll(2): + for vec in T.vectorized(2): + acc[j, i * 2 + vec] = A[ + wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, + i * 8 + lane_id % 4 * 2 + vec, + ] + # load bias into bias_local + for i in T.serial(NUM_COL // 8): + for j in T.unroll(2): + for vec in T.vectorized(2): + bias_local[j, i * 2 + vec] = bias[ + wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, + i * 8 + lane_id % 4 * 2 + vec, + ] + + # unary op + acc_view = acc.view(*acc_shape, layout=acc_layout) + res_view = res.view(*red_shape, layout=acc_layout) + bias_view = bias_local.view(*bias_shape, layout=acc_layout) + if bias_type == "const": + if op_type == "sqrt": + Tx.warp.sqrt(res_view, acc_view, const_bias, scale) + elif op_type == "exp": + Tx.warp.exp(res_view, acc_view, const_bias, scale) + elif bias_type == "region": + if op_type == "sqrt": + Tx.warp.sqrt(res_view, acc_view, bias_view, scale) + elif op_type == "exp": + Tx.warp.exp(res_view, acc_view, bias_view, scale) - # load A into acc - with Tx.thread(): - for i in Tx.serial(NUM_COL // 8): - for j in Tx.unroll(2): - for vec in Tx.vectorized(2): - acc[j, i * 2 + vec] = A[ - wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, - i * 8 + lane_id % 4 * 2 + vec, - ] - # load bias into bias_local - with Tx.thread(): - for i in Tx.serial(NUM_COL // 8): - for j in Tx.unroll(2): - for vec in Tx.vectorized(2): - bias_local[j, i * 2 + vec] = bias[ - wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, - i * 8 + lane_id % 4 * 2 + vec, - ] - - # unary op - with Tx.warp(): - acc_view = acc.view(*acc_shape, layout=acc_layout) - res_view = res.view(*red_shape, layout=acc_layout) - bias_view = bias_local.view(*bias_shape, layout=acc_layout) - if bias_type == "const": - if op_type == "sqrt": - Tx.sqrt(res_view, acc_view, const_bias, scale) - elif op_type == "exp": - Tx.exp(res_view, acc_view, const_bias, scale) - elif bias_type == "region": - if op_type == "sqrt": - Tx.sqrt(res_view, acc_view, bias_view, scale) - elif op_type == "exp": - Tx.exp(res_view, acc_view, bias_view, scale) - - # write res into B - with Tx.thread(): - for i in Tx.serial(NUM_COL // 8): - for j in Tx.unroll(2): - for vec in Tx.vectorized(2): - B[ - wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, - i * 8 + lane_id % 4 * 2 + vec, - ] = res[j, i * 2 + vec] + # write res into B + for i in T.serial(NUM_COL // 8): + for j in T.unroll(2): + for vec in T.vectorized(2): + B[ + wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, + i * 8 + lane_id % 4 * 2 + vec, + ] = res[j, i * 2 + vec] def get_ref(A_np, bias_np): A_ref = A_np.copy() @@ -718,42 +693,40 @@ def test_unary_op_vectorized(shape, op_type, exec_scope, storage_scope): dtype = "float16" A_ref = np.random.rand(*shape).astype(dtype) A = tvm.runtime.tensor(A_ref, dev) - value = Tx.float16(7.89) if dtype == "float16" else Tx.float32(7.89) + value = T.float16(7.89) if dtype == "float16" else T.float32(7.89) # fmt: off - @Tx.prim_func - def test_unary_thread(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype, layout=TileLayout(S[shape])) - Tx.device_entry() - _bx = Tx.cta_id([1]) - tx = Tx.thread_id([128]) - with Tx.thread(): - if storage_scope == "shared": - a_smem = Tx.alloc_buffer( - shape, dtype=dtype, layout=TileLayout(S[shape]), scope="shared" - ) - Tx.fill(a_smem[tx], value) - Tx.copy(A[tx], a_smem[tx]) - elif storage_scope == "local": - a_local = Tx.alloc_buffer( - shape[1:], dtype=dtype, layout=TileLayout(S[shape[1:]]), scope="local" - ) - Tx.fill(a_local, value) - Tx.copy(A[tx], a_local) - - @Tx.prim_func - def test_unary_cta(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype, layout=TileLayout(S[shape])) - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([128]) - with Tx.cta(): - if storage_scope == "shared": - a_smem = Tx.alloc_buffer( - shape, dtype=dtype, layout=TileLayout(S[shape]), scope="shared" - ) - Tx.fill(a_smem, value) - Tx.copy(A, a_smem) + @T.prim_func + def test_unary_thread(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype, layout=TileLayout(S[shape])) + T.device_entry() + _bx = T.cta_id([1]) + tx = T.thread_id([128]) + if storage_scope == "shared": + a_smem = T.alloc_buffer( + shape, dtype=dtype, layout=TileLayout(S[shape]), scope="shared" + ) + Tx.fill(a_smem[tx], value) + Tx.copy(A[tx], a_smem[tx]) + elif storage_scope == "local": + a_local = T.alloc_buffer( + shape[1:], dtype=dtype, layout=TileLayout(S[shape[1:]]), scope="local" + ) + Tx.fill(a_local, value) + Tx.copy(A[tx], a_local) + + @T.prim_func + def test_unary_cta(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype, layout=TileLayout(S[shape])) + T.device_entry() + _bx = T.cta_id([1]) + _tid = T.thread_id([128]) + if storage_scope == "shared": + a_smem = T.alloc_buffer( + shape, dtype=dtype, layout=TileLayout(S[shape]), scope="shared" + ) + Tx.cta.fill(a_smem, value) + Tx.cta.copy(A, a_smem) # fmt: on target = tvm.target.Target("cuda") @@ -775,28 +748,27 @@ def test_unary_op_local_thread_wise(op_type, dtype): local_shape = shape[1:] dev = tvm.cuda(0) - @Tx.prim_func - def kernel(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, dtype, layout=TileLayout(S[shape])) - Tx.device_entry() - _bx = Tx.cta_id([1]) - tid = Tx.thread_id([64]) - with Tx.thread(): - a_local = Tx.alloc_buffer( - local_shape, dtype, scope="local", layout=TileLayout(S[local_shape]) - ) - Tx.copy(a_local, A[tid]) - if op_type == "zero": - Tx.zero(a_local, a_local) - elif op_type == "sqrt": - Tx.sqrt(a_local, a_local) - elif op_type == "reciprocal": - Tx.reciprocal(a_local, a_local) - elif op_type == "exp": - Tx.exp(a_local, a_local) - elif op_type == "silu": - Tx.silu(a_local, a_local) - Tx.copy(A[tid], a_local) + @T.prim_func + def kernel(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, dtype, layout=TileLayout(S[shape])) + T.device_entry() + _bx = T.cta_id([1]) + tid = T.thread_id([64]) + a_local = T.alloc_buffer( + local_shape, dtype, scope="local", layout=TileLayout(S[local_shape]) + ) + Tx.copy(a_local, A[tid]) + if op_type == "zero": + Tx.zero(a_local, a_local) + elif op_type == "sqrt": + Tx.sqrt(a_local, a_local) + elif op_type == "reciprocal": + Tx.reciprocal(a_local, a_local) + elif op_type == "exp": + Tx.exp(a_local, a_local) + elif op_type == "silu": + Tx.silu(a_local, a_local) + Tx.copy(A[tid], a_local) target = tvm.target.Target("cuda") with target: @@ -835,20 +807,19 @@ def test_cast_thread_local(shape, A_dtype, B_dtype): B_ref = A_ref.astype(B_dtype) # fmt: off - @Tx.prim_func - def test_cast(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, A_dtype, layout=TileLayout(S[shape])) - B = Tx.match_buffer(B_ptr, shape, B_dtype, layout=TileLayout(S[shape])) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([256]) - with Tx.thread(): - A_local = Tx.alloc_local(shape, dtype=A_dtype, layout=TileLayout(S[shape])) - B_local = Tx.alloc_local(shape, dtype=B_dtype, layout=TileLayout(S[shape])) - Tx.copy(A_local, A) - Tx.cast(B_local, A_local) - Tx.copy(B, B_local) + @T.prim_func + def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, A_dtype, layout=TileLayout(S[shape])) + B = T.match_buffer(B_ptr, shape, B_dtype, layout=TileLayout(S[shape])) + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([256]) + A_local = T.alloc_local(shape, dtype=A_dtype, layout=TileLayout(S[shape])) + B_local = T.alloc_local(shape, dtype=B_dtype, layout=TileLayout(S[shape])) + Tx.copy(A_local, A) + Tx.cast(B_local, A_local) + Tx.copy(B, B_local) # fmt: on target = tvm.target.Target("cuda") @@ -862,7 +833,7 @@ def test_cast(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) def test_cast_warpgroup_local_view(A_dtype, B_dtype): - """Tx.cast in warpgroup scope with offset (tid_in_wg + layout offset). Covers offset/tid_in_wg/warpgroup scope.""" # noqa: E501 + """T.cast in warpgroup scope with offset (tid_in_wg + layout offset). Covers offset/tid_in_wg/warpgroup scope.""" # noqa: E501 N_THREADS, LOCAL_LEN = 128, 8 g_shape = (N_THREADS, LOCAL_LEN) g_layout = TileLayout(S[g_shape]) @@ -884,29 +855,24 @@ def test_cast_warpgroup_local_view(A_dtype, B_dtype): B_ref = A_ref.astype(B_dtype) # fmt: off - @Tx.prim_func - def test_cast(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, A_dtype, layout=g_layout) - B = Tx.match_buffer(B_ptr, g_shape, B_dtype, layout=g_layout) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([N_THREADS]) - - with Tx.thread(): - reg_src = Tx.alloc_buffer((LOCAL_LEN,), A_dtype, scope="local") - reg_dst = Tx.alloc_buffer((LOCAL_LEN,), B_dtype, scope="local") - with Tx.thread(): - for i in Tx.serial(LOCAL_LEN): - reg_src[i] = A[tid_in_wg, i] - with Tx.warpgroup(): - reg_src_view = reg_src.view(N_THREADS, LOCAL_LEN, layout=cast_layout) - reg_dst_view = reg_dst.view(N_THREADS, LOCAL_LEN, layout=cast_layout) - Tx.cast(reg_dst_view, reg_src_view) - with Tx.thread(): - for i in Tx.serial(LOCAL_LEN): - B[tid_in_wg, i] = reg_dst[i] + @T.prim_func + def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, A_dtype, layout=g_layout) + B = T.match_buffer(B_ptr, g_shape, B_dtype, layout=g_layout) + + T.device_entry() + cta_id = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([N_THREADS]) + reg_src = T.alloc_buffer((LOCAL_LEN,), A_dtype, scope="local") + reg_dst = T.alloc_buffer((LOCAL_LEN,), B_dtype, scope="local") + for i in T.serial(LOCAL_LEN): + reg_src[i] = A[tid_in_wg, i] + reg_src_view = reg_src.view(N_THREADS, LOCAL_LEN, layout=cast_layout) + reg_dst_view = reg_dst.view(N_THREADS, LOCAL_LEN, layout=cast_layout) + Tx.wg.cast(reg_dst_view, reg_src_view) + for i in T.serial(LOCAL_LEN): + B[tid_in_wg, i] = reg_dst[i] # fmt: on target = tvm.target.Target("cuda") @@ -940,34 +906,29 @@ def test_cast_warpgroup_src_layout_to_flat_uses_vec2_intrinsic(A_dtype, B_dtype) B_ref = A_ref.astype(B_dtype) # fmt: off - @Tx.prim_func - def test_cast(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, A_dtype, layout=g_layout) - B = Tx.match_buffer(B_ptr, g_shape, B_dtype, layout=g_layout) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid = Tx.thread_id_in_wg([N_THREADS]) - - with Tx.thread(): - for no in Tx.unroll(N_CHUNKS): - reg_src = Tx.alloc_buffer((LOCAL_LEN,), A_dtype, scope="local") - Dreg_chunk = Tx.alloc_buffer((LOCAL_LEN,), B_dtype, scope="local") - with Tx.thread(): - for i in Tx.serial(LOCAL_LEN): - reg_src[i] = A[tid, no * LOCAL_LEN + i] - with Tx.warpgroup(): - reg_src_view = reg_src.view( - N_THREADS, LOCAL_LEN, layout=wg_local_layout(LOCAL_LEN) - ) - Dreg_chunk_view = Dreg_chunk.view( - N_THREADS, LOCAL_LEN, layout=wg_local_layout(LOCAL_LEN) - ) - Tx.cast(Dreg_chunk_view, reg_src_view) - with Tx.thread(): - for i in Tx.serial(LOCAL_LEN): - B[tid, no * LOCAL_LEN + i] = Dreg_chunk[i] + @T.prim_func + def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, A_dtype, layout=g_layout) + B = T.match_buffer(B_ptr, g_shape, B_dtype, layout=g_layout) + + T.device_entry() + cta_id = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([N_THREADS]) + for no in T.unroll(N_CHUNKS): + reg_src = T.alloc_buffer((LOCAL_LEN,), A_dtype, scope="local") + Dreg_chunk = T.alloc_buffer((LOCAL_LEN,), B_dtype, scope="local") + for i in T.serial(LOCAL_LEN): + reg_src[i] = A[tid, no * LOCAL_LEN + i] + reg_src_view = reg_src.view( + N_THREADS, LOCAL_LEN, layout=wg_local_layout(LOCAL_LEN) + ) + Dreg_chunk_view = Dreg_chunk.view( + N_THREADS, LOCAL_LEN, layout=wg_local_layout(LOCAL_LEN) + ) + Tx.wg.cast(Dreg_chunk_view, reg_src_view) + for i in T.serial(LOCAL_LEN): + B[tid, no * LOCAL_LEN + i] = Dreg_chunk[i] # fmt: on target = tvm.target.Target("cuda") @@ -976,7 +937,7 @@ def test_cast(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: mod = tvm.compile(mod, target=target, tir_pipeline="tirx") src = mod.mod.imports[0].inspect_source() # The packed vec2 cast intrinsic must be present — guards against - # falling back to scalar Tx.cast inside Tx.vectorized. + # falling back to scalar T.cast inside T.vectorized. helper = f"tvm_builtin_cast_{A_dtype}x2_{B_dtype}x2" assert helper in src, f"expected {helper!r} in generated CUDA, fell back to scalar cast" mod(A, B) @@ -985,7 +946,7 @@ def test_cast(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) def test_cast_cta_local_view(A_dtype, B_dtype): - """Tx.cast with view+layout in CTA scope (128 threads, register->register).""" + """T.cast with view+layout in CTA scope (128 threads, register->register).""" N_THREADS, LOCAL_LEN = 128, 8 g_shape = (N_THREADS, LOCAL_LEN) g_layout = TileLayout(S[g_shape]) @@ -999,28 +960,23 @@ def test_cast_cta_local_view(A_dtype, B_dtype): B_ref = A_ref.astype(B_dtype) # fmt: off - @Tx.prim_func - def test_cast(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, A_dtype, layout=g_layout) - B = Tx.match_buffer(B_ptr, g_shape, B_dtype, layout=g_layout) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tx_var = Tx.thread_id([N_THREADS]) - - with Tx.thread(): - reg_src = Tx.alloc_buffer((LOCAL_LEN,), A_dtype, scope="local") - reg_dst = Tx.alloc_buffer((LOCAL_LEN,), B_dtype, scope="local") - with Tx.thread(): - for i in Tx.serial(LOCAL_LEN): - reg_src[i] = A[tx_var, i] - with Tx.cta(): - reg_src_view = reg_src.view(N_THREADS, LOCAL_LEN, layout=cast_layout) - reg_dst_view = reg_dst.view(N_THREADS, LOCAL_LEN, layout=cast_layout) - Tx.cast(reg_dst_view, reg_src_view) - with Tx.thread(): - for i in Tx.serial(LOCAL_LEN): - B[tx_var, i] = reg_dst[i] + @T.prim_func + def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, A_dtype, layout=g_layout) + B = T.match_buffer(B_ptr, g_shape, B_dtype, layout=g_layout) + + T.device_entry() + cta_id = T.cta_id([1]) + tx_var = T.thread_id([N_THREADS]) + reg_src = T.alloc_buffer((LOCAL_LEN,), A_dtype, scope="local") + reg_dst = T.alloc_buffer((LOCAL_LEN,), B_dtype, scope="local") + for i in T.serial(LOCAL_LEN): + reg_src[i] = A[tx_var, i] + reg_src_view = reg_src.view(N_THREADS, LOCAL_LEN, layout=cast_layout) + reg_dst_view = reg_dst.view(N_THREADS, LOCAL_LEN, layout=cast_layout) + Tx.cta.cast(reg_dst_view, reg_src_view) + for i in T.serial(LOCAL_LEN): + B[tx_var, i] = reg_dst[i] # fmt: on target = tvm.target.Target("cuda") @@ -1035,7 +991,7 @@ def test_cast(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) @pytest.mark.parametrize("slice_start,slice_end", [(0, 4), (2, 6), (4, 8)]) def test_cast_local_view_sliced(A_dtype, B_dtype, slice_start, slice_end): - """Tx.cast with sliced view in CTA scope — exercises _emit_cast_local_view_sliced.""" + """T.cast with sliced view in CTA scope — exercises _emit_cast_local_view_sliced.""" N_THREADS, LOCAL_LEN = 128, 8 g_shape = (N_THREADS, LOCAL_LEN) g_layout = TileLayout(S[g_shape]) @@ -1049,29 +1005,25 @@ def test_cast_local_view_sliced(A_dtype, B_dtype, slice_start, slice_end): B_ref[:, slice_start:slice_end] = A_ref[:, slice_start:slice_end].astype(B_dtype) # fmt: off - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, A_dtype, layout=g_layout) - B = Tx.match_buffer(B_ptr, g_shape, B_dtype, layout=g_layout) - Tx.device_entry() - _bx = Tx.cta_id([1]) - tx = Tx.thread_id([N_THREADS]) - with Tx.thread(): - reg_src = Tx.alloc_buffer((LOCAL_LEN,), A_dtype, scope="local") - reg_dst = Tx.alloc_buffer((LOCAL_LEN,), B_dtype, scope="local") - with Tx.thread(): - for i in Tx.serial(LOCAL_LEN): - reg_src[i] = A[tx, i] - with Tx.cta(): - reg_src_view = reg_src.view(N_THREADS, LOCAL_LEN, layout=cast_layout) - reg_dst_view = reg_dst.view(N_THREADS, LOCAL_LEN, layout=cast_layout) - Tx.cast( - reg_dst_view[0:N_THREADS, slice_start:slice_end], - reg_src_view[0:N_THREADS, slice_start:slice_end], - ) - with Tx.thread(): - for i in Tx.serial(LOCAL_LEN): - B[tx, i] = reg_dst[i] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, A_dtype, layout=g_layout) + B = T.match_buffer(B_ptr, g_shape, B_dtype, layout=g_layout) + T.device_entry() + _bx = T.cta_id([1]) + tx = T.thread_id([N_THREADS]) + reg_src = T.alloc_buffer((LOCAL_LEN,), A_dtype, scope="local") + reg_dst = T.alloc_buffer((LOCAL_LEN,), B_dtype, scope="local") + for i in T.serial(LOCAL_LEN): + reg_src[i] = A[tx, i] + reg_src_view = reg_src.view(N_THREADS, LOCAL_LEN, layout=cast_layout) + reg_dst_view = reg_dst.view(N_THREADS, LOCAL_LEN, layout=cast_layout) + Tx.cta.cast( + reg_dst_view[0:N_THREADS, slice_start:slice_end], + reg_src_view[0:N_THREADS, slice_start:slice_end], + ) + for i in T.serial(LOCAL_LEN): + B[tx, i] = reg_dst[i] # fmt: on target = tvm.target.Target("cuda") @@ -1158,32 +1110,28 @@ def test_cast_mixed_axes_and_subregion(slice_start, slice_end): A = tvm.runtime.tensor(A_ref, dev) B = tvm.runtime.tensor(np.zeros(full_shape, dtype="float16"), dev) - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, full_shape, "float32", layout=g_layout) - B = Tx.match_buffer(B_ptr, full_shape, "float16", layout=g_layout) - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([N_WARPS]) - lane_id = Tx.lane_id([LANES]) - with Tx.thread(): - reg_src = Tx.alloc_buffer((LOCAL_LEN,), "float32", scope="local") - reg_dst = Tx.alloc_buffer((LOCAL_LEN,), "float16", scope="local") - with Tx.thread(): - j, k = lane_id // 4, lane_id % 4 - for i in Tx.serial(LOCAL_LEN): - reg_src[i] = A[j, warp_id, k, i] - with Tx.cta(): - reg_src_view = reg_src.view(*full_shape, layout=cast_layout) - reg_dst_view = reg_dst.view(*full_shape, layout=cast_layout) - Tx.cast( - reg_dst_view[0:8, 0:N_WARPS, 0:4, slice_start:slice_end], - reg_src_view[0:8, 0:N_WARPS, 0:4, slice_start:slice_end], - ) - with Tx.thread(): - j, k = lane_id // 4, lane_id % 4 - for i in Tx.serial(LOCAL_LEN): - B[j, warp_id, k, i] = reg_dst[i] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, full_shape, "float32", layout=g_layout) + B = T.match_buffer(B_ptr, full_shape, "float16", layout=g_layout) + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([N_WARPS]) + lane_id = T.lane_id([LANES]) + reg_src = T.alloc_buffer((LOCAL_LEN,), "float32", scope="local") + reg_dst = T.alloc_buffer((LOCAL_LEN,), "float16", scope="local") + j, k = lane_id // 4, lane_id % 4 + for i in T.serial(LOCAL_LEN): + reg_src[i] = A[j, warp_id, k, i] + reg_src_view = reg_src.view(*full_shape, layout=cast_layout) + reg_dst_view = reg_dst.view(*full_shape, layout=cast_layout) + Tx.cta.cast( + reg_dst_view[0:8, 0:N_WARPS, 0:4, slice_start:slice_end], + reg_src_view[0:8, 0:N_WARPS, 0:4, slice_start:slice_end], + ) + j_1, k_1 = lane_id // 4, lane_id % 4 + for i in T.serial(LOCAL_LEN): + B[j_1, warp_id, k_1, i] = reg_dst[i] target = tvm.target.Target("cuda") with target: @@ -1236,29 +1184,25 @@ def test_cast_validate_extent_mismatch_rejected(): S[view_shape : (2 @ warpid, 8 @ laneid, 1 @ laneid, 1)] ) # dim1 extent 8 != 4 - @Tx.prim_func - def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, view_shape, "float32", layout=g_layout) - B = Tx.match_buffer(B_ptr, view_shape, "float16", layout=g_layout) - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([2]) - lane_id = Tx.lane_id([32]) - with Tx.thread(): - reg_src = Tx.alloc_buffer((8,), "float32", scope="local") - reg_dst = Tx.alloc_buffer((8,), "float16", scope="local") - with Tx.thread(): - j, k = lane_id // 4, lane_id % 4 - for i in Tx.serial(8): - reg_src[i] = A[warp_id, j, k, i] - with Tx.cta(): - reg_src_view = reg_src.view(*view_shape, layout=src_layout) - reg_dst_view = reg_dst.view(*view_shape, layout=dst_layout) - Tx.cast(reg_dst_view, reg_src_view) - with Tx.thread(): - j, k = lane_id // 4, lane_id % 4 - for i in Tx.serial(8): - B[warp_id, j, k, i] = reg_dst[i] + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, view_shape, "float32", layout=g_layout) + B = T.match_buffer(B_ptr, view_shape, "float16", layout=g_layout) + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([2]) + lane_id = T.lane_id([32]) + reg_src = T.alloc_buffer((8,), "float32", scope="local") + reg_dst = T.alloc_buffer((8,), "float16", scope="local") + j, k = lane_id // 4, lane_id % 4 + for i in T.serial(8): + reg_src[i] = A[warp_id, j, k, i] + reg_src_view = reg_src.view(*view_shape, layout=src_layout) + reg_dst_view = reg_dst.view(*view_shape, layout=dst_layout) + Tx.cta.cast(reg_dst_view, reg_src_view) + j_1, k_1 = lane_id // 4, lane_id % 4 + for i in T.serial(8): + B[warp_id, j_1, k_1, i] = reg_dst[i] target = tvm.target.Target("cuda") with target: @@ -1273,24 +1217,22 @@ def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: # Dispatch codegen checks (no GPU runtime — explicit target arch). # ----------------------------------------------------------------------------- def test_unary_exp_f16_shared_scalar_fallback_dispatch(): - """exp f16 + shared cta → smem.py + scalar (Tx.vectorized) — no exp packed.""" + """exp f16 + shared cta → smem.py + scalar (T.vectorized) — no exp packed.""" shape = (64, 32) lay = TileLayout(S[shape]) - @Tx.prim_func - def k(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, "float16", layout=lay) - B = Tx.match_buffer(B_ptr, shape, "float16", layout=lay) - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tx = Tx.thread_id([64]) - with Tx.thread(): - sa = Tx.alloc_buffer(shape, "float16", scope="shared", layout=lay) - sb = Tx.alloc_buffer(shape, "float16", scope="shared", layout=lay) - Tx.copy(sa, A) - with Tx.cta(): - Tx.exp(sb, sa) - Tx.copy(B, sb) + @T.prim_func + def k(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, "float16", layout=lay) + B = T.match_buffer(B_ptr, shape, "float16", layout=lay) + T.device_entry() + _bx = T.cta_id([1]) + _tx = T.thread_id([64]) + sa = T.alloc_buffer(shape, "float16", scope="shared", layout=lay) + sb = T.alloc_buffer(shape, "float16", scope="shared", layout=lay) + Tx.copy(sa, A) + Tx.cta.exp(sb, sa) + Tx.copy(B, sb) target = tvm.target.Target({"kind": "cuda", "arch": "sm_80"}) with target: @@ -1312,23 +1254,18 @@ def test_cast_vec2_packed_dispatch(src_dtype, dst_dtype, intrinsic): shape = (64, 32) lay = TileLayout(S[shape]) - @Tx.prim_func - def k(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, shape, src_dtype, layout=lay) - B = Tx.match_buffer(B_ptr, shape, dst_dtype, layout=lay) - Tx.device_entry() - _bx = Tx.cta_id([1]) - tx = Tx.thread_id([64]) - with Tx.thread(): - ra = Tx.alloc_buffer( - shape[1:], src_dtype, scope="local", layout=TileLayout(S[shape[1:]]) - ) - rb = Tx.alloc_buffer( - shape[1:], dst_dtype, scope="local", layout=TileLayout(S[shape[1:]]) - ) - Tx.copy(ra, A[tx]) - Tx.cast(rb, ra) - Tx.copy(B[tx], rb) + @T.prim_func + def k(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, shape, src_dtype, layout=lay) + B = T.match_buffer(B_ptr, shape, dst_dtype, layout=lay) + T.device_entry() + _bx = T.cta_id([1]) + tx = T.thread_id([64]) + ra = T.alloc_buffer(shape[1:], src_dtype, scope="local", layout=TileLayout(S[shape[1:]])) + rb = T.alloc_buffer(shape[1:], dst_dtype, scope="local", layout=TileLayout(S[shape[1:]])) + Tx.copy(ra, A[tx]) + Tx.cast(rb, ra) + Tx.copy(B[tx], rb) target = tvm.target.Target({"kind": "cuda", "arch": "sm_80"}) with target: diff --git a/tests/python/tirx/operator/tile_primitive/cuda/gemm/test_gemm_mma_m16n8k_.py b/tests/python/tirx/operator/tile_primitive/cuda/gemm/test_gemm_mma_m16n8k_.py index 8a645dbe62e9..516366365f34 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/gemm/test_gemm_mma_m16n8k_.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/gemm/test_gemm_mma_m16n8k_.py @@ -16,7 +16,7 @@ # under the License. """Tests for the CUDA synchronous ``gemm`` (mma.sync) tensor-core dispatch. -The dispatch lowers ``tirx.gemm`` over pure-register fragments to warp-level +The dispatch lowers ``tirx.tile.gemm`` over pure-register fragments to warp-level ``mma.sync.aligned.m16n8k16/k8`` for bf16/f16 inputs with f32 accumulation. The fragment layouts below are the standard m16n8 register maps (PTX ISA @@ -36,7 +36,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import S, TileLayout, laneid from tvm.tirx.operator.tile_primitive import list_registered_schedules @@ -91,10 +92,10 @@ def _frag(Mt, Nt, Kt, kinst): def _build_tiled(Mt, Nt, Kt, kinst, *, beta=0.0, dtype="float16", store=False): - """A single-warp kernel issuing one ``Tx.gemm`` over an Mt x Nt x Kt tiling. + """A single-warp kernel issuing one ``T.gemm`` over an Mt x Nt x Kt tiling. With ``store=True`` the result is written back to a global buffer (a full - kernel for codegen); otherwise only the ``Tx.gemm`` is emitted (for + kernel for codegen); otherwise only the ``T.gemm`` is emitted (for ``LowerTIRx`` dispatch checks). """ Dl, Al, Bl = _frag(Mt, Nt, Kt, kinst) @@ -102,64 +103,58 @@ def _build_tiled(Mt, Nt, Kt, kinst, *, beta=0.0, dtype="float16", store=False): if not store: - @Tx.prim_func + @T.prim_func def gemm(): - Tx.device_entry() - _cta = Tx.cta_id([1]) - _warp = Tx.warp_id([1]) - _lane = Tx.lane_id([32]) - with Tx.cta(): - A = Tx.alloc_buffer((M, K), dtype, scope="local", layout=Al) - B = Tx.alloc_buffer((K, N), dtype, scope="local", layout=Bl) - C = Tx.alloc_buffer((M, N), "float32", scope="local", layout=Dl) - D = Tx.alloc_buffer((M, N), "float32", scope="local", layout=Dl) - with Tx.warp(): - Tx.gemm(D, A, B, C, transpose_A=False, transpose_B=False, alpha=1.0, beta=beta) + T.device_entry() + _cta = T.cta_id([1]) + _warp = T.warp_id([1]) + _lane = T.lane_id([32]) + A = T.alloc_buffer((M, K), dtype, scope="local", layout=Al) + B = T.alloc_buffer((K, N), dtype, scope="local", layout=Bl) + C = T.alloc_buffer((M, N), "float32", scope="local", layout=Dl) + D = T.alloc_buffer((M, N), "float32", scope="local", layout=Dl) + Tx.warp.gemm(D, A, B, C, transpose_A=False, transpose_B=False, alpha=1.0, beta=beta) return gemm - @Tx.prim_func - def gemm(D_ptr: Tx.handle): - D_g = Tx.match_buffer(D_ptr, (M, N), "float32") - Tx.device_entry() - _cta = Tx.cta_id([1]) - _warp = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - A = Tx.alloc_buffer((M, K), dtype, scope="local", layout=Al) - B = Tx.alloc_buffer((K, N), dtype, scope="local", layout=Bl) - C = Tx.alloc_buffer((M, N), "float32", scope="local", layout=Dl) - D = Tx.alloc_buffer((M, N), "float32", scope="local", layout=Dl) - with Tx.warp(): - Tx.gemm(D, A, B, C, transpose_A=False, transpose_B=False, alpha=1.0, beta=beta) - # Decode D's per-thread registers (c = ((mt*Nt + nt)*2 + rM)*2 + rN) - # back to logical (M, N) and store, exercising the whole tiling. - D_reg = D.local(Mt * Nt * 4) - for c in Tx.unroll(Mt * Nt * 4): - rN = c % 2 - rM = (c // 2) % 2 - nt = (c // 4) % Nt - mt = c // (4 * Nt) - D_g[mt * 16 + lane // 4 + rM * 8, nt * 8 + (lane % 4) * 2 + rN] = D_reg[c] + @T.prim_func + def gemm(D_ptr: T.handle): + D_g = T.match_buffer(D_ptr, (M, N), "float32") + T.device_entry() + _cta = T.cta_id([1]) + _warp = T.warp_id([1]) + lane = T.lane_id([32]) + A = T.alloc_buffer((M, K), dtype, scope="local", layout=Al) + B = T.alloc_buffer((K, N), dtype, scope="local", layout=Bl) + C = T.alloc_buffer((M, N), "float32", scope="local", layout=Dl) + D = T.alloc_buffer((M, N), "float32", scope="local", layout=Dl) + Tx.warp.gemm(D, A, B, C, transpose_A=False, transpose_B=False, alpha=1.0, beta=beta) + # Decode D's per-thread registers (c = ((mt*Nt + nt)*2 + rM)*2 + rN) + # back to logical (M, N) and store, exercising the whole tiling. + D_reg = D.local(Mt * Nt * 4) + for c in T.unroll(Mt * Nt * 4): + rN = c % 2 + rM = (c // 2) % 2 + nt = (c // 4) % Nt + mt = c // (4 * Nt) + D_g[mt * 16 + lane // 4 + rM * 8, nt * 8 + (lane % 4) * 2 + rN] = D_reg[c] return gemm def _build_gemm(alpha=1.0, beta=0.0, dtype="bfloat16"): - """A single-warp kernel issuing one ``Tx.gemm`` over register fragments.""" + """A single-warp kernel issuing one ``T.gemm`` over register fragments.""" - @Tx.prim_func + @T.prim_func def gemm_min(): - Tx.device_entry() - _cta = Tx.cta_id([1]) - _tid = Tx.thread_id([32]) - with Tx.cta(): - D = Tx.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) - C = Tx.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) - A = Tx.alloc_buffer((16, 16), dtype, scope="local", layout=A_FRAG) - B = Tx.alloc_buffer((16, 8), dtype, scope="local", layout=B_FRAG) - with Tx.warp(): - Tx.gemm(D, A, B, C, transpose_A=False, transpose_B=False, alpha=alpha, beta=beta) + T.device_entry() + _cta = T.cta_id([1]) + _tid = T.thread_id([32]) + D = T.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) + C = T.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) + A = T.alloc_buffer((16, 16), dtype, scope="local", layout=A_FRAG) + B = T.alloc_buffer((16, 8), dtype, scope="local", layout=B_FRAG) + Tx.warp.gemm(D, A, B, C, transpose_A=False, transpose_B=False, alpha=alpha, beta=beta) return gemm_min @@ -173,57 +168,53 @@ def _build_transpose(transpose_A, transpose_B, *, store=False): if not store: - @Tx.prim_func + @T.prim_func def gemm(): - Tx.device_entry() - _cta = Tx.cta_id([1]) - _warp = Tx.warp_id([1]) - _lane = Tx.lane_id([32]) - with Tx.cta(): - A = Tx.alloc_buffer(A_shape, "float16", scope="local", layout=Al) - B = Tx.alloc_buffer(B_shape, "float16", scope="local", layout=Bl) - C = Tx.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) - D = Tx.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) - with Tx.warp(): - Tx.gemm( - D, - A, - B, - C, - transpose_A=transpose_A, - transpose_B=transpose_B, - alpha=1.0, - beta=0.0, - ) + T.device_entry() + _cta = T.cta_id([1]) + _warp = T.warp_id([1]) + _lane = T.lane_id([32]) + A = T.alloc_buffer(A_shape, "float16", scope="local", layout=Al) + B = T.alloc_buffer(B_shape, "float16", scope="local", layout=Bl) + C = T.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) + D = T.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) + Tx.warp.gemm( + D, + A, + B, + C, + transpose_A=transpose_A, + transpose_B=transpose_B, + alpha=1.0, + beta=0.0, + ) return gemm - @Tx.prim_func - def gemm(D_ptr: Tx.handle): - D_g = Tx.match_buffer(D_ptr, (16, 8), "float32") - Tx.device_entry() - _cta = Tx.cta_id([1]) - _warp = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - A = Tx.alloc_buffer(A_shape, "float16", scope="local", layout=Al) - B = Tx.alloc_buffer(B_shape, "float16", scope="local", layout=Bl) - C = Tx.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) - D = Tx.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) - with Tx.warp(): - Tx.gemm( - D, - A, - B, - C, - transpose_A=transpose_A, - transpose_B=transpose_B, - alpha=1.0, - beta=0.0, - ) - D_reg = D.local(4) - for c in Tx.unroll(4): - D_g[lane // 4 + (c // 2) * 8, (lane % 4) * 2 + c % 2] = D_reg[c] + @T.prim_func + def gemm(D_ptr: T.handle): + D_g = T.match_buffer(D_ptr, (16, 8), "float32") + T.device_entry() + _cta = T.cta_id([1]) + _warp = T.warp_id([1]) + lane = T.lane_id([32]) + A = T.alloc_buffer(A_shape, "float16", scope="local", layout=Al) + B = T.alloc_buffer(B_shape, "float16", scope="local", layout=Bl) + C = T.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) + D = T.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) + Tx.warp.gemm( + D, + A, + B, + C, + transpose_A=transpose_A, + transpose_B=transpose_B, + alpha=1.0, + beta=0.0, + ) + D_reg = D.local(4) + for c in T.unroll(4): + D_g[lane // 4 + (c // 2) * 8, (lane % 4) * 2 + c % 2] = D_reg[c] return gemm @@ -231,24 +222,22 @@ def gemm(D_ptr: Tx.handle): def _build_dtypes(a_dtype, b_dtype, c_dtype, d_dtype): """Single tile with explicit per-operand dtypes (for decline checks).""" - @Tx.prim_func + @T.prim_func def gemm_min(): - Tx.device_entry() - _cta = Tx.cta_id([1]) - _tid = Tx.thread_id([32]) - with Tx.cta(): - D = Tx.alloc_buffer((16, 8), d_dtype, scope="local", layout=D_FRAG) - C = Tx.alloc_buffer((16, 8), c_dtype, scope="local", layout=D_FRAG) - A = Tx.alloc_buffer((16, 16), a_dtype, scope="local", layout=A_FRAG) - B = Tx.alloc_buffer((16, 8), b_dtype, scope="local", layout=B_FRAG) - with Tx.warp(): - Tx.gemm(D, A, B, C, transpose_A=False, transpose_B=False, alpha=1.0, beta=0.0) + T.device_entry() + _cta = T.cta_id([1]) + _tid = T.thread_id([32]) + D = T.alloc_buffer((16, 8), d_dtype, scope="local", layout=D_FRAG) + C = T.alloc_buffer((16, 8), c_dtype, scope="local", layout=D_FRAG) + A = T.alloc_buffer((16, 16), a_dtype, scope="local", layout=A_FRAG) + B = T.alloc_buffer((16, 8), b_dtype, scope="local", layout=B_FRAG) + Tx.warp.gemm(D, A, B, C, transpose_A=False, transpose_B=False, alpha=1.0, beta=0.0) return gemm_min def _build_tiled_numeric(Mt, Nt, Kt, kinst, beta, dtype): - """End-to-end ``Tx.gemm`` over an Mt x Nt x Kt tiling, with the A/B inputs + """End-to-end ``T.gemm`` over an Mt x Nt x Kt tiling, with the A/B inputs loaded and the D output stored register-by-register. Fragments are indexed through their per-register multi-dim ``.local()`` views @@ -262,54 +251,48 @@ def _build_tiled_numeric(Mt, Nt, Kt, kinst, beta, dtype): KP = 2 kHi_n = kinst // (4 * KP) - @Tx.prim_func - def gemm(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle, D_ptr: Tx.handle): - A_g = Tx.match_buffer(A_ptr, (M, K), dtype) - B_g = Tx.match_buffer(B_ptr, (K, N), dtype) - C_g = Tx.match_buffer(C_ptr, (M, N), "float32") - D_g = Tx.match_buffer(D_ptr, (M, N), "float32") - Tx.device_entry() - _cta = Tx.cta_id([1]) - _warp = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - A_f = Tx.alloc_buffer((M, K), dtype, scope="local", layout=Al) - B_f = Tx.alloc_buffer((K, N), dtype, scope="local", layout=Bl) - C_f = Tx.alloc_buffer((M, N), "float32", scope="local", layout=Dl) - D_f = Tx.alloc_buffer((M, N), "float32", scope="local", layout=Dl) - with Tx.warp(): - A_reg = A_f.local(Mt, 2, Kt, kHi_n, KP) - for mt, rM, kt, kHi, kp in Tx.grid(Mt, 2, Kt, kHi_n, KP): - A_reg[mt, rM, kt, kHi, kp] = A_g[ - mt * 16 + lane // 4 + 8 * rM, - kt * kinst + kHi * 8 + 2 * (lane % 4) + kp, - ] - B_reg = B_f.local(Kt, kHi_n, KP, Nt) - for kt, kHi, kp, nt in Tx.grid(Kt, kHi_n, KP, Nt): - B_reg[kt, kHi, kp, nt] = B_g[ - kt * kinst + kHi * 8 + 2 * (lane % 4) + kp, - nt * 8 + lane // 4, - ] - if beta == 1.0: - C_reg = C_f.local(Mt, 2, Nt, 2) - for mt, rM, nt, rN in Tx.grid(Mt, 2, Nt, 2): - C_reg[mt, rM, nt, rN] = C_g[ - mt * 16 + lane // 4 + 8 * rM, nt * 8 + 2 * (lane % 4) + rN - ] - Tx.gemm( - D_f, A_f, B_f, C_f, transpose_A=False, transpose_B=False, alpha=1.0, beta=beta - ) - D_reg = D_f.local(Mt, 2, Nt, 2) - for mt, rM, nt, rN in Tx.grid(Mt, 2, Nt, 2): - D_g[mt * 16 + lane // 4 + 8 * rM, nt * 8 + 2 * (lane % 4) + rN] = D_reg[ - mt, rM, nt, rN - ] + @T.prim_func + def gemm(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, D_ptr: T.handle): + A_g = T.match_buffer(A_ptr, (M, K), dtype) + B_g = T.match_buffer(B_ptr, (K, N), dtype) + C_g = T.match_buffer(C_ptr, (M, N), "float32") + D_g = T.match_buffer(D_ptr, (M, N), "float32") + T.device_entry() + _cta = T.cta_id([1]) + _warp = T.warp_id([1]) + lane = T.lane_id([32]) + A_f = T.alloc_buffer((M, K), dtype, scope="local", layout=Al) + B_f = T.alloc_buffer((K, N), dtype, scope="local", layout=Bl) + C_f = T.alloc_buffer((M, N), "float32", scope="local", layout=Dl) + D_f = T.alloc_buffer((M, N), "float32", scope="local", layout=Dl) + A_reg = A_f.local(Mt, 2, Kt, kHi_n, KP) + for mt, rM, kt, kHi, kp in T.grid(Mt, 2, Kt, kHi_n, KP): + A_reg[mt, rM, kt, kHi, kp] = A_g[ + mt * 16 + lane // 4 + 8 * rM, + kt * kinst + kHi * 8 + 2 * (lane % 4) + kp, + ] + B_reg = B_f.local(Kt, kHi_n, KP, Nt) + for kt, kHi, kp, nt in T.grid(Kt, kHi_n, KP, Nt): + B_reg[kt, kHi, kp, nt] = B_g[ + kt * kinst + kHi * 8 + 2 * (lane % 4) + kp, + nt * 8 + lane // 4, + ] + if beta == 1.0: + C_reg = C_f.local(Mt, 2, Nt, 2) + for mt, rM, nt, rN in T.grid(Mt, 2, Nt, 2): + C_reg[mt, rM, nt, rN] = C_g[ + mt * 16 + lane // 4 + 8 * rM, nt * 8 + 2 * (lane % 4) + rN + ] + Tx.warp.gemm(D_f, A_f, B_f, C_f, transpose_A=False, transpose_B=False, alpha=1.0, beta=beta) + D_reg = D_f.local(Mt, 2, Nt, 2) + for mt, rM, nt, rN in T.grid(Mt, 2, Nt, 2): + D_g[mt * 16 + lane // 4 + 8 * rM, nt * 8 + 2 * (lane % 4) + rN] = D_reg[mt, rM, nt, rN] return gemm, M, N, K def _build_transpose_numeric(transpose_A, transpose_B, dtype="float16"): - """End-to-end single-tile ``Tx.gemm`` for one A/B input orientation. + """End-to-end single-tile ``T.gemm`` for one A/B input orientation. The transposed A fragment (``A_KM_FRAG``) carries its registers in the [kHi, kp, rM] shard order (vs [rM, kHi, kp] for the K-major ``A_FRAG``); B's @@ -322,50 +305,48 @@ def _build_transpose_numeric(transpose_A, transpose_B, dtype="float16"): A_shape = (16, 16) B_shape = (8, 16) if transpose_B else (16, 8) - @Tx.prim_func - def gemm(A_ptr: Tx.handle, B_ptr: Tx.handle, D_ptr: Tx.handle): - A_g = Tx.match_buffer(A_ptr, A_shape, dtype) - B_g = Tx.match_buffer(B_ptr, B_shape, dtype) - D_g = Tx.match_buffer(D_ptr, (16, 8), "float32") - Tx.device_entry() - _cta = Tx.cta_id([1]) - _warp = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - A_f = Tx.alloc_buffer(A_shape, dtype, scope="local", layout=Al) - B_f = Tx.alloc_buffer(B_shape, dtype, scope="local", layout=Bl) - D_f = Tx.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) - with Tx.warp(): - A_reg = A_f.local(2, 2, 2) - if transpose_A: - # A_KM_FRAG register order is [kHi, kp, rM]; buffer is [K, M]. - for kHi, kp, rM in Tx.grid(2, 2, 2): - A_reg[kHi, kp, rM] = A_g[2 * (lane % 4) + kp + 8 * kHi, lane // 4 + 8 * rM] - else: - # A_FRAG register order is [rM, kHi, kp]; buffer is [M, K]. - for rM, kHi, kp in Tx.grid(2, 2, 2): - A_reg[rM, kHi, kp] = A_g[lane // 4 + 8 * rM, 2 * (lane % 4) + kp + 8 * kHi] - B_reg = B_f.local(2, 2) - if transpose_B: - # B_NK_FRAG buffer is [N, K]. - for kHi, kp in Tx.grid(2, 2): - B_reg[kHi, kp] = B_g[lane // 4, 2 * (lane % 4) + kp + 8 * kHi] - else: - for kHi, kp in Tx.grid(2, 2): - B_reg[kHi, kp] = B_g[2 * (lane % 4) + kp + 8 * kHi, lane // 4] - Tx.gemm( - D_f, - A_f, - B_f, - D_f, - transpose_A=transpose_A, - transpose_B=transpose_B, - alpha=1.0, - beta=0.0, - ) - D_reg = D_f.local(2, 2) - for rM, rN in Tx.grid(2, 2): - D_g[lane // 4 + 8 * rM, 2 * (lane % 4) + rN] = D_reg[rM, rN] + @T.prim_func + def gemm(A_ptr: T.handle, B_ptr: T.handle, D_ptr: T.handle): + A_g = T.match_buffer(A_ptr, A_shape, dtype) + B_g = T.match_buffer(B_ptr, B_shape, dtype) + D_g = T.match_buffer(D_ptr, (16, 8), "float32") + T.device_entry() + _cta = T.cta_id([1]) + _warp = T.warp_id([1]) + lane = T.lane_id([32]) + A_f = T.alloc_buffer(A_shape, dtype, scope="local", layout=Al) + B_f = T.alloc_buffer(B_shape, dtype, scope="local", layout=Bl) + D_f = T.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) + A_reg = A_f.local(2, 2, 2) + if transpose_A: + # A_KM_FRAG register order is [kHi, kp, rM]; buffer is [K, M]. + for kHi, kp, rM in T.grid(2, 2, 2): + A_reg[kHi, kp, rM] = A_g[2 * (lane % 4) + kp + 8 * kHi, lane // 4 + 8 * rM] + else: + # A_FRAG register order is [rM, kHi, kp]; buffer is [M, K]. + for rM, kHi, kp in T.grid(2, 2, 2): + A_reg[rM, kHi, kp] = A_g[lane // 4 + 8 * rM, 2 * (lane % 4) + kp + 8 * kHi] + B_reg = B_f.local(2, 2) + if transpose_B: + # B_NK_FRAG buffer is [N, K]. + for kHi, kp in T.grid(2, 2): + B_reg[kHi, kp] = B_g[lane // 4, 2 * (lane % 4) + kp + 8 * kHi] + else: + for kHi, kp in T.grid(2, 2): + B_reg[kHi, kp] = B_g[2 * (lane % 4) + kp + 8 * kHi, lane // 4] + Tx.warp.gemm( + D_f, + A_f, + B_f, + D_f, + transpose_A=transpose_A, + transpose_B=transpose_B, + alpha=1.0, + beta=0.0, + ) + D_reg = D_f.local(2, 2) + for rM, rN in T.grid(2, 2): + D_g[lane // 4 + 8 * rM, 2 * (lane % 4) + rN] = D_reg[rM, rN] return gemm @@ -378,11 +359,11 @@ def _lower(func): def test_cuda_gemm_mma_variant_is_registered(): # Importing tvm.tirx registers all per-target schedule variants. The new # synchronous CUDA mma path must show up for ("gemm", "cuda"). The registry - # keys ops by their full name (``op.name`` == "tirx.gemm"). + # keys ops by their full name (``op.name`` == "tirx.tile.gemm"). schedules = list_registered_schedules() - cuda_gemm = schedules.get("tirx.gemm", {}).get("cuda", []) + cuda_gemm = schedules.get("tirx.tile.gemm", {}).get("cuda", []) assert "mma.m16n8k*" in cuda_gemm, ( - f"mma.m16n8k* not registered; tirx.gemm schedules = {schedules.get('tirx.gemm')}" + f"mma.m16n8k* not registered; tirx.tile.gemm schedules = {schedules.get('tirx.tile.gemm')}" ) @@ -392,10 +373,10 @@ def test_cuda_gemm_mma_lowers_to_mma_sync(dtype): the registers laid out in the fixed PTX fragment order.""" script = _lower(_build_gemm(alpha=1.0, beta=0.0, dtype=dtype))["main"].script() - assert "Tx.ptx.mma(" in script + assert "T.ptx.mma(" in script assert "m16n8k16" in script # beta == 0 clears the accumulator before the K loop. - assert "Tx.float32(0" in script + assert "T.float32(0" in script # D accumulator: c_id = 2*rM + rN -> regs 0..3. for r in range(4): assert f"d_local[{r}]" in script @@ -411,11 +392,11 @@ def test_cuda_gemm_mma_accumulates_c_when_beta_one(): """beta=1: the accumulator is initialized by copying C instead of zeroing.""" script = _lower(_build_gemm(alpha=1.0, beta=1.0))["main"].script() - assert "Tx.ptx.mma(" in script + assert "T.ptx.mma(" in script assert "m16n8k16" in script # The init reads C into D; nothing is zeroed. assert "c_local[" in script - assert "Tx.float32(0" not in script + assert "T.float32(0" not in script def test_cuda_gemm_mma_rejects_nonunit_alpha(): @@ -438,7 +419,7 @@ def test_cuda_gemm_mma_numerical(dtype): A is [M, K] = [16, 16], B is [K, N] = [16, 8], D is [M, N] = [16, 8]. The lane-distributed register fragments cannot be filled with a whole-tile - ``Tx.copy`` (the per-thread axis can't be matched coordinate-wise), so each + ``T.copy`` (the per-thread axis can't be matched coordinate-wise), so each of a lane's registers is loaded/stored by decoding the m16n8k16 register map with ``g = lane >> 2`` and ``t = lane & 3``. The per-register *slot* order matches the dispatch's fragment register layout: @@ -453,39 +434,35 @@ def test_cuda_gemm_mma_numerical(dtype): else: np_dtype = np.float16 - @Tx.prim_func - def gemm(A_ptr: Tx.handle, B_ptr: Tx.handle, D_ptr: Tx.handle): - A_g = Tx.match_buffer(A_ptr, (16, 16), dtype) - B_g = Tx.match_buffer(B_ptr, (16, 8), dtype) - D_g = Tx.match_buffer(D_ptr, (16, 8), "float32") - Tx.device_entry() - _cta = Tx.cta_id([1]) - _warp = Tx.warp_id([1]) - lane = Tx.lane_id([32]) - with Tx.cta(): - A_f = Tx.alloc_buffer((16, 16), dtype, scope="local", layout=A_FRAG) - B_f = Tx.alloc_buffer((16, 8), dtype, scope="local", layout=B_FRAG) - D_f = Tx.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) - with Tx.warp(): - A_reg = A_f.local(8) - for s in Tx.unroll(8): - kp = s % 2 - kHi = (s // 2) % 2 - rM = s // 4 - A_reg[s] = A_g[lane // 4 + 8 * rM, 2 * (lane % 4) + kp + 8 * kHi] - B_reg = B_f.local(4) - for s in Tx.unroll(4): - kp = s % 2 - kHi = s // 2 - B_reg[s] = B_g[2 * (lane % 4) + kp + 8 * kHi, lane // 4] - Tx.gemm( - D_f, A_f, B_f, D_f, transpose_A=False, transpose_B=False, alpha=1.0, beta=0.0 - ) - D_reg = D_f.local(4) - for s in Tx.unroll(4): - rN = s % 2 - rM = s // 2 - D_g[lane // 4 + 8 * rM, 2 * (lane % 4) + rN] = D_reg[s] + @T.prim_func + def gemm(A_ptr: T.handle, B_ptr: T.handle, D_ptr: T.handle): + A_g = T.match_buffer(A_ptr, (16, 16), dtype) + B_g = T.match_buffer(B_ptr, (16, 8), dtype) + D_g = T.match_buffer(D_ptr, (16, 8), "float32") + T.device_entry() + _cta = T.cta_id([1]) + _warp = T.warp_id([1]) + lane = T.lane_id([32]) + A_f = T.alloc_buffer((16, 16), dtype, scope="local", layout=A_FRAG) + B_f = T.alloc_buffer((16, 8), dtype, scope="local", layout=B_FRAG) + D_f = T.alloc_buffer((16, 8), "float32", scope="local", layout=D_FRAG) + A_reg = A_f.local(8) + for s in T.unroll(8): + kp = s % 2 + kHi = (s // 2) % 2 + rM = s // 4 + A_reg[s] = A_g[lane // 4 + 8 * rM, 2 * (lane % 4) + kp + 8 * kHi] + B_reg = B_f.local(4) + for s in T.unroll(4): + kp = s % 2 + kHi = s // 2 + B_reg[s] = B_g[2 * (lane % 4) + kp + 8 * kHi, lane // 4] + Tx.warp.gemm(D_f, A_f, B_f, D_f, transpose_A=False, transpose_B=False, alpha=1.0, beta=0.0) + D_reg = D_f.local(4) + for s in T.unroll(4): + rN = s % 2 + rM = s // 2 + D_g[lane // 4 + 8 * rM, 2 * (lane % 4) + rN] = D_reg[s] dev = tvm.cuda(0) with tvm.target.Target("cuda"): @@ -615,7 +592,7 @@ def test_cuda_gemm_mma_lowers_tiled(Mt, Nt, Kt, kinst): (an extent-1 high-K register group must not be rejected as a thread axis). """ script = _lower(_build_tiled(Mt, Nt, Kt, kinst))["main"].script() - assert "Tx.ptx.mma(" in script + assert "T.ptx.mma(" in script assert f"m16n8k{kinst}" in script @@ -656,7 +633,7 @@ def test_cuda_gemm_mma_lowers_transpose(transpose_A, transpose_B): """All four A/B orientations dispatch to the same m16n8k16. transpose only describes the input's logical orientation; the .row.col mma is unchanged.""" script = _lower(_build_transpose(transpose_A, transpose_B))["main"].script() - assert "Tx.ptx.mma(" in script + assert "T.ptx.mma(" in script assert "m16n8k16" in script diff --git a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py index 0076d1026480..8c32bbe04839 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py @@ -30,7 +30,8 @@ import tvm import tvm.testing from tvm.ir.type import PointerType, PrimType -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg from tvm.tirx.operator.tile_primitive.cuda.gemm_async import sf_tmem_layout @@ -210,68 +211,61 @@ def test_gemm_tcgen05_cta_group_1(task): r_smem_B = list(slice(B_region[i][0], B_region[i][1]) for i in range(len(B_shape))) # fmt: off - @Tx.prim_func - def gemm_async(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, A_shape, A_dtype) - B = Tx.match_buffer(B_ptr, B_shape, B_dtype) - C = Tx.match_buffer(C_ptr, C_shape, C_dtype) - - Tx.device_entry() - warp_id = Tx.warp_id([(1) * 4]) - cta_id = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - - A_smem = Tx.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) - B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) - tmem_addr = Tx.alloc_shared([1], "uint32") - tma_mbar = Tx.alloc_shared([1], "uint64") - mma_mbar = Tx.alloc_shared([1], "uint64") + @T.prim_func + def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, A_shape, A_dtype) + B = T.match_buffer(B_ptr, B_shape, B_dtype) + C = T.match_buffer(C_ptr, C_shape, C_dtype) + + T.device_entry() + warp_id = T.warp_id([(1) * 4]) + cta_id = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + + A_smem = T.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) + B_smem = T.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) - Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=cols_alloc, cta_group=1) - Tx.cuda.cta_sync() - tmem = Tx.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=cols_alloc, cta_group=1) + T.cuda.cta_sync() + tmem = T.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 if tid_in_wg == 0: - with Tx.thread(): - tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) - Tx.copy_async(A_smem[tuple(r_gmem_A)], A[tuple(r_gmem_A)], **tma_args) - Tx.copy_async(B_smem[tuple(r_gmem_B)], B[tuple(r_gmem_B)], **tma_args) - Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) - Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) + Tx.copy_async(A_smem[tuple(r_gmem_A)], A[tuple(r_gmem_A)], **tma_args) + Tx.copy_async(B_smem[tuple(r_gmem_B)], B[tuple(r_gmem_B)], **tma_args) + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() if tid_in_wg == 0: - with Tx.thread(): - Tx.gemm_async(tmem[tuple(r_tmem_C)], A_smem[tuple(r_smem_A)], B_smem[tuple(r_smem_B)], dispatch="tcgen05") # noqa: E501 - Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) - Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() - - Tx.ptx.tcgen05.fence.after_thread_sync() - C_reg = Tx.alloc_local(width, dtype=C_dtype) + Tx.gemm_async(tmem[tuple(r_tmem_C)], A_smem[tuple(r_smem_A)], B_smem[tuple(r_smem_B)], dispatch="tcgen05") # noqa: E501 + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + + T.ptx.tcgen05.fence.after_thread_sync() + C_reg = T.alloc_local(width, dtype=C_dtype) C_view = C_reg.view(128, width, layout=TileLayout(S[(128, width) : (1@axis_tid_in_wg, 1)])) if wg_id == 0: - with Tx.warpgroup(): - Tx.copy_async(C_view[:, :], tmem[tuple(r_tmem_C)]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - with Tx.thread(): - Tx.copy(C[tid_in_wg, C_region[1][0]:C_region[1][1]], C_reg[:]) + Tx.wg.copy_async(C_view[:, :], tmem[tuple(r_tmem_C)]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + Tx.copy(C[tid_in_wg, C_region[1][0]:C_region[1][1]], C_reg[:]) if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=1) + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=1) # fmt: on dev = tvm.cuda(0) @@ -322,81 +316,75 @@ def test_gemm_tcgen05_cta_group_1_layout_f_m64(): c_layout = tmem_datapath_layout("F", 64, N) # fmt: off - @Tx.prim_func - def gemm_layout_f(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, A_shape, A_dtype) - B = Tx.match_buffer(B_ptr, B_shape, B_dtype) - C = Tx.match_buffer(C_ptr, C_shape, C_dtype) - - Tx.device_entry() - warp_id = Tx.warp_id([4]) - cta_id = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - lane_id = Tx.lane_id([32]) - - A_smem = Tx.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) - B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) - tmem_addr = Tx.alloc_shared([1], "uint32") - tma_mbar = Tx.alloc_shared([1], "uint64") - mma_mbar = Tx.alloc_shared([1], "uint64") + @T.prim_func + def gemm_layout_f(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, A_shape, A_dtype) + B = T.match_buffer(B_ptr, B_shape, B_dtype) + C = T.match_buffer(C_ptr, C_shape, C_dtype) + + T.device_entry() + warp_id = T.warp_id([4]) + cta_id = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + lane_id = T.lane_id([32]) + + A_smem = T.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) + B_smem = T.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) - Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=64, cta_group=1) - Tx.cuda.cta_sync() + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=64, cta_group=1) + T.cuda.cta_sync() # Layout F C operand — the path under test. - tmem = Tx.decl_buffer((64, N), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=c_layout) # noqa: E501 + tmem = T.decl_buffer((64, N), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=c_layout) # noqa: E501 if tid_in_wg == 0: - with Tx.thread(): - tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) - Tx.copy_async(A_smem[:, :], A[:, :], **tma_args) - Tx.copy_async(B_smem[:, :], B[:, :], **tma_args) - Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), (M * K + N * K) * 2) - Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) + Tx.copy_async(A_smem[:, :], A[:, :], **tma_args) + Tx.copy_async(B_smem[:, :], B[:, :], **tma_args) + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), (M * K + N * K) * 2) + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() if tid_in_wg == 0: - with Tx.thread(): - Tx.gemm_async(tmem[0:64, 0:N], A_smem[:, :], B_smem[:, :], dispatch="tcgen05") - Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) - Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() - Tx.ptx.tcgen05.fence.after_thread_sync() + Tx.gemm_async(tmem[0:64, 0:N], A_smem[:, :], B_smem[:, :], dispatch="tcgen05") + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + T.ptx.tcgen05.fence.after_thread_sync() # Read back via .16x256b M=64 (the canonical pairing). - reg = Tx.alloc_local(32, dtype="float32") + reg = T.alloc_local(32, dtype="float32") reg_view = reg.view(64, N, layout=tcgen05_atom_layout("16x256b", (64, N), "float32")) if wg_id == 0: - with Tx.warpgroup(): - Tx.copy_async(reg_view[:, :], tmem[0:64, 0:N]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() + Tx.wg.copy_async(reg_view[:, :], tmem[0:64, 0:N]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() # Per-(reg -> row, col) decomposition for .16x256b M=64 fp32 (BT=64 -> rep=8): # r = v0p + 2*va + 4*vb, v0p in {0,1}, va in {0,1}, vb in [0, 8) # row = (lane_id >> 2) + 8*va + 16*warp_id # col = v0p + ((lane_id & 3) << 1) + 8*vb - for vb in Tx.unroll(8): - for va in Tx.unroll(2): - for v0p in Tx.unroll(2): - r: Tx.let = v0p + 2 * va + 4 * vb - row: Tx.let = (lane_id >> 2) + 8 * va + 16 * warp_id - col: Tx.let = v0p + ((lane_id & 3) << 1) + 8 * vb + for vb in T.unroll(8): + for va in T.unroll(2): + for v0p in T.unroll(2): + r: T.let = v0p + 2 * va + 4 * vb + row: T.let = (lane_id >> 2) + 8 * va + 16 * warp_id + col: T.let = v0p + ((lane_id & 3) << 1) + 8 * vb C[row, col] = reg[r] if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=64, cta_group=1) + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=64, cta_group=1) # fmt: on dev = tvm.cuda(0) @@ -464,77 +452,70 @@ def test_gemm_tcgen05_cta_group_2(task): r_smem_B = list(slice(B_region[i][0], B_region[i][1]) for i in range(len(B_shape))) # fmt: off - @Tx.prim_func - def gemm_async(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, A_shape, A_dtype) - B = Tx.match_buffer(B_ptr, B_shape, B_dtype) - C = Tx.match_buffer(C_ptr, C_shape, C_dtype) - - Tx.device_entry() - warp_id = Tx.warp_id([(1) * 4]) - cbx, cby = Tx.cta_id_in_cluster([2, 1]) - cta_id = Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - - A_smem = Tx.alloc_buffer(A_shape_per_cta, A_dtype, scope="shared", layout=A_layout) - B_smem = Tx.alloc_buffer(B_shape_per_cta, B_dtype, scope="shared", layout=B_layout) - tmem_addr = Tx.alloc_shared([1], "uint32") - tma_mbar = Tx.alloc_shared([1], "uint64") - mma_mbar = Tx.alloc_shared([1], "uint64") - - ptr: Tx.let[Tx.Var(name="ptr", dtype=PointerType(PrimType("uint64")))] = Tx.reinterpret("handle", Tx.ptx.map_shared_rank(tma_mbar.ptr_to([0]), 0)) # noqa: E501 - tma_mbar_cta_0 = Tx.decl_buffer([1], "uint64", data=ptr, scope="shared") + @T.prim_func + def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, A_shape, A_dtype) + B = T.match_buffer(B_ptr, B_shape, B_dtype) + C = T.match_buffer(C_ptr, C_shape, C_dtype) + + T.device_entry() + warp_id = T.warp_id([(1) * 4]) + cbx, cby = T.cta_id_in_cluster([2, 1]) + cta_id = T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + + A_smem = T.alloc_buffer(A_shape_per_cta, A_dtype, scope="shared", layout=A_layout) + B_smem = T.alloc_buffer(B_shape_per_cta, B_dtype, scope="shared", layout=B_layout) + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") + + ptr: T.let[T.Var(name="ptr", dtype=PointerType(PrimType("uint64")))] = T.reinterpret("handle", T.ptx.map_shared_rank(tma_mbar.ptr_to([0]), 0)) # noqa: E501 + tma_mbar_cta_0 = T.decl_buffer([1], "uint64", data=ptr, scope="shared") if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) - Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=cols_alloc, cta_group=2) - tmem = Tx.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 - Tx.ptx.fence.mbarrier_init() - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - Tx.cuda.cluster_sync() - - tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar_cta_0.ptr_to([0]), "cta_group": 2}) # noqa: E501 + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=cols_alloc, cta_group=2) + tmem = T.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + T.ptx.fence.mbarrier_init() + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + T.cuda.cluster_sync() + + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar_cta_0.ptr_to([0]), "cta_group": 2}) # noqa: E501 if tid_in_wg == 0: - with Tx.thread(): - Tx.copy_async(A_smem[tuple(r_smem_A_in)], A[tuple(get_global_region(A_shape_per_cta, transA, cbx))], **tma_args) # noqa: E501 - Tx.copy_async(B_smem[tuple(r_smem_B_in)], B[tuple(get_global_region(B_shape_per_cta, transB, cbx))], **tma_args) # noqa: E501 - if cbx == 0: - Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) + Tx.copy_async(A_smem[tuple(r_smem_A_in)], A[tuple(get_global_region(A_shape_per_cta, transA, cbx))], **tma_args) # noqa: E501 + Tx.copy_async(B_smem[tuple(r_smem_B_in)], B[tuple(get_global_region(B_shape_per_cta, transB, cbx))], **tma_args) # noqa: E501 + if cbx == 0: + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) if cbx == 0: - Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) - Tx.ptx.tcgen05.fence.after_thread_sync() - Tx.cuda.cta_sync() + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.ptx.tcgen05.fence.after_thread_sync() + T.cuda.cta_sync() if tid_in_wg == 0: - with Tx.thread(): - Tx.gemm_async(tmem[tuple(r_tmem_C)], A_smem[tuple(r_smem_A)], B_smem[tuple(r_smem_B)], dispatch="tcgen05", cta_group=2) # noqa: E501 - Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=2, cta_mask=3) # signal cta 1's mbarrier # noqa: E501 - Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) # both cta 0 and cta 1 have done mma - Tx.ptx.tcgen05.fence.after_thread_sync() - Tx.cuda.cta_sync() - - C_reg = Tx.alloc_local(width , dtype=C_dtype) + Tx.gemm_async(tmem[tuple(r_tmem_C)], A_smem[tuple(r_smem_A)], B_smem[tuple(r_smem_B)], dispatch="tcgen05", cta_group=2) # noqa: E501 + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=2, cta_mask=3) # signal cta 1's mbarrier # noqa: E501 + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) # both cta 0 and cta 1 have done mma + T.ptx.tcgen05.fence.after_thread_sync() + T.cuda.cta_sync() + + C_reg = T.alloc_local(width , dtype=C_dtype) C_view = C_reg.view(128, width, layout=TileLayout(S[(128, width) : (1@axis_tid_in_wg, 1)])) if wg_id == 0: - with Tx.warpgroup(): - Tx.copy_async(C_view[:, :], tmem[C_region[0][0]:C_region[0][1], C_region[1][0]:C_region[1][0] + width]) # noqa: E501 - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - with Tx.thread(): - Tx.copy(C[cbx * 128 +tid_in_wg, C_region[1][0]:C_region[1][0] + width], C_reg[:]) - Tx.cuda.cta_sync() + Tx.wg.copy_async(C_view[:, :], tmem[C_region[0][0]:C_region[0][1], C_region[1][0]:C_region[1][0] + width]) # noqa: E501 + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + Tx.copy(C[cbx * 128 +tid_in_wg, C_region[1][0]:C_region[1][0] + width], C_reg[:]) + T.cuda.cta_sync() if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=2) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=2) + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=2) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=2) # fmt: on dev = tvm.cuda(0) @@ -599,87 +580,77 @@ def test_gemm_tcgen05_cta_group_2_layout_b(): total_bytes = per_cta_bytes * 2 # fmt: off - @Tx.prim_func - def gemm_async(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M_per_cta * 2, K), A_dtype) - B = Tx.match_buffer(B_ptr, (N_logical, K), B_dtype) - C = Tx.match_buffer(C_ptr, C_shape, C_dtype) - - Tx.device_entry() - warp_id = Tx.warp_id([(1) * 4]) - cbx, cby = Tx.cta_id_in_cluster([2, 1]) - cta_id = Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - - A_smem = Tx.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) - B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) - tmem_addr = Tx.alloc_shared([1], "uint32") - tma_mbar = Tx.alloc_shared([1], "uint64") - mma_mbar = Tx.alloc_shared([1], "uint64") - - ptr: Tx.let[Tx.Var(name="ptr", dtype=PointerType(PrimType("uint64")))] = Tx.reinterpret("handle", Tx.ptx.map_shared_rank(tma_mbar.ptr_to([0]), 0)) # noqa: E501 - tma_mbar_cta_0 = Tx.decl_buffer([1], "uint64", data=ptr, scope="shared") + @T.prim_func + def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M_per_cta * 2, K), A_dtype) + B = T.match_buffer(B_ptr, (N_logical, K), B_dtype) + C = T.match_buffer(C_ptr, C_shape, C_dtype) + + T.device_entry() + warp_id = T.warp_id([(1) * 4]) + cbx, cby = T.cta_id_in_cluster([2, 1]) + cta_id = T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + + A_smem = T.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) + B_smem = T.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") + + ptr: T.let[T.Var(name="ptr", dtype=PointerType(PrimType("uint64")))] = T.reinterpret("handle", T.ptx.map_shared_rank(tma_mbar.ptr_to([0]), 0)) # noqa: E501 + tma_mbar_cta_0 = T.decl_buffer([1], "uint64", data=ptr, scope="shared") if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) - Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=cols_alloc, cta_group=2) - # Logical TMEM buffer: (64, N_logical) with 2x2 shard layout - tmem = Tx.decl_buffer((M_per_cta, N_logical), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(M_per_cta, 2, N_half) : (1 @ TLane, 64 @ TLane, 1 @ TCol)])) # noqa: E501 + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=cols_alloc, cta_group=2) + tmem = T.decl_buffer((M_per_cta, N_logical), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(M_per_cta, 2, N_half) : (1 @ TLane, 64 @ TLane, 1 @ TCol)])) # noqa: E501 # Physical TMEM view for readback: (128, N_half) standard layout - tmem_phys = Tx.decl_buffer((128, N_half), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, N_half) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 - Tx.ptx.fence.mbarrier_init() - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - Tx.cuda.cluster_sync() + tmem_phys = T.decl_buffer((128, N_half), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, N_half) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + T.ptx.fence.mbarrier_init() + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + T.cuda.cluster_sync() - tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar_cta_0.ptr_to([0]), "cta_group": 2}) # noqa: E501 + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar_cta_0.ptr_to([0]), "cta_group": 2}) # noqa: E501 if tid_in_wg == 0: - with Tx.thread(): - # CTA cbx loads its portion of A and B - Tx.copy_async(A_smem[0:M_per_cta, 0:K], A[cbx * M_per_cta:(cbx + 1) * M_per_cta, 0:K], **tma_args) # noqa: E501 - Tx.copy_async(B_smem[0:N_half, 0:K], B[cbx * N_half:(cbx + 1) * N_half, 0:K], **tma_args) # noqa: E501 - if cbx == 0: - Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) + # CTA cbx loads its portion of A and B + Tx.copy_async(A_smem[0:M_per_cta, 0:K], A[cbx * M_per_cta:(cbx + 1) * M_per_cta, 0:K], **tma_args) # noqa: E501 + Tx.copy_async(B_smem[0:N_half, 0:K], B[cbx * N_half:(cbx + 1) * N_half, 0:K], **tma_args) # noqa: E501 + if cbx == 0: + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) if cbx == 0: - Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) - Tx.ptx.tcgen05.fence.after_thread_sync() - Tx.cuda.cta_sync() + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.ptx.tcgen05.fence.after_thread_sync() + T.cuda.cta_sync() if tid_in_wg == 0: - with Tx.thread(): - Tx.gemm_async(tmem[0:M_per_cta, 0:N_logical], A_smem[0:M_per_cta, 0:K], B_smem[0:N_half, 0:K], dispatch="tcgen05", cta_group=2) # noqa: E501 - Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=2, cta_mask=3) - Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) - Tx.ptx.tcgen05.fence.after_thread_sync() - Tx.cuda.cta_sync() + Tx.gemm_async(tmem[0:M_per_cta, 0:N_logical], A_smem[0:M_per_cta, 0:K], B_smem[0:N_half, 0:K], dispatch="tcgen05", cta_group=2) # noqa: E501 + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=2, cta_mask=3) + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + T.ptx.tcgen05.fence.after_thread_sync() + T.cuda.cta_sync() # Readback from physical TMEM view (128 rows x N_half cols) # Warps 0,1 (rows 0-63): first N half for M rows 0-63 # Warps 2,3 (rows 64-127): second N half for M rows 0-63 - C_reg = Tx.alloc_local(N_half, dtype=C_dtype) + C_reg = T.alloc_local(N_half, dtype=C_dtype) C_view = C_reg.view(128, N_half, layout=TileLayout(S[(128, N_half) : (1 @ axis_tid_in_wg, 1)])) # noqa: E501 if wg_id == 0: - with Tx.warpgroup(): - Tx.copy_async(C_view[:, :], tmem_phys[0:128, 0:N_half]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - - # Write to global: thread t holds M_row = t%64, N_half_idx = t//64 - with Tx.thread(): - n_off = (tid_in_wg // 64) * N_half - Tx.copy(C[cbx * M_per_cta + tid_in_wg % 64, n_off : n_off + N_half], C_reg[:]) - Tx.cuda.cta_sync() + Tx.wg.copy_async(C_view[:, :], tmem_phys[0:128, 0:N_half]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + n_off = (tid_in_wg // 64) * N_half + Tx.copy(C[cbx * M_per_cta + tid_in_wg % 64, n_off : n_off + N_half], C_reg[:]) + T.cuda.cta_sync() if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=2) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=2) + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=2) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=2) # fmt: on dev = tvm.cuda(0) @@ -722,7 +693,7 @@ def test_gemm_block_scaled_fp8_cta_group_1(task): """Test block-scaled fp8 GEMM with cta_group=1 using gemm_async op. Uses random per-row quantization with float8_e8m0fnu scale factors - loaded via tcgen05.cp. Reference: C = dequant(A) @ dequant(B).Tx. + loaded via tcgen05.cp. Reference: C = dequant(A) @ dequant(B).T. """ ( (C_shape, C_dtype, C_region), @@ -774,101 +745,90 @@ def test_gemm_block_scaled_fp8_cta_group_1(task): SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off - @Tx.prim_func - def gemm_async_fn(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle, SFA_ptr: Tx.handle, SFB_ptr: Tx.handle) -> None: # noqa: E501 - A = Tx.match_buffer(A_ptr, A_shape, A_dtype) - B = Tx.match_buffer(B_ptr, B_shape, B_dtype) - C = Tx.match_buffer(C_ptr, C_shape, C_dtype) - SFA_in = Tx.match_buffer(SFA_ptr, (128,), "uint32") - SFB_in = Tx.match_buffer(SFB_ptr, (128,), "uint32") - - Tx.device_entry() - warp_id = Tx.warp_id([(1) * 4]) - cta_id = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - - A_smem = Tx.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) - B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) - SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) - SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + @T.prim_func + def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T.handle, SFB_ptr: T.handle) -> None: # noqa: E501 + A = T.match_buffer(A_ptr, A_shape, A_dtype) + B = T.match_buffer(B_ptr, B_shape, B_dtype) + C = T.match_buffer(C_ptr, C_shape, C_dtype) + SFA_in = T.match_buffer(SFA_ptr, (128,), "uint32") + SFB_in = T.match_buffer(SFB_ptr, (128,), "uint32") + + T.device_entry() + warp_id = T.warp_id([(1) * 4]) + cta_id = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + + A_smem = T.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) + B_smem = T.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) + SFA_smem = T.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + SFB_smem = T.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout) SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout) - tmem_addr = Tx.alloc_shared([1], "uint32") - tma_mbar = Tx.alloc_shared([1], "uint64") - mma_mbar = Tx.alloc_shared([1], "uint64") - descSFA = Tx.alloc_buffer((1,), "uint64", scope="local") - descSFB = Tx.alloc_buffer((1,), "uint64", scope="local") + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") + descSFA = T.alloc_buffer((1,), "uint64", scope="local") + descSFB = T.alloc_buffer((1,), "uint64", scope="local") if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) - Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=cols_alloc, cta_group=1) - Tx.cuda.cta_sync() + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=cols_alloc, cta_group=1) + T.cuda.cta_sync() - tmem = Tx.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 - sfa_tmem = Tx.decl_buffer((M, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFA_TMEM_START, layout=sfa_layout) # noqa: E501 - sfb_tmem = Tx.decl_buffer((N, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFB_TMEM_START, layout=sfb_layout) # noqa: E501 + tmem = T.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + sfa_tmem = T.decl_buffer((M, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFA_TMEM_START, layout=sfa_layout) # noqa: E501 + sfb_tmem = T.decl_buffer((N, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFB_TMEM_START, layout=sfb_layout) # noqa: E501 # TMA load A and B from global to shared if tid_in_wg == 0: - with Tx.thread(): - tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) - Tx.copy_async(A_smem[tuple(r_gmem_A)], A[tuple(r_gmem_A)], **tma_args) - Tx.copy_async(B_smem[tuple(r_gmem_B)], B[tuple(r_gmem_B)], **tma_args) - Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) - Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() - - # Load packed scale factors from global to shared memory - with Tx.thread(): - SFA_smem[tid_in_wg // 32, tid_in_wg % 32] = SFA_in[tid_in_wg] - SFB_smem[tid_in_wg // 32, tid_in_wg % 32] = SFB_in[tid_in_wg] - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) + Tx.copy_async(A_smem[tuple(r_gmem_A)], A[tuple(r_gmem_A)], **tma_args) + Tx.copy_async(B_smem[tuple(r_gmem_B)], B[tuple(r_gmem_B)], **tma_args) + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + SFA_smem[tid_in_wg // 32, tid_in_wg % 32] = SFA_in[tid_in_wg] + SFB_smem[tid_in_wg // 32, tid_in_wg % 32] = SFB_in[tid_in_wg] + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() # Transpose scale factors in shared memory if warp_id == 0: - with Tx.warp(): - Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) - Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) - Tx.cuda.cta_sync() + Tx.warp.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) + Tx.warp.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) + T.cuda.cta_sync() # Copy SFA/SFB from shared to TMEM via tcgen05.cp, then issue MMA if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.tcgen05.encode_matrix_descriptor(descSFA.data, SFA_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 - Tx.ptx.tcgen05.cp(SFA_TMEM_START, descSFA[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 - Tx.ptx.tcgen05.encode_matrix_descriptor(descSFB.data, SFB_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 - Tx.ptx.tcgen05.cp(SFB_TMEM_START, descSFB[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 + T.ptx.tcgen05.encode_matrix_descriptor(descSFA.data, SFA_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 + T.ptx.tcgen05.cp(SFA_TMEM_START, descSFA[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 + T.ptx.tcgen05.encode_matrix_descriptor(descSFB.data, SFB_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 + T.ptx.tcgen05.cp(SFB_TMEM_START, descSFB[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 - Tx.gemm_async(tmem[tuple(r_tmem_C)], A_smem[tuple(r_smem_A)], B_smem[tuple(r_smem_B)], SFA=sfa_tmem[0:M, 0:sf_mma_k], SFB=sfb_tmem[0:N, 0:sf_mma_k], dispatch="tcgen05") # noqa: E501 - Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) - Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() + Tx.gemm_async(tmem[tuple(r_tmem_C)], A_smem[tuple(r_smem_A)], B_smem[tuple(r_smem_B)], SFA=sfa_tmem[0:M, 0:sf_mma_k], SFB=sfb_tmem[0:N, 0:sf_mma_k], dispatch="tcgen05") # noqa: E501 + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() # Copy result from tmem to global - Tx.ptx.tcgen05.fence.after_thread_sync() - C_reg = Tx.alloc_local(width, dtype=C_dtype) + T.ptx.tcgen05.fence.after_thread_sync() + C_reg = T.alloc_local(width, dtype=C_dtype) C_view = C_reg.view(128, width, layout=TileLayout(S[(128, width) : (1@axis_tid_in_wg, 1)])) if wg_id == 0: - with Tx.warpgroup(): - Tx.copy_async(C_view[:, :], tmem[tuple(r_tmem_C)]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - with Tx.thread(): - Tx.copy(C[tid_in_wg, C_region[1][0]:C_region[1][1]], C_reg[:]) + Tx.wg.copy_async(C_view[:, :], tmem[tuple(r_tmem_C)]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + Tx.copy(C[tid_in_wg, C_region[1][0]:C_region[1][1]], C_reg[:]) if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=1) + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=1) # fmt: on dev = tvm.cuda(0) @@ -926,7 +886,7 @@ def test_gemm_block_scaled_fp8_cta_group_2(task): """Test block-scaled fp8 GEMM with cta_group=2 using gemm_async op. Uses random per-row SFA quantization (256 rows, indexed by cbx per CTA) - and uniform SFB. Reference: C = dequant(A) @ dequant(B).Tx. + and uniform SFB. Reference: C = dequant(A) @ dequant(B).T. """ ( (C_shape, C_dtype, C_region), @@ -980,115 +940,103 @@ def test_gemm_block_scaled_fp8_cta_group_2(task): SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off - @Tx.prim_func - def gemm_async_fn(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle, SFA_ptr: Tx.handle, SFB_ptr: Tx.handle) -> None: # noqa: E501 - A = Tx.match_buffer(A_ptr, A_shape, A_dtype) - B = Tx.match_buffer(B_ptr, B_shape, B_dtype) - C = Tx.match_buffer(C_ptr, C_shape, C_dtype) - SFA_in = Tx.match_buffer(SFA_ptr, (M_total,), "uint32") - SFB_in = Tx.match_buffer(SFB_ptr, (128,), "uint32") - - Tx.device_entry() - warp_id = Tx.warp_id([(1) * 4]) - cbx, cby = Tx.cta_id_in_cluster([2, 1]) - cta_id = Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - - A_smem = Tx.alloc_buffer(A_shape_per_cta, A_dtype, scope="shared", layout=A_layout) - B_smem = Tx.alloc_buffer(B_shape_per_cta, B_dtype, scope="shared", layout=B_layout) - SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) - SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + @T.prim_func + def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T.handle, SFB_ptr: T.handle) -> None: # noqa: E501 + A = T.match_buffer(A_ptr, A_shape, A_dtype) + B = T.match_buffer(B_ptr, B_shape, B_dtype) + C = T.match_buffer(C_ptr, C_shape, C_dtype) + SFA_in = T.match_buffer(SFA_ptr, (M_total,), "uint32") + SFB_in = T.match_buffer(SFB_ptr, (128,), "uint32") + + T.device_entry() + warp_id = T.warp_id([(1) * 4]) + cbx, cby = T.cta_id_in_cluster([2, 1]) + cta_id = T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + + A_smem = T.alloc_buffer(A_shape_per_cta, A_dtype, scope="shared", layout=A_layout) + B_smem = T.alloc_buffer(B_shape_per_cta, B_dtype, scope="shared", layout=B_layout) + SFA_smem = T.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + SFB_smem = T.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout) SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout) - tmem_addr = Tx.alloc_shared([1], "uint32") - tma_mbar = Tx.alloc_shared([1], "uint64") - mma_mbar = Tx.alloc_shared([1], "uint64") - descSFA = Tx.alloc_buffer((1,), "uint64", scope="local") - descSFB = Tx.alloc_buffer((1,), "uint64", scope="local") + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") + descSFA = T.alloc_buffer((1,), "uint64", scope="local") + descSFB = T.alloc_buffer((1,), "uint64", scope="local") - ptr: Tx.let[Tx.Var(name="ptr", dtype=PointerType(PrimType("uint64")))] = Tx.reinterpret("handle", Tx.ptx.map_shared_rank(tma_mbar.ptr_to([0]), 0)) # noqa: E501 - tma_mbar_cta_0 = Tx.decl_buffer([1], "uint64", data=ptr, scope="shared") + ptr: T.let[T.Var(name="ptr", dtype=PointerType(PrimType("uint64")))] = T.reinterpret("handle", T.ptx.map_shared_rank(tma_mbar.ptr_to([0]), 0)) # noqa: E501 + tma_mbar_cta_0 = T.decl_buffer([1], "uint64", data=ptr, scope="shared") if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) - Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=cols_alloc, cta_group=2) - tmem = Tx.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=cols_alloc, cta_group=2) + tmem = T.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 - sfa_tmem = Tx.decl_buffer((128, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFA_TMEM_START, layout=sf_layout) # noqa: E501 - sfb_tmem = Tx.decl_buffer((128, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFB_TMEM_START, layout=sf_layout) # noqa: E501 + sfa_tmem = T.decl_buffer((128, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFA_TMEM_START, layout=sf_layout) # noqa: E501 + sfb_tmem = T.decl_buffer((128, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFB_TMEM_START, layout=sf_layout) # noqa: E501 - Tx.ptx.fence.mbarrier_init() - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - Tx.cuda.cluster_sync() + T.ptx.fence.mbarrier_init() + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + T.cuda.cluster_sync() # TMA load A and B (both CTAs issue with multicast) - tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar_cta_0.ptr_to([0]), "cta_group": 2}) # noqa: E501 + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar_cta_0.ptr_to([0]), "cta_group": 2}) # noqa: E501 if tid_in_wg == 0: - with Tx.thread(): - Tx.copy_async(A_smem[tuple(r_smem_A_in)], A[tuple(get_global_region(A_shape_per_cta, transA, cbx))], **tma_args) # noqa: E501 - Tx.copy_async(B_smem[tuple(r_smem_B_in)], B[tuple(get_global_region(B_shape_per_cta, transB, cbx))], **tma_args) # noqa: E501 - if cbx == 0: - Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) - - # Load SFA per CTA (each CTA gets its 128 rows), SFB same for both - with Tx.thread(): - SFA_smem[tid_in_wg // 32, tid_in_wg % 32] = SFA_in[cbx * 128 + tid_in_wg] - SFB_smem[tid_in_wg // 32, tid_in_wg % 32] = SFB_in[tid_in_wg] - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + Tx.copy_async(A_smem[tuple(r_smem_A_in)], A[tuple(get_global_region(A_shape_per_cta, transA, cbx))], **tma_args) # noqa: E501 + Tx.copy_async(B_smem[tuple(r_smem_B_in)], B[tuple(get_global_region(B_shape_per_cta, transB, cbx))], **tma_args) # noqa: E501 + if cbx == 0: + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) + SFA_smem[tid_in_wg // 32, tid_in_wg % 32] = SFA_in[cbx * 128 + tid_in_wg] + SFB_smem[tid_in_wg // 32, tid_in_wg % 32] = SFB_in[tid_in_wg] + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() # Transpose scale factors (both CTAs) if warp_id == 0: - with Tx.warp(): - Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) - Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) - Tx.cuda.cta_sync() + Tx.warp.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) + Tx.warp.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) + T.cuda.cta_sync() # Copy SFA/SFB from shared to TMEM via tcgen05.cp (both CTAs, cta_group=2) if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.tcgen05.encode_matrix_descriptor(descSFA.data, SFA_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 - Tx.ptx.tcgen05.cp(SFA_TMEM_START, descSFA[0], shape="32x128b", cta_group=2, multicast="warpx4") # noqa: E501 - Tx.ptx.tcgen05.encode_matrix_descriptor(descSFB.data, SFB_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 - Tx.ptx.tcgen05.cp(SFB_TMEM_START, descSFB[0], shape="32x128b", cta_group=2, multicast="warpx4") # noqa: E501 - Tx.cuda.cta_sync() - Tx.cuda.cluster_sync() + T.ptx.tcgen05.encode_matrix_descriptor(descSFA.data, SFA_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 + T.ptx.tcgen05.cp(SFA_TMEM_START, descSFA[0], shape="32x128b", cta_group=2, multicast="warpx4") # noqa: E501 + T.ptx.tcgen05.encode_matrix_descriptor(descSFB.data, SFB_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 + T.ptx.tcgen05.cp(SFB_TMEM_START, descSFB[0], shape="32x128b", cta_group=2, multicast="warpx4") # noqa: E501 + T.cuda.cta_sync() + T.cuda.cluster_sync() if cbx == 0: - Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) - Tx.ptx.tcgen05.fence.after_thread_sync() - Tx.cuda.cta_sync() + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.ptx.tcgen05.fence.after_thread_sync() + T.cuda.cta_sync() if tid_in_wg == 0: - with Tx.thread(): - Tx.gemm_async(tmem[tuple(r_tmem_C)], A_smem[tuple(r_smem_A)], B_smem[tuple(r_smem_B)], SFA=sfa_tmem[0:128, 0:sf_mma_k], SFB=sfb_tmem[0:128, 0:sf_mma_k], dispatch="tcgen05", cta_group=2) # noqa: E501 - Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=2, cta_mask=3) - Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) - Tx.ptx.tcgen05.fence.after_thread_sync() - Tx.cuda.cta_sync() + Tx.gemm_async(tmem[tuple(r_tmem_C)], A_smem[tuple(r_smem_A)], B_smem[tuple(r_smem_B)], SFA=sfa_tmem[0:128, 0:sf_mma_k], SFB=sfb_tmem[0:128, 0:sf_mma_k], dispatch="tcgen05", cta_group=2) # noqa: E501 + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=2, cta_mask=3) + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + T.ptx.tcgen05.fence.after_thread_sync() + T.cuda.cta_sync() # Copy result from tmem to global - C_reg = Tx.alloc_local(width, dtype=C_dtype) + C_reg = T.alloc_local(width, dtype=C_dtype) C_view = C_reg.view(128, width, layout=TileLayout(S[(128, width) : (1@axis_tid_in_wg, 1)])) if wg_id == 0: - with Tx.warpgroup(): - Tx.copy_async(C_view[:, :], tmem[C_region[0][0]:C_region[0][1], C_region[1][0]:C_region[1][0] + width]) # noqa: E501 - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - with Tx.thread(): - Tx.copy(C[cbx * 128 + tid_in_wg, C_region[1][0]:C_region[1][0] + width], C_reg[:]) - Tx.cuda.cta_sync() + Tx.wg.copy_async(C_view[:, :], tmem[C_region[0][0]:C_region[0][1], C_region[1][0]:C_region[1][0] + width]) # noqa: E501 + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + Tx.copy(C[cbx * 128 + tid_in_wg, C_region[1][0]:C_region[1][0] + width], C_reg[:]) + T.cuda.cta_sync() if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=2) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=2) + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=2) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=2) # fmt: on dev = tvm.cuda(0) @@ -1146,7 +1094,7 @@ def test_gemm_block_scaled_nvfp4_cta_group_1(): """Test block-scaled nvfp4 GEMM with cta_group=1. Uses float4_e2m1fn A/B with float8_e4m3fn per-row scale factors. - Reference: C = dequant(A) @ dequant(B).Tx. + Reference: C = dequant(A) @ dequant(B).T. """ M, N, K = 128, 32, 256 C_shape = (128, 512) @@ -1184,104 +1132,93 @@ def test_gemm_block_scaled_nvfp4_cta_group_1(): SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off - @Tx.prim_func - def gemm_async_fn(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle, SFA_ptr: Tx.handle, SFB_ptr: Tx.handle) -> None: # noqa: E501 - A_packed = Tx.match_buffer(A_ptr, A_packed_shape, "uint8") - B_packed = Tx.match_buffer(B_ptr, B_packed_shape, "uint8") - C = Tx.match_buffer(C_ptr, C_shape, C_dtype) - SFA_in = Tx.match_buffer(SFA_ptr, (128,), "uint32") - SFB_in = Tx.match_buffer(SFB_ptr, (128,), "uint32") - - Tx.device_entry() - warp_id = Tx.warp_id([(1) * 4]) - cta_id = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - - A_smem_packed = Tx.alloc_buffer(A_packed_shape, "uint8", scope="shared", layout=A_uint8_layout) # noqa: E501 - B_smem_packed = Tx.alloc_buffer(B_packed_shape, "uint8", scope="shared", layout=B_uint8_layout) # noqa: E501 - A_smem = Tx.decl_buffer(A_fp4_shape, "float4_e2m1fn", data=A_smem_packed.data, scope="shared", layout=A_fp4_layout) # noqa: E501 - B_smem = Tx.decl_buffer(B_fp4_shape, "float4_e2m1fn", data=B_smem_packed.data, scope="shared", layout=B_fp4_layout) # noqa: E501 - - SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) - SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + @T.prim_func + def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T.handle, SFB_ptr: T.handle) -> None: # noqa: E501 + A_packed = T.match_buffer(A_ptr, A_packed_shape, "uint8") + B_packed = T.match_buffer(B_ptr, B_packed_shape, "uint8") + C = T.match_buffer(C_ptr, C_shape, C_dtype) + SFA_in = T.match_buffer(SFA_ptr, (128,), "uint32") + SFB_in = T.match_buffer(SFB_ptr, (128,), "uint32") + + T.device_entry() + warp_id = T.warp_id([(1) * 4]) + cta_id = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + + A_smem_packed = T.alloc_buffer(A_packed_shape, "uint8", scope="shared", layout=A_uint8_layout) # noqa: E501 + B_smem_packed = T.alloc_buffer(B_packed_shape, "uint8", scope="shared", layout=B_uint8_layout) # noqa: E501 + A_smem = T.decl_buffer(A_fp4_shape, "float4_e2m1fn", data=A_smem_packed.data, scope="shared", layout=A_fp4_layout) # noqa: E501 + B_smem = T.decl_buffer(B_fp4_shape, "float4_e2m1fn", data=B_smem_packed.data, scope="shared", layout=B_fp4_layout) # noqa: E501 + + SFA_smem = T.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + SFB_smem = T.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout) SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout) - tmem_addr = Tx.alloc_shared([1], "uint32") - tma_mbar = Tx.alloc_shared([1], "uint64") - mma_mbar = Tx.alloc_shared([1], "uint64") - descSFA = Tx.alloc_buffer((1,), "uint64", scope="local") - descSFB = Tx.alloc_buffer((1,), "uint64", scope="local") + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") + descSFA = T.alloc_buffer((1,), "uint64", scope="local") + descSFB = T.alloc_buffer((1,), "uint64", scope="local") if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) - Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=cols_alloc, cta_group=1) - Tx.cuda.cta_sync() + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=cols_alloc, cta_group=1) + T.cuda.cta_sync() - tmem = Tx.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 - sfa_tmem = Tx.decl_buffer((M, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFA_TMEM_START, layout=sfa_layout) # noqa: E501 - sfb_tmem = Tx.decl_buffer((N, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFB_TMEM_START, layout=sfb_layout) # noqa: E501 + tmem = T.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + sfa_tmem = T.decl_buffer((M, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFA_TMEM_START, layout=sfa_layout) # noqa: E501 + sfb_tmem = T.decl_buffer((N, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFB_TMEM_START, layout=sfb_layout) # noqa: E501 # TMA load A and B as uint8 if tid_in_wg == 0: - with Tx.thread(): - tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) - Tx.copy_async(A_smem_packed[:, :], A_packed[:, :], **tma_args) - Tx.copy_async(B_smem_packed[:, :], B_packed[:, :], **tma_args) - Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) - Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() - - # Load packed scale factors from global to shared memory - with Tx.thread(): - SFA_smem[tid_in_wg // 32, tid_in_wg % 32] = SFA_in[tid_in_wg] - SFB_smem[tid_in_wg // 32, tid_in_wg % 32] = SFB_in[tid_in_wg] - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) + Tx.copy_async(A_smem_packed[:, :], A_packed[:, :], **tma_args) + Tx.copy_async(B_smem_packed[:, :], B_packed[:, :], **tma_args) + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + SFA_smem[tid_in_wg // 32, tid_in_wg % 32] = SFA_in[tid_in_wg] + SFB_smem[tid_in_wg // 32, tid_in_wg % 32] = SFB_in[tid_in_wg] + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() # Transpose scale factors in shared memory if warp_id == 0: - with Tx.warp(): - Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) - Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) - Tx.cuda.cta_sync() + Tx.warp.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) + Tx.warp.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) + T.cuda.cta_sync() # Copy SFA/SFB from shared to TMEM via tcgen05.cp, then issue MMA if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.tcgen05.encode_matrix_descriptor(descSFA.data, SFA_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 - Tx.ptx.tcgen05.cp(SFA_TMEM_START, descSFA[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 - Tx.ptx.tcgen05.encode_matrix_descriptor(descSFB.data, SFB_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 - Tx.ptx.tcgen05.cp(SFB_TMEM_START, descSFB[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 + T.ptx.tcgen05.encode_matrix_descriptor(descSFA.data, SFA_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 + T.ptx.tcgen05.cp(SFA_TMEM_START, descSFA[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 + T.ptx.tcgen05.encode_matrix_descriptor(descSFB.data, SFB_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 + T.ptx.tcgen05.cp(SFB_TMEM_START, descSFB[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 - Tx.gemm_async(tmem[0:128, 0:N], A_smem[:, :], B_smem[:, :], SFA=sfa_tmem[0:M, 0:sf_mma_k], SFB=sfb_tmem[0:N, 0:sf_mma_k], dispatch="tcgen05") # noqa: E501 - Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) - Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() + Tx.gemm_async(tmem[0:128, 0:N], A_smem[:, :], B_smem[:, :], SFA=sfa_tmem[0:M, 0:sf_mma_k], SFB=sfb_tmem[0:N, 0:sf_mma_k], dispatch="tcgen05") # noqa: E501 + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() # Copy result from tmem to global - Tx.ptx.tcgen05.fence.after_thread_sync() - C_reg = Tx.alloc_local(width, dtype=C_dtype) + T.ptx.tcgen05.fence.after_thread_sync() + C_reg = T.alloc_local(width, dtype=C_dtype) C_view = C_reg.view(128, width, layout=TileLayout(S[(128, width) : (1@axis_tid_in_wg, 1)])) if wg_id == 0: - with Tx.warpgroup(): - Tx.copy_async(C_view[:, :], tmem[0:128, 0:N]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - with Tx.thread(): - Tx.copy(C[tid_in_wg, 0:N], C_reg[:]) + Tx.wg.copy_async(C_view[:, :], tmem[0:128, 0:N]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + Tx.copy(C[tid_in_wg, 0:N], C_reg[:]) if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=1) + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=1) # fmt: on dev = tvm.cuda(0) @@ -1328,7 +1265,7 @@ def test_gemm_block_scaled_nvfp4_cta_group_2(): A: (256, 256) float4_e2m1fn, split M across 2 CTAs (128 each). B: (64, 256) float4_e2m1fn, split N across 2 CTAs (32 each). Per-row SFA, uniform SFB. - Reference: C = dequant(A) @ dequant(B).Tx. + Reference: C = dequant(A) @ dequant(B).T. """ M_total, N_per_cta, K = 256, 32, 256 N_total = N_per_cta * 2 # 64 @@ -1374,118 +1311,106 @@ def test_gemm_block_scaled_nvfp4_cta_group_2(): SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off - @Tx.prim_func - def gemm_async_fn(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle, SFA_ptr: Tx.handle, SFB_ptr: Tx.handle) -> None: # noqa: E501 - A_packed = Tx.match_buffer(A_ptr, A_packed_shape, "uint8") - B_packed = Tx.match_buffer(B_ptr, B_packed_shape, "uint8") - C = Tx.match_buffer(C_ptr, C_shape, C_dtype) - SFA_in = Tx.match_buffer(SFA_ptr, (M_total,), "uint32") - SFB_in = Tx.match_buffer(SFB_ptr, (128,), "uint32") - - Tx.device_entry() - warp_id = Tx.warp_id([(1) * 4]) - cbx, cby = Tx.cta_id_in_cluster([2, 1]) - cta_id = Tx.cta_id([2]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - - A_smem_packed = Tx.alloc_buffer(A_packed_per_cta, "uint8", scope="shared", layout=A_uint8_layout) # noqa: E501 - B_smem_packed = Tx.alloc_buffer(B_packed_per_cta, "uint8", scope="shared", layout=B_uint8_layout) # noqa: E501 - A_smem = Tx.decl_buffer(A_fp4_per_cta, "float4_e2m1fn", data=A_smem_packed.data, scope="shared", layout=A_fp4_layout) # noqa: E501 - B_smem = Tx.decl_buffer(B_fp4_per_cta, "float4_e2m1fn", data=B_smem_packed.data, scope="shared", layout=B_fp4_layout) # noqa: E501 - - SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) - SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + @T.prim_func + def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T.handle, SFB_ptr: T.handle) -> None: # noqa: E501 + A_packed = T.match_buffer(A_ptr, A_packed_shape, "uint8") + B_packed = T.match_buffer(B_ptr, B_packed_shape, "uint8") + C = T.match_buffer(C_ptr, C_shape, C_dtype) + SFA_in = T.match_buffer(SFA_ptr, (M_total,), "uint32") + SFB_in = T.match_buffer(SFB_ptr, (128,), "uint32") + + T.device_entry() + warp_id = T.warp_id([(1) * 4]) + cbx, cby = T.cta_id_in_cluster([2, 1]) + cta_id = T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + + A_smem_packed = T.alloc_buffer(A_packed_per_cta, "uint8", scope="shared", layout=A_uint8_layout) # noqa: E501 + B_smem_packed = T.alloc_buffer(B_packed_per_cta, "uint8", scope="shared", layout=B_uint8_layout) # noqa: E501 + A_smem = T.decl_buffer(A_fp4_per_cta, "float4_e2m1fn", data=A_smem_packed.data, scope="shared", layout=A_fp4_layout) # noqa: E501 + B_smem = T.decl_buffer(B_fp4_per_cta, "float4_e2m1fn", data=B_smem_packed.data, scope="shared", layout=B_fp4_layout) # noqa: E501 + + SFA_smem = T.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + SFB_smem = T.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout) SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout) - tmem_addr = Tx.alloc_shared([1], "uint32") - tma_mbar = Tx.alloc_shared([1], "uint64") - mma_mbar = Tx.alloc_shared([1], "uint64") - descSFA = Tx.alloc_buffer((1,), "uint64", scope="local") - descSFB = Tx.alloc_buffer((1,), "uint64", scope="local") + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") + descSFA = T.alloc_buffer((1,), "uint64", scope="local") + descSFB = T.alloc_buffer((1,), "uint64", scope="local") - ptr: Tx.let[Tx.Var(name="ptr", dtype=PointerType(PrimType("uint64")))] = Tx.reinterpret("handle", Tx.ptx.map_shared_rank(tma_mbar.ptr_to([0]), 0)) # noqa: E501 - tma_mbar_cta_0 = Tx.decl_buffer([1], "uint64", data=ptr, scope="shared") + ptr: T.let[T.Var(name="ptr", dtype=PointerType(PrimType("uint64")))] = T.reinterpret("handle", T.ptx.map_shared_rank(tma_mbar.ptr_to([0]), 0)) # noqa: E501 + tma_mbar_cta_0 = T.decl_buffer([1], "uint64", data=ptr, scope="shared") if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) - Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=cols_alloc, cta_group=2) - tmem = Tx.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=cols_alloc, cta_group=2) + tmem = T.decl_buffer((128, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 - sfa_tmem = Tx.decl_buffer((M_per_cta, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFA_TMEM_START, layout=sfa_layout) # noqa: E501 - sfb_tmem = Tx.decl_buffer((N_total, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFB_TMEM_START, layout=sfb_layout) # noqa: E501 + sfa_tmem = T.decl_buffer((M_per_cta, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFA_TMEM_START, layout=sfa_layout) # noqa: E501 + sfb_tmem = T.decl_buffer((N_total, sf_mma_k), SF_dtype, scope="tmem", allocated_addr=SFB_TMEM_START, layout=sfb_layout) # noqa: E501 - Tx.ptx.fence.mbarrier_init() - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() - Tx.cuda.cluster_sync() + T.ptx.fence.mbarrier_init() + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + T.cuda.cluster_sync() # TMA load A and B with multicast (each CTA loads its portion) - tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar_cta_0.ptr_to([0]), "cta_group": 2}) # noqa: E501 + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar_cta_0.ptr_to([0]), "cta_group": 2}) # noqa: E501 if tid_in_wg == 0: - with Tx.thread(): - Tx.copy_async(A_smem_packed[:, :], A_packed[cbx * M_per_cta:(cbx + 1) * M_per_cta, :], **tma_args) # noqa: E501 - Tx.copy_async(B_smem_packed[:, :], B_packed[cbx * N_per_cta:(cbx + 1) * N_per_cta, :], **tma_args) # noqa: E501 - if cbx == 0: - Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) - - # Load SFA per CTA (each CTA gets its 128 rows), SFB same for both - with Tx.thread(): - SFA_smem[tid_in_wg // 32, tid_in_wg % 32] = SFA_in[cbx * M_per_cta + tid_in_wg] - SFB_smem[tid_in_wg // 32, tid_in_wg % 32] = SFB_in[tid_in_wg] - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + Tx.copy_async(A_smem_packed[:, :], A_packed[cbx * M_per_cta:(cbx + 1) * M_per_cta, :], **tma_args) # noqa: E501 + Tx.copy_async(B_smem_packed[:, :], B_packed[cbx * N_per_cta:(cbx + 1) * N_per_cta, :], **tma_args) # noqa: E501 + if cbx == 0: + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) + SFA_smem[tid_in_wg // 32, tid_in_wg % 32] = SFA_in[cbx * M_per_cta + tid_in_wg] + SFB_smem[tid_in_wg // 32, tid_in_wg % 32] = SFB_in[tid_in_wg] + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() # Transpose scale factors if warp_id == 0: - with Tx.warp(): - Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) - Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) - Tx.cuda.cta_sync() + Tx.warp.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) + Tx.warp.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) + T.cuda.cta_sync() # Copy SFA/SFB from shared to TMEM via tcgen05.cp if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.tcgen05.encode_matrix_descriptor(descSFA.data, SFA_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 - Tx.ptx.tcgen05.cp(SFA_TMEM_START, descSFA[0], shape="32x128b", cta_group=2, multicast="warpx4") # noqa: E501 - Tx.ptx.tcgen05.encode_matrix_descriptor(descSFB.data, SFB_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 - Tx.ptx.tcgen05.cp(SFB_TMEM_START, descSFB[0], shape="32x128b", cta_group=2, multicast="warpx4") # noqa: E501 - Tx.cuda.cta_sync() - Tx.cuda.cluster_sync() + T.ptx.tcgen05.encode_matrix_descriptor(descSFA.data, SFA_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 + T.ptx.tcgen05.cp(SFA_TMEM_START, descSFA[0], shape="32x128b", cta_group=2, multicast="warpx4") # noqa: E501 + T.ptx.tcgen05.encode_matrix_descriptor(descSFB.data, SFB_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 + T.ptx.tcgen05.cp(SFB_TMEM_START, descSFB[0], shape="32x128b", cta_group=2, multicast="warpx4") # noqa: E501 + T.cuda.cta_sync() + T.cuda.cluster_sync() if cbx == 0: - Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) - Tx.ptx.tcgen05.fence.after_thread_sync() - Tx.cuda.cta_sync() + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.ptx.tcgen05.fence.after_thread_sync() + T.cuda.cta_sync() if tid_in_wg == 0: - with Tx.thread(): - Tx.gemm_async(tmem[0:128, 0:N_total], A_smem[:, :], B_smem[:, :], SFA=sfa_tmem[0:128, 0:sf_mma_k], SFB=sfb_tmem[0:N_total, 0:sf_mma_k], dispatch="tcgen05", cta_group=2) # noqa: E501 - Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=2, cta_mask=3) - Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) - Tx.ptx.tcgen05.fence.after_thread_sync() - Tx.cuda.cta_sync() + Tx.gemm_async(tmem[0:128, 0:N_total], A_smem[:, :], B_smem[:, :], SFA=sfa_tmem[0:128, 0:sf_mma_k], SFB=sfb_tmem[0:N_total, 0:sf_mma_k], dispatch="tcgen05", cta_group=2) # noqa: E501 + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=2, cta_mask=3) + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + T.ptx.tcgen05.fence.after_thread_sync() + T.cuda.cta_sync() # Copy result from tmem to global - C_reg = Tx.alloc_local(width, dtype=C_dtype) + C_reg = T.alloc_local(width, dtype=C_dtype) C_view = C_reg.view(128, width, layout=TileLayout(S[(128, width) : (1@axis_tid_in_wg, 1)])) if wg_id == 0: - with Tx.warpgroup(): - Tx.copy_async(C_view[:, :], tmem[0:128, 0:width]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - with Tx.thread(): - Tx.copy(C[cbx * M_per_cta + tid_in_wg, 0:width], C_reg[:]) - Tx.cuda.cta_sync() + Tx.wg.copy_async(C_view[:, :], tmem[0:128, 0:width]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + Tx.copy(C[cbx * M_per_cta + tid_in_wg, 0:width], C_reg[:]) + T.cuda.cta_sync() if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=2) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=2) + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=2) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=2) # fmt: on dev = tvm.cuda(0) @@ -1589,106 +1514,95 @@ def test_gemm_block_scaled_fp8_sf_id(): SF_smem_post_layout = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off - @Tx.prim_func - def gemm_async_fn(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle, SFA_ptr: Tx.handle, SFB_ptr: Tx.handle) -> None: # noqa: E501 - A = Tx.match_buffer(A_ptr, A_shape, A_dtype) - B = Tx.match_buffer(B_ptr, B_shape, B_dtype) - C = Tx.match_buffer(C_ptr, C_shape, C_dtype) - SFA_in = Tx.match_buffer(SFA_ptr, (128,), "uint32") - SFB_in = Tx.match_buffer(SFB_ptr, (128,), "uint32") - - Tx.device_entry() - warp_id = Tx.warp_id([(1) * 4]) - cta_id = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - - A_smem = Tx.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) - B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) - SFA_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) - SFB_smem = Tx.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + @T.prim_func + def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T.handle, SFB_ptr: T.handle) -> None: # noqa: E501 + A = T.match_buffer(A_ptr, A_shape, A_dtype) + B = T.match_buffer(B_ptr, B_shape, B_dtype) + C = T.match_buffer(C_ptr, C_shape, C_dtype) + SFA_in = T.match_buffer(SFA_ptr, (128,), "uint32") + SFB_in = T.match_buffer(SFB_ptr, (128,), "uint32") + + T.device_entry() + warp_id = T.warp_id([(1) * 4]) + cta_id = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + + A_smem = T.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout) + B_smem = T.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout) + SFA_smem = T.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) + SFB_smem = T.alloc_buffer((4, 32), "uint32", scope="shared", layout=SF_smem_layout) SFA_smem_post = SFA_smem.view(4, 32, layout=SF_smem_post_layout) SFB_smem_post = SFB_smem.view(4, 32, layout=SF_smem_post_layout) - tmem_addr = Tx.alloc_shared([1], "uint32") - tma_mbar = Tx.alloc_shared([1], "uint64") - mma_mbar = Tx.alloc_shared([1], "uint64") - descSFA = Tx.alloc_buffer((1,), "uint64", scope="local") - descSFB = Tx.alloc_buffer((1,), "uint64", scope="local") + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") + descSFA = T.alloc_buffer((1,), "uint64", scope="local") + descSFB = T.alloc_buffer((1,), "uint64", scope="local") if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) - Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc(Tx.address_of(tmem_addr), n_cols=cols_alloc, cta_group=1) - Tx.cuda.cta_sync() + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=cols_alloc, cta_group=1) + T.cuda.cta_sync() - tmem = Tx.decl_buffer(C_shape, C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 - sfa_tmem = Tx.decl_buffer((M, sf_mma_k * num_ki), SF_dtype, scope="tmem", allocated_addr=SFA_TMEM_START, layout=sfa_layout) # noqa: E501 - sfb_tmem = Tx.decl_buffer((N, sf_mma_k * num_ki), SF_dtype, scope="tmem", allocated_addr=SFB_TMEM_START, layout=sfb_layout) # noqa: E501 + tmem = T.decl_buffer(C_shape, C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + sfa_tmem = T.decl_buffer((M, sf_mma_k * num_ki), SF_dtype, scope="tmem", allocated_addr=SFA_TMEM_START, layout=sfa_layout) # noqa: E501 + sfb_tmem = T.decl_buffer((N, sf_mma_k * num_ki), SF_dtype, scope="tmem", allocated_addr=SFB_TMEM_START, layout=sfb_layout) # noqa: E501 # TMA load A and B from global to shared if tid_in_wg == 0: - with Tx.thread(): - tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) - Tx.copy_async(A_smem[0:M, 0:K], A[0:M, 0:K], **tma_args) - Tx.copy_async(B_smem[0:N, 0:K], B[0:N, 0:K], **tma_args) - Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) - Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() - - # Load packed scale factors from global to shared memory - with Tx.thread(): - SFA_smem[tid_in_wg // 32, tid_in_wg % 32] = SFA_in[tid_in_wg] - SFB_smem[tid_in_wg // 32, tid_in_wg % 32] = SFB_in[tid_in_wg] - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) + Tx.copy_async(A_smem[0:M, 0:K], A[0:M, 0:K], **tma_args) + Tx.copy_async(B_smem[0:N, 0:K], B[0:N, 0:K], **tma_args) + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + SFA_smem[tid_in_wg // 32, tid_in_wg % 32] = SFA_in[tid_in_wg] + SFB_smem[tid_in_wg // 32, tid_in_wg % 32] = SFB_in[tid_in_wg] + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() # Transpose scale factors in shared memory if warp_id == 0: - with Tx.warp(): - Tx.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) - Tx.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) - Tx.cuda.cta_sync() + Tx.warp.permute_layout(SFA_smem_post[:, :], SFA_smem[:, :]) + Tx.warp.permute_layout(SFB_smem_post[:, :], SFB_smem[:, :]) + T.cuda.cta_sync() # Copy SF to TMEM, then single MMA call (schedule auto-derives sf_id per ki) if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.tcgen05.encode_matrix_descriptor(descSFA.data, SFA_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 - Tx.ptx.tcgen05.cp(SFA_TMEM_START, descSFA[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 - Tx.ptx.tcgen05.encode_matrix_descriptor(descSFB.data, SFB_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 - Tx.ptx.tcgen05.cp(SFB_TMEM_START, descSFB[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 - - # Single call with K=128: schedule auto-encodes descI and - # rotates sf_id=0,1,2,3 for each of the 4 ki iterations. - # SFA/SFB region covers all 4 ki positions (num_ki elements) - # so the schedule knows sf_id should rotate. - Tx.gemm_async(tmem[0:128, 0:N], A_smem[0:M, 0:K], B_smem[0:N, 0:K], SFA=sfa_tmem[0:M, 0:sf_mma_k * num_ki], SFB=sfb_tmem[0:N, 0:sf_mma_k * num_ki], dispatch="tcgen05") # noqa: E501 - - Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) - Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() + T.ptx.tcgen05.encode_matrix_descriptor(descSFA.data, SFA_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 + T.ptx.tcgen05.cp(SFA_TMEM_START, descSFA[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 + T.ptx.tcgen05.encode_matrix_descriptor(descSFB.data, SFB_smem.access_ptr("r", offset=0), ldo=16, sdo=8 * 4 * F32_BYTES // F128_BYTES, swizzle=0) # noqa: E501 + T.ptx.tcgen05.cp(SFB_TMEM_START, descSFB[0], shape="32x128b", cta_group=1, multicast="warpx4") # noqa: E501 + + # Single call with K=128: schedule auto-encodes descI and + # rotates sf_id=0,1,2,3 for each of the 4 ki iterations. + # SFA/SFB region covers all 4 ki positions (num_ki elements) + # so the schedule knows sf_id should rotate. + Tx.gemm_async(tmem[0:128, 0:N], A_smem[0:M, 0:K], B_smem[0:N, 0:K], SFA=sfa_tmem[0:M, 0:sf_mma_k * num_ki], SFB=sfb_tmem[0:N, 0:sf_mma_k * num_ki], dispatch="tcgen05") # noqa: E501 + + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() # Copy result from tmem to global - Tx.ptx.tcgen05.fence.after_thread_sync() - C_reg = Tx.alloc_local(N, dtype=C_dtype) + T.ptx.tcgen05.fence.after_thread_sync() + C_reg = T.alloc_local(N, dtype=C_dtype) C_view = C_reg.view(128, N, layout=TileLayout(S[(128, N) : (1@axis_tid_in_wg, 1)])) if wg_id == 0: - with Tx.warpgroup(): - Tx.copy_async(C_view[:, :], tmem[0:128, 0:N]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - with Tx.thread(): - Tx.copy(C[tid_in_wg, 0:N], C_reg[:]) + Tx.wg.copy_async(C_view[:, :], tmem[0:128, 0:N]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + Tx.copy(C[tid_in_wg, 0:N], C_reg[:]) if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=1) + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=1) # fmt: on def per_block_quantize_fp8(mat, block_size=32): @@ -1947,70 +1861,63 @@ def test_gemm_tcgen05_arbitrary_tiles(task): B_gmem_kw = {"layout": B_gmem_layout} if B_gmem_layout is not None else {} # fmt: off - @Tx.prim_func - def gemm_async(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, A_shape, A_dtype, **A_gmem_kw) - B = Tx.match_buffer(B_ptr, B_shape, B_dtype, **B_gmem_kw) - C = Tx.match_buffer(C_ptr, C_shape, C_dtype) - - Tx.device_entry() - warp_id = Tx.warp_id([(1) * 4]) - cta_id = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid_in_wg = Tx.thread_id_in_wg([128]) - - A_smem = Tx.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout, align=1024) - B_smem = Tx.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout, align=1024) - tmem_addr = Tx.alloc_shared([1], "uint32") - tma_mbar = Tx.alloc_shared([1], "uint64") - mma_mbar = Tx.alloc_shared([1], "uint64") + @T.prim_func + def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, A_shape, A_dtype, **A_gmem_kw) + B = T.match_buffer(B_ptr, B_shape, B_dtype, **B_gmem_kw) + C = T.match_buffer(C_ptr, C_shape, C_dtype) + + T.device_entry() + warp_id = T.warp_id([(1) * 4]) + cta_id = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + + A_smem = T.alloc_buffer(A_shape, A_dtype, scope="shared", layout=A_layout, align=1024) + B_smem = T.alloc_buffer(B_shape, B_dtype, scope="shared", layout=B_layout, align=1024) + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") if tid_in_wg == 0: - with Tx.thread(): - Tx.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) - Tx.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.cta_sync() + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.alloc( - Tx.address_of(tmem_addr), n_cols=cols_alloc, cta_group=cta_group - ) - Tx.cuda.cta_sync() - tmem = Tx.decl_buffer((M, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(M, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + T.ptx.tcgen05.alloc( + T.address_of(tmem_addr), n_cols=cols_alloc, cta_group=cta_group + ) + T.cuda.cta_sync() + tmem = T.decl_buffer((M, C_shape[1]), C_dtype, scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(M, C_shape[1]) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 if tid_in_wg == 0: - with Tx.thread(): - tma_args = Tx.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) - Tx.copy_async(A_smem[tuple(r_gmem_A)], A[tuple(r_gmem_A)], **tma_args) - Tx.copy_async(B_smem[tuple(r_gmem_B)], B[tuple(r_gmem_B)], **tma_args) - Tx.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) - Tx.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) + Tx.copy_async(A_smem[tuple(r_gmem_A)], A[tuple(r_gmem_A)], **tma_args) + Tx.copy_async(B_smem[tuple(r_gmem_B)], B[tuple(r_gmem_B)], **tma_args) + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() if tid_in_wg == 0: - with Tx.thread(): - Tx.gemm_async(tmem[tuple(r_tmem_C)], A_smem[tuple(r_smem_A)], B_smem[tuple(r_smem_B)], transA=transA, transB=transB, dispatch="tcgen05", cta_group=cta_group) # noqa: E501 - Tx.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=cta_group) - Tx.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) - Tx.cuda.cta_sync() - - Tx.ptx.tcgen05.fence.after_thread_sync() - C_reg = Tx.alloc_local(N, dtype=C_dtype) + Tx.gemm_async(tmem[tuple(r_tmem_C)], A_smem[tuple(r_smem_A)], B_smem[tuple(r_smem_B)], transA=transA, transB=transB, dispatch="tcgen05", cta_group=cta_group) # noqa: E501 + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=cta_group) + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + + T.ptx.tcgen05.fence.after_thread_sync() + C_reg = T.alloc_local(N, dtype=C_dtype) C_view = C_reg.view(M, N, layout=TileLayout(S[(M, N) : (1@axis_tid_in_wg, 1)])) if wg_id == 0: - with Tx.warpgroup(): - Tx.copy_async(C_view[:, :], tmem[tuple(r_tmem_C)]) - Tx.ptx.tcgen05.wait.ld() - Tx.cuda.cta_sync() - with Tx.thread(): - Tx.copy(C[tid_in_wg, C_region[1][0]:C_region[1][1]], C_reg[:]) + Tx.wg.copy_async(C_view[:, :], tmem[tuple(r_tmem_C)]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + Tx.copy(C[tid_in_wg, C_region[1][0]:C_region[1][1]], C_reg[:]) if warp_id == 0: - with Tx.warp(): - Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) - Tx.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=cta_group) + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=cta_group) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=cols_alloc, cta_group=cta_group) # fmt: on dev = tvm.cuda(0) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py index 87617f667284..9aba8b4316dd 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=missing-function-docstring -"""Tests for ``Tx.permute_layout``. +"""Tests for ``T.permute_layout``. Coverage: @@ -41,7 +41,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import S, SwizzleLayout, TileLayout # Helpers exposed by the dispatcher module for direct algorithm tests. @@ -191,19 +192,17 @@ def test_sf_blockwise_transpose(name, pipe, blk, dtype): post = TileLayout(S[shape:dst_strides]) # fmt: off - @Tx.prim_func - def f(A: Tx.handle, B: Tx.handle): - A_buf = Tx.match_buffer(A, shape, dtype, layout=pre) - B_buf = Tx.match_buffer(B, shape, dtype, layout=post) - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([32]) - with Tx.cta(): - with Tx.warp(): - for s in Tx.serial(0, pipe): - Tx.permute_layout( - B_buf[s, 0:high, 0:4, 0:32], A_buf[s, 0:high, 0:4, 0:32] - ) + @T.prim_func + def f(A: T.handle, B: T.handle): + A_buf = T.match_buffer(A, shape, dtype, layout=pre) + B_buf = T.match_buffer(B, shape, dtype, layout=post) + T.device_entry() + T.cta_id([1]) + T.thread_id([32]) + for s in T.serial(0, pipe): + Tx.warp.permute_layout( + B_buf[s, 0:high, 0:4, 0:32], A_buf[s, 0:high, 0:4, 0:32] + ) # fmt: on np.random.seed(0) @@ -239,16 +238,14 @@ def test_identity_passes_through_as_copy(): layout = TileLayout(S[shape : (32, 1)]) # fmt: off - @Tx.prim_func - def f(A: Tx.handle, B: Tx.handle): - A_buf = Tx.match_buffer(A, shape, "uint32", layout=layout) - B_buf = Tx.match_buffer(B, shape, "uint32", layout=layout) - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([32]) - with Tx.cta(): - with Tx.warp(): - Tx.permute_layout(B_buf, A_buf) + @T.prim_func + def f(A: T.handle, B: T.handle): + A_buf = T.match_buffer(A, shape, "uint32", layout=layout) + B_buf = T.match_buffer(B, shape, "uint32", layout=layout) + T.device_entry() + T.cta_id([1]) + T.thread_id([32]) + Tx.warp.permute_layout(B_buf, A_buf) # fmt: on np.random.seed(0) @@ -275,16 +272,14 @@ def test_generic_transpose(shape, src_strides, dst_strides, dtype): post = TileLayout(S[shape:dst_strides]) # fmt: off - @Tx.prim_func - def f(A: Tx.handle, B: Tx.handle): - A_buf = Tx.match_buffer(A, shape, dtype, layout=pre) - B_buf = Tx.match_buffer(B, shape, dtype, layout=post) - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([32]) - with Tx.cta(): - with Tx.warp(): - Tx.permute_layout(B_buf, A_buf) + @T.prim_func + def f(A: T.handle, B: T.handle): + A_buf = T.match_buffer(A, shape, dtype, layout=pre) + B_buf = T.match_buffer(B, shape, dtype, layout=post) + T.device_entry() + T.cta_id([1]) + T.thread_id([32]) + Tx.warp.permute_layout(B_buf, A_buf) # fmt: on np.random.seed(0) @@ -304,16 +299,14 @@ def f(A: Tx.handle, B: Tx.handle): def _build_and_assert_rejected(shape, src_layout, dst_layout, dtype, msg_substr): # fmt: off - @Tx.prim_func - def f(A: Tx.handle, B: Tx.handle): - A_buf = Tx.match_buffer(A, shape, dtype, layout=src_layout) - B_buf = Tx.match_buffer(B, shape, dtype, layout=dst_layout) - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([32]) - with Tx.cta(): - with Tx.warp(): - Tx.permute_layout(B_buf, A_buf) + @T.prim_func + def f(A: T.handle, B: T.handle): + A_buf = T.match_buffer(A, shape, dtype, layout=src_layout) + B_buf = T.match_buffer(B, shape, dtype, layout=dst_layout) + T.device_entry() + T.cta_id([1]) + T.thread_id([32]) + Tx.warp.permute_layout(B_buf, A_buf) # fmt: on target = tvm.target.Target("cuda") @@ -330,16 +323,14 @@ def test_reject_dtype_mismatch(): layout = TileLayout(S[shape : (32, 1)]) # fmt: off - @Tx.prim_func - def f(A: Tx.handle, B: Tx.handle): - A_buf = Tx.match_buffer(A, shape, "uint32", layout=layout) - B_buf = Tx.match_buffer(B, shape, "uint16", layout=layout) - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([32]) - with Tx.cta(): - with Tx.warp(): - Tx.permute_layout(B_buf, A_buf) + @T.prim_func + def f(A: T.handle, B: T.handle): + A_buf = T.match_buffer(A, shape, "uint32", layout=layout) + B_buf = T.match_buffer(B, shape, "uint16", layout=layout) + T.device_entry() + T.cta_id([1]) + T.thread_id([32]) + Tx.warp.permute_layout(B_buf, A_buf) # fmt: on target = tvm.target.Target("cuda") @@ -353,16 +344,14 @@ def test_reject_shape_mismatch(): dst_layout = TileLayout(S[(8, 16) : (16, 1)]) # fmt: off - @Tx.prim_func - def f(A: Tx.handle, B: Tx.handle): - A_buf = Tx.match_buffer(A, (4, 32), "uint32", layout=src_layout) - B_buf = Tx.match_buffer(B, (8, 16), "uint32", layout=dst_layout) - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([32]) - with Tx.cta(): - with Tx.warp(): - Tx.permute_layout(B_buf, A_buf) + @T.prim_func + def f(A: T.handle, B: T.handle): + A_buf = T.match_buffer(A, (4, 32), "uint32", layout=src_layout) + B_buf = T.match_buffer(B, (8, 16), "uint32", layout=dst_layout) + T.device_entry() + T.cta_id([1]) + T.thread_id([32]) + Tx.warp.permute_layout(B_buf, A_buf) # fmt: on target = tvm.target.Target("cuda") @@ -381,16 +370,14 @@ def test_reject_swizzle_layout(): plain = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off - @Tx.prim_func - def f(A: Tx.handle, B: Tx.handle): - A_buf = Tx.match_buffer(A, (4, 32), "uint32", layout=swizzled) - B_buf = Tx.match_buffer(B, (4, 32), "uint32", layout=plain) - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([32]) - with Tx.cta(): - with Tx.warp(): - Tx.permute_layout(B_buf, A_buf) + @T.prim_func + def f(A: T.handle, B: T.handle): + A_buf = T.match_buffer(A, (4, 32), "uint32", layout=swizzled) + B_buf = T.match_buffer(B, (4, 32), "uint32", layout=plain) + T.device_entry() + T.cta_id([1]) + T.thread_id([32]) + Tx.warp.permute_layout(B_buf, A_buf) # fmt: on target = tvm.target.Target("cuda") @@ -404,15 +391,14 @@ def test_reject_non_warp_scope(): layout_post = TileLayout(S[(4, 32) : (1, 4)]) # fmt: off - @Tx.prim_func - def f(A: Tx.handle, B: Tx.handle): - A_buf = Tx.match_buffer(A, (4, 32), "uint32", layout=layout_pre) - B_buf = Tx.match_buffer(B, (4, 32), "uint32", layout=layout_post) - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([32]) - with Tx.cta(): - Tx.permute_layout(B_buf, A_buf) # cta scope, not warp + @T.prim_func + def f(A: T.handle, B: T.handle): + A_buf = T.match_buffer(A, (4, 32), "uint32", layout=layout_pre) + B_buf = T.match_buffer(B, (4, 32), "uint32", layout=layout_post) + T.device_entry() + T.cta_id([1]) + T.thread_id([32]) + Tx.cta.permute_layout(B_buf, A_buf) # cta scope, not warp # fmt: on target = tvm.target.Target("cuda") diff --git a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py index 3009e6420955..0474ad2dc46a 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py @@ -19,7 +19,8 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import R, S, TileLayout, laneid, wg_local_layout @@ -65,31 +66,29 @@ def test_reduction_shared( g_layout_dst = s_layout_dst = TileLayout(S[dst_shape]) # fmt: off - @Tx.prim_func - def test_reduction(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, dtype, layout=g_layout_src) - B = Tx.match_buffer(B_ptr, dst_shape, dtype, layout=g_layout_dst) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([thread_cnt]) - - with Tx.cta(): - A_smem = Tx.alloc_buffer(s_shape_src, dtype, scope="shared", layout=s_layout_src) - B_smem = Tx.alloc_buffer(s_shape_dst, dtype, scope="shared", layout=s_layout_dst) - - Tx.copy(A_smem[tuple(copy_slice_src)], A[tuple(copy_slice_src)]) - if accum: - Tx.copy(B_smem[tuple(copy_slice_dst)], B[tuple(copy_slice_dst)]) - Tx.cuda.cta_sync() - if op_type == "sum": - Tx.sum(B_smem[tuple(reduce_slice_dst)], A_smem[tuple(reduce_slice_src)], axes=axes, accum=accum) # noqa: E501 - elif op_type == "max": - Tx.max(B_smem[tuple(reduce_slice_dst)], A_smem[tuple(reduce_slice_src)], axes=axes, accum=accum) # noqa: E501 - elif op_type == "min": - Tx.min(B_smem[tuple(reduce_slice_dst)], A_smem[tuple(reduce_slice_src)], axes=axes, accum=accum) # noqa: E501 - Tx.cuda.cta_sync() - Tx.copy(B[tuple(copy_slice_dst)], B_smem[tuple(copy_slice_dst)]) + @T.prim_func + def test_reduction(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, dtype, layout=g_layout_src) + B = T.match_buffer(B_ptr, dst_shape, dtype, layout=g_layout_dst) + + T.device_entry() + _bx = T.cta_id([1]) + _tid = T.thread_id([thread_cnt]) + A_smem = T.alloc_buffer(s_shape_src, dtype, scope="shared", layout=s_layout_src) + B_smem = T.alloc_buffer(s_shape_dst, dtype, scope="shared", layout=s_layout_dst) + + Tx.cta.copy(A_smem[tuple(copy_slice_src)], A[tuple(copy_slice_src)]) + if accum: + Tx.cta.copy(B_smem[tuple(copy_slice_dst)], B[tuple(copy_slice_dst)]) + T.cuda.cta_sync() + if op_type == "sum": + Tx.cta.sum(B_smem[tuple(reduce_slice_dst)], A_smem[tuple(reduce_slice_src)], axes=axes, accum=accum) # noqa: E501 + elif op_type == "max": + Tx.cta.max(B_smem[tuple(reduce_slice_dst)], A_smem[tuple(reduce_slice_src)], axes=axes, accum=accum) # noqa: E501 + elif op_type == "min": + Tx.cta.min(B_smem[tuple(reduce_slice_dst)], A_smem[tuple(reduce_slice_src)], axes=axes, accum=accum) # noqa: E501 + T.cuda.cta_sync() + Tx.cta.copy(B[tuple(copy_slice_dst)], B_smem[tuple(copy_slice_dst)]) # fmt: on target = tvm.target.Target("cuda") @@ -146,82 +145,76 @@ def test_reduction_shared_subscope(exec_scope, op_type, accum): # fmt: off if exec_scope == "warp": - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, dtype, layout=g_layout_src) - B = Tx.match_buffer(B_ptr, dst_shape, dtype, layout=g_layout_dst) - Tx.device_entry() - warp_id = Tx.warp_id([(256) // 32]) - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([256]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(list(src_shape), dtype, scope="shared", layout=s_layout_src) # noqa: E501 - B_smem = Tx.alloc_buffer(list(dst_shape), dtype, scope="shared", layout=s_layout_dst) # noqa: E501 - Tx.copy(A_smem, A) - if accum: - Tx.copy(B_smem, B) - Tx.cuda.cta_sync() - if warp_id == 5: - with Tx.warp(): - if op_type == "sum": - Tx.sum(B_smem, A_smem, axes=axes, accum=accum) - elif op_type == "max": - Tx.max(B_smem, A_smem, axes=axes, accum=accum) - elif op_type == "min": - Tx.min(B_smem, A_smem, axes=axes, accum=accum) - Tx.cuda.cta_sync() - Tx.copy(B, B_smem) + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, dtype, layout=g_layout_src) + B = T.match_buffer(B_ptr, dst_shape, dtype, layout=g_layout_dst) + T.device_entry() + warp_id = T.warp_id([(256) // 32]) + _bx = T.cta_id([1]) + _tid = T.thread_id([256]) + A_smem = T.alloc_buffer(list(src_shape), dtype, scope="shared", layout=s_layout_src) + B_smem = T.alloc_buffer(list(dst_shape), dtype, scope="shared", layout=s_layout_dst) + Tx.cta.copy(A_smem, A) + if accum: + Tx.cta.copy(B_smem, B) + T.cuda.cta_sync() + if warp_id == 5: + if op_type == "sum": + Tx.warp.sum(B_smem, A_smem, axes=axes, accum=accum) + elif op_type == "max": + Tx.warp.max(B_smem, A_smem, axes=axes, accum=accum) + elif op_type == "min": + Tx.warp.min(B_smem, A_smem, axes=axes, accum=accum) + T.cuda.cta_sync() + Tx.cta.copy(B, B_smem) elif exec_scope == "warpgroup": - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, dtype, layout=g_layout_src) - B = Tx.match_buffer(B_ptr, dst_shape, dtype, layout=g_layout_dst) - Tx.device_entry() - wg_id = Tx.warpgroup_id([(256) // 128]) - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([256]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(list(src_shape), dtype, scope="shared", layout=s_layout_src) # noqa: E501 - B_smem = Tx.alloc_buffer(list(dst_shape), dtype, scope="shared", layout=s_layout_dst) # noqa: E501 - Tx.copy(A_smem, A) - if accum: - Tx.copy(B_smem, B) - Tx.cuda.cta_sync() - if wg_id == 0: - with Tx.warpgroup(): - if op_type == "sum": - Tx.sum(B_smem, A_smem, axes=axes, accum=accum) - elif op_type == "max": - Tx.max(B_smem, A_smem, axes=axes, accum=accum) - elif op_type == "min": - Tx.min(B_smem, A_smem, axes=axes, accum=accum) - Tx.cuda.cta_sync() - Tx.copy(B, B_smem) + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, dtype, layout=g_layout_src) + B = T.match_buffer(B_ptr, dst_shape, dtype, layout=g_layout_dst) + T.device_entry() + wg_id = T.warpgroup_id([(256) // 128]) + _bx = T.cta_id([1]) + _tid = T.thread_id([256]) + A_smem = T.alloc_buffer(list(src_shape), dtype, scope="shared", layout=s_layout_src) + B_smem = T.alloc_buffer(list(dst_shape), dtype, scope="shared", layout=s_layout_dst) + Tx.cta.copy(A_smem, A) + if accum: + Tx.cta.copy(B_smem, B) + T.cuda.cta_sync() + if wg_id == 0: + if op_type == "sum": + Tx.wg.sum(B_smem, A_smem, axes=axes, accum=accum) + elif op_type == "max": + Tx.wg.max(B_smem, A_smem, axes=axes, accum=accum) + elif op_type == "min": + Tx.wg.min(B_smem, A_smem, axes=axes, accum=accum) + T.cuda.cta_sync() + Tx.cta.copy(B, B_smem) elif exec_scope == "thread": - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, dtype, layout=g_layout_src) - B = Tx.match_buffer(B_ptr, dst_shape, dtype, layout=g_layout_dst) - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([256]) - with Tx.cta(): - A_smem = Tx.alloc_buffer(list(src_shape), dtype, scope="shared", layout=s_layout_src) # noqa: E501 - B_smem = Tx.alloc_buffer(list(dst_shape), dtype, scope="shared", layout=s_layout_dst) # noqa: E501 - Tx.copy(A_smem, A) - if accum: - Tx.copy(B_smem, B) - Tx.cuda.cta_sync() - if _tid == 65: - with Tx.thread(): - if op_type == "sum": - Tx.sum(B_smem, A_smem, axes=axes, accum=accum) - elif op_type == "max": - Tx.max(B_smem, A_smem, axes=axes, accum=accum) - elif op_type == "min": - Tx.min(B_smem, A_smem, axes=axes, accum=accum) - Tx.cuda.cta_sync() - Tx.copy(B, B_smem) + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, dtype, layout=g_layout_src) + B = T.match_buffer(B_ptr, dst_shape, dtype, layout=g_layout_dst) + T.device_entry() + _bx = T.cta_id([1]) + _tid = T.thread_id([256]) + A_smem = T.alloc_buffer(list(src_shape), dtype, scope="shared", layout=s_layout_src) + B_smem = T.alloc_buffer(list(dst_shape), dtype, scope="shared", layout=s_layout_dst) + Tx.cta.copy(A_smem, A) + if accum: + Tx.cta.copy(B_smem, B) + T.cuda.cta_sync() + if _tid == 65: + if op_type == "sum": + Tx.sum(B_smem, A_smem, axes=axes, accum=accum) + elif op_type == "max": + Tx.max(B_smem, A_smem, axes=axes, accum=accum) + elif op_type == "min": + Tx.min(B_smem, A_smem, axes=axes, accum=accum) + T.cuda.cta_sync() + Tx.cta.copy(B, B_smem) # fmt: on target = tvm.target.Target("cuda") @@ -294,38 +287,36 @@ def decompose_flat(flat_idx, shape): return indices # fmt: off - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, list(src_shape), dtype, layout=TileLayout(S[src_shape])) - B = Tx.match_buffer(B_ptr, list(dst_shape), dtype, layout=TileLayout(S[dst_shape])) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([1]) + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, list(src_shape), dtype, layout=TileLayout(S[src_shape])) + B = T.match_buffer(B_ptr, list(dst_shape), dtype, layout=TileLayout(S[dst_shape])) - with Tx.thread(): - A_local = Tx.alloc_buffer(list(src_shape), dtype, scope="local") - B_local = Tx.alloc_buffer(list(dst_shape), dtype, scope="local") + T.device_entry() + _bx = T.cta_id([1]) + _tid = T.thread_id([1]) + A_local = T.alloc_buffer(list(src_shape), dtype, scope="local") + B_local = T.alloc_buffer(list(dst_shape), dtype, scope="local") - for i in Tx.serial(src_total): - idx = Tx.meta_var(decompose_flat(i, src_shape)) - A_local[tuple(idx)] = A[tuple(idx)] + for i in T.serial(src_total): + idx = T.meta_var(decompose_flat(i, src_shape)) + A_local[tuple(idx)] = A[tuple(idx)] - if accum: - for i in Tx.serial(dst_total): - idx = Tx.meta_var(decompose_flat(i, dst_shape)) - B_local[tuple(idx)] = B[tuple(idx)] + if accum: + for i in T.serial(dst_total): + idx = T.meta_var(decompose_flat(i, dst_shape)) + B_local[tuple(idx)] = B[tuple(idx)] - if op_type == "sum": - Tx.sum(B_local, A_local, axes=axes, accum=accum) - elif op_type == "max": - Tx.max(B_local, A_local, axes=axes, accum=accum) - elif op_type == "min": - Tx.min(B_local, A_local, axes=axes, accum=accum) + if op_type == "sum": + Tx.sum(B_local, A_local, axes=axes, accum=accum) + elif op_type == "max": + Tx.max(B_local, A_local, axes=axes, accum=accum) + elif op_type == "min": + Tx.min(B_local, A_local, axes=axes, accum=accum) - for i in Tx.serial(dst_total): - idx = Tx.meta_var(decompose_flat(i, dst_shape)) - B[tuple(idx)] = B_local[tuple(idx)] + for i in T.serial(dst_total): + idx = T.meta_var(decompose_flat(i, dst_shape)) + B[tuple(idx)] = B_local[tuple(idx)] # fmt: on target = tvm.target.Target("cuda") @@ -394,11 +385,11 @@ def row_major_strides(dims): s *= d return strides - acc_view_layout = Tx.TileLayout( - Tx.S[src_shape : (1 @ laneid, *tuple(row_major_strides(inner_dims)))] + acc_view_layout = T.TileLayout( + T.S[src_shape : (1 @ laneid, *tuple(row_major_strides(inner_dims)))] ) - red_view_layout = Tx.TileLayout( - Tx.S[dst_shape : (1 @ laneid, *tuple(row_major_strides(dst_dims)))] + red_view_layout = T.TileLayout( + T.S[dst_shape : (1 @ laneid, *tuple(row_major_strides(dst_dims)))] ) g_layout_a = TileLayout(S[src_shape]) g_layout_b = TileLayout(S[dst_shape]) @@ -420,49 +411,44 @@ def decompose_flat(flat_idx, shape): return indices # fmt: off - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, list(src_shape), dtype, layout=g_layout_a) - B = Tx.match_buffer(B_ptr, list(dst_shape), dtype, layout=g_layout_b) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([thread_cnt]) - - acc = Tx.alloc_buffer(list((1, *inner_dims)), dtype=dtype, scope="local", layout=g_layout_a) - red = Tx.alloc_buffer(list((1, *dst_dims)), dtype=dtype, scope="local", layout=g_layout_b) - - with Tx.thread(): - for i in Tx.serial(src_local_total): - idx = Tx.meta_var(decompose_flat(i, inner_dims)) - acc[(0, *list(idx))] = A[(lane_id, *list(idx))] - if accum: - for i in Tx.serial(dst_local_total): - idx = Tx.meta_var(decompose_flat(i, dst_dims)) - red[(0, *list(idx))] = B[(lane_id, *list(idx))] - with Tx.warp(): - acc_view = acc.view(*src_shape, layout=acc_view_layout) - red_view = red.view(*dst_shape, layout=red_view_layout) - if slice_end is not None: - if op_type == "sum": - Tx.sum(red_view, acc_view[:, slice_end // 2:slice_end], axes=axes, accum=accum) - elif op_type == "max": - Tx.max(red_view, acc_view[:, slice_end // 2:slice_end], axes=axes, accum=accum) - elif op_type == "min": - Tx.min(red_view, acc_view[:, slice_end // 2:slice_end], axes=axes, accum=accum) - else: - if op_type == "sum": - Tx.sum(red_view, acc_view, axes=axes, accum=accum) - elif op_type == "max": - Tx.max(red_view, acc_view, axes=axes, accum=accum) - elif op_type == "min": - Tx.min(red_view, acc_view, axes=axes, accum=accum) - - with Tx.thread(): - for i in Tx.serial(dst_local_total): - idx = Tx.meta_var(decompose_flat(i, dst_dims)) - B[(lane_id, *list(idx))] = red[(0, *list(idx))] + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, list(src_shape), dtype, layout=g_layout_a) + B = T.match_buffer(B_ptr, list(dst_shape), dtype, layout=g_layout_b) + + T.device_entry() + _bx = T.cta_id([1]) + _warp_id = T.warp_id([1]) + lane_id = T.lane_id([thread_cnt]) + + acc = T.alloc_buffer(list((1, *inner_dims)), dtype=dtype, scope="local", layout=g_layout_a) + red = T.alloc_buffer(list((1, *dst_dims)), dtype=dtype, scope="local", layout=g_layout_b) + for i in T.serial(src_local_total): + idx = T.meta_var(decompose_flat(i, inner_dims)) + acc[(0, *list(idx))] = A[(lane_id, *list(idx))] + if accum: + for i in T.serial(dst_local_total): + idx = T.meta_var(decompose_flat(i, dst_dims)) + red[(0, *list(idx))] = B[(lane_id, *list(idx))] + acc_view = acc.view(*src_shape, layout=acc_view_layout) + red_view = red.view(*dst_shape, layout=red_view_layout) + if slice_end is not None: + if op_type == "sum": + Tx.warp.sum(red_view, acc_view[:, slice_end // 2:slice_end], axes=axes, accum=accum) + elif op_type == "max": + Tx.warp.max(red_view, acc_view[:, slice_end // 2:slice_end], axes=axes, accum=accum) + elif op_type == "min": + Tx.warp.min(red_view, acc_view[:, slice_end // 2:slice_end], axes=axes, accum=accum) + else: + if op_type == "sum": + Tx.warp.sum(red_view, acc_view, axes=axes, accum=accum) + elif op_type == "max": + Tx.warp.max(red_view, acc_view, axes=axes, accum=accum) + elif op_type == "min": + Tx.warp.min(red_view, acc_view, axes=axes, accum=accum) + for i in T.serial(dst_local_total): + idx = T.meta_var(decompose_flat(i, dst_dims)) + B[(lane_id, *list(idx))] = red[(0, *list(idx))] # fmt: on target = tvm.target.Target("cuda") @@ -517,87 +503,77 @@ def test_reduction_local_view_complex(n_groups, n_warps, op_type, dtype, shuffle acc_shape, red_shape = (16, NUM_COL), (16, 4) # fmt: off - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape_a, dtype, layout=g_layout_a) - B = Tx.match_buffer(B_ptr, g_shape_b, dtype, layout=g_layout_b) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([n_groups]) - warp_id_in_wg = Tx.warp_id_in_wg([n_warps // n_groups]) - lane_id = Tx.lane_id([thread_cnt]) - - with Tx.thread(): - # acc layout - atom = Tx.TileLayout(Tx.S[(1, 2) : (2, 1)]) - warp_layout = Tx.TileLayout(Tx.S[(8, 4) : (4@laneid, 1@laneid)]) - warp_atom = atom.tile(warp_layout, (8, 4), (1, 2)) - tile = Tx.TileLayout(Tx.S[(2, NUM_COL // 8) : (1, 2)]) - acc_layout = warp_atom.tile(tile, (2, NUM_COL // 8), (8, 8)) - acc = Tx.alloc_buffer( - [2, NUM_COL // 4], - dtype=dtype, - scope="local", - layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), - ) - - # red layout - red_atom = Tx.TileLayout(Tx.S[(1, 1) : (1, 1)]) - red_warp_atom = red_atom.tile(warp_layout, (8, 4), (1, 1)) - red_tile = Tx.TileLayout(Tx.S[(2, 1) : (1, 1)]) - red_layout = red_warp_atom.tile(red_tile, (2, 1), (8, 4)) - red = Tx.alloc_buffer( - [2], - dtype=dtype, - scope="local", - layout=red_atom.tile(red_tile, (2, 1), (1, 1)), + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape_a, dtype, layout=g_layout_a) + B = T.match_buffer(B_ptr, g_shape_b, dtype, layout=g_layout_b) + + T.device_entry() + _bx = T.cta_id([1]) + wg_id = T.warpgroup_id([n_groups]) + warp_id_in_wg = T.warp_id_in_wg([n_warps // n_groups]) + lane_id = T.lane_id([thread_cnt]) + # acc layout + atom = T.TileLayout(T.S[(1, 2) : (2, 1)]) + warp_layout = T.TileLayout(T.S[(8, 4) : (4@laneid, 1@laneid)]) + warp_atom = atom.tile(warp_layout, (8, 4), (1, 2)) + tile = T.TileLayout(T.S[(2, NUM_COL // 8) : (1, 2)]) + acc_layout = warp_atom.tile(tile, (2, NUM_COL // 8), (8, 8)) + acc = T.alloc_buffer( + [2, NUM_COL // 4], + dtype=dtype, + scope="local", + layout=atom.tile(tile, (2, NUM_COL // 8), (1, 2)), + ) + + # red layout + red_atom = T.TileLayout(T.S[(1, 1) : (1, 1)]) + red_warp_atom = red_atom.tile(warp_layout, (8, 4), (1, 1)) + red_tile = T.TileLayout(T.S[(2, 1) : (1, 1)]) + red_layout = red_warp_atom.tile(red_tile, (2, 1), (8, 4)) + red = T.alloc_buffer( + [2], + dtype=dtype, + scope="local", + layout=red_atom.tile(red_tile, (2, 1), (1, 1)), + ) + for i in T.serial(NUM_COL // 8): + for j in T.unroll(2): + for vec in T.vectorized(2): + acc[j, i * 2 + vec] = A[ + wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, + i * 8 + lane_id % 4 * 2 + vec, + ] + + # Pre-load B into red for accumulation + if accum: + for i in T.unroll(2): + red[i] = B[ + wg_id * 64 + warp_id_in_wg * 16 + i * 8 + lane_id // 4, + lane_id % 4, + ] + acc_view = acc.view(*acc_shape, layout=acc_layout) + red_view = red.view(*red_shape, layout=red_layout) + if op_type == "sum": + Tx.warp.sum(red_view, acc_view, thread_reduce=shuffle, accum=accum) + elif op_type == "max": + Tx.warp.max(red_view, acc_view, thread_reduce=shuffle, accum=accum) + elif op_type == "min": + Tx.warp.min(red_view, acc_view, thread_reduce=shuffle, accum=accum) + # perform an additional shuffle step if not shuffled above + if not shuffle: + if op_type == "sum": + Tx.warp.sum(red_view, red_view, thread_reduce=True) + elif op_type == "max": + Tx.warp.max(red_view, red_view, thread_reduce=True) + elif op_type == "min": + Tx.warp.min(red_view, red_view, thread_reduce=True) + # Write red into B + for i in T.unroll(2): + B[wg_id * 64 + warp_id_in_wg * 16 + i * 8 + lane_id // 4, lane_id % 4] = ( + red[i] ) - # Load A into acc - with Tx.thread(): - for i in Tx.serial(NUM_COL // 8): - for j in Tx.unroll(2): - for vec in Tx.vectorized(2): - acc[j, i * 2 + vec] = A[ - wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, - i * 8 + lane_id % 4 * 2 + vec, - ] - - # Pre-load B into red for accumulation - if accum: - with Tx.thread(): - for i in Tx.unroll(2): - red[i] = B[ - wg_id * 64 + warp_id_in_wg * 16 + i * 8 + lane_id // 4, - lane_id % 4, - ] - - # Reduce - with Tx.warp(): - acc_view = acc.view(*acc_shape, layout=acc_layout) - red_view = red.view(*red_shape, layout=red_layout) - if op_type == "sum": - Tx.sum(red_view, acc_view, thread_reduce=shuffle, accum=accum) - elif op_type == "max": - Tx.max(red_view, acc_view, thread_reduce=shuffle, accum=accum) - elif op_type == "min": - Tx.min(red_view, acc_view, thread_reduce=shuffle, accum=accum) - # perform an additional shuffle step if not shuffled above - if not shuffle: - if op_type == "sum": - Tx.sum(red_view, red_view, thread_reduce=True) - elif op_type == "max": - Tx.max(red_view, red_view, thread_reduce=True) - elif op_type == "min": - Tx.min(red_view, red_view, thread_reduce=True) - # Write red into B - with Tx.thread(): - for i in Tx.unroll(2): - B[wg_id * 64 + warp_id_in_wg * 16 + i * 8 + lane_id // 4, lane_id % 4] = ( - red[i] - ) - # fmt: on target = tvm.target.Target("cuda") @@ -649,35 +625,33 @@ def test_reduction_local_optimized_3input_maxmin(reduction_len, op_type, accum): dtype = "float32" # fmt: off - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, [reduction_len], dtype, layout=TileLayout(S[reduction_len])) - B = Tx.match_buffer(B_ptr, [1], dtype, layout=TileLayout(S[1])) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([1]) - - with Tx.thread(): - A_local = Tx.alloc_buffer([reduction_len], dtype, scope="local") - B_local = Tx.alloc_buffer([1], dtype, scope="local") - - # Load from global to local - for i in Tx.serial(reduction_len): - A_local[i] = A[i] - - # Initialize B_local for accum test - if accum: - B_local[0] = B[0] + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, [reduction_len], dtype, layout=TileLayout(S[reduction_len])) + B = T.match_buffer(B_ptr, [1], dtype, layout=TileLayout(S[1])) + + T.device_entry() + _bx = T.cta_id([1]) + _tid = T.thread_id([1]) + A_local = T.alloc_buffer([reduction_len], dtype, scope="local") + B_local = T.alloc_buffer([1], dtype, scope="local") + + # Load from global to local + for i in T.serial(reduction_len): + A_local[i] = A[i] + + # Initialize B_local for accum test + if accum: + B_local[0] = B[0] - # Thread-level reduction - if op_type == "max": - Tx.max(B_local, A_local, accum=accum) - elif op_type == "min": - Tx.min(B_local, A_local, accum=accum) + # Thread-level reduction + if op_type == "max": + Tx.max(B_local, A_local, accum=accum) + elif op_type == "min": + Tx.min(B_local, A_local, accum=accum) - # Store result to global - B[0] = B_local[0] + # Store result to global + B[0] = B_local[0] # fmt: on target = tvm.target.Target("cuda") @@ -719,32 +693,30 @@ def test_reduction_local_optimized_packed_add_sum(reduction_len, accum): dtype = "float32" # fmt: off - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, [reduction_len], dtype, layout=TileLayout(S[reduction_len])) - B = Tx.match_buffer(B_ptr, [1], dtype, layout=TileLayout(S[1])) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - _tid = Tx.thread_id([1]) - - with Tx.thread(): - A_local = Tx.alloc_buffer([reduction_len], dtype, scope="local") - B_local = Tx.alloc_buffer([1], dtype, scope="local") - - # Load from global to local - for i in Tx.serial(reduction_len): - A_local[i] = A[i] - - # Initialize B_local for accum test - if accum: - B_local[0] = B[0] + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, [reduction_len], dtype, layout=TileLayout(S[reduction_len])) + B = T.match_buffer(B_ptr, [1], dtype, layout=TileLayout(S[1])) + + T.device_entry() + _bx = T.cta_id([1]) + _tid = T.thread_id([1]) + A_local = T.alloc_buffer([reduction_len], dtype, scope="local") + B_local = T.alloc_buffer([1], dtype, scope="local") + + # Load from global to local + for i in T.serial(reduction_len): + A_local[i] = A[i] + + # Initialize B_local for accum test + if accum: + B_local[0] = B[0] - # Thread-level sum reduction - Tx.sum(B_local, A_local, accum=accum) + # Thread-level sum reduction + Tx.sum(B_local, A_local, accum=accum) - # Store result to global - B[0] = B_local[0] + # Store result to global + B[0] = B_local[0] # fmt: on # Use sm_100a target for packed add sum dispatch @@ -792,33 +764,25 @@ def test_reduction_op_warp_shuffle(op_type, dtype): dst_layout = TileLayout(S[1:1] + R[N : 1 @ laneid]) # fmt: off - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=g_layout) - B = Tx.match_buffer(B_ptr, g_shape, dtype, layout=g_layout) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - - with Tx.thread(): - src_local = Tx.alloc_buffer([1], dtype, scope="local") - dst_local = Tx.alloc_buffer([1], dtype, scope="local") - - with Tx.thread(): - src_local[0] = A[lane_id] - - with Tx.warp(): - src_view = src_local.view(N, layout=src_layout) - dst_view = dst_local.view(1, layout=dst_layout) - if op_type == "sum": - Tx.sum(dst_view, src_view) - elif op_type == "max": - Tx.max(dst_view, src_view) - - with Tx.thread(): - B[lane_id] = dst_local[0] + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=g_layout) + B = T.match_buffer(B_ptr, g_shape, dtype, layout=g_layout) + + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + src_local = T.alloc_buffer([1], dtype, scope="local") + dst_local = T.alloc_buffer([1], dtype, scope="local") + src_local[0] = A[lane_id] + src_view = src_local.view(N, layout=src_layout) + dst_view = dst_local.view(1, layout=dst_layout) + if op_type == "sum": + Tx.warp.sum(dst_view, src_view) + elif op_type == "max": + Tx.warp.max(dst_view, src_view) + B[lane_id] = dst_local[0] # fmt: on target = tvm.target.Target("cuda") @@ -864,36 +828,28 @@ def test_reduction_op_warp_shuffle_multi_elem(op_type, dtype): dst_layout = TileLayout(S[ELEMS_PER_THREAD:1] + R[N_LANES : 1 @ laneid]) # fmt: off - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, g_shape, dtype, layout=g_layout) + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, g_shape, dtype, layout=g_layout) dst_lay = TileLayout(S[ELEMS_PER_THREAD]) - B = Tx.match_buffer(B_ptr, [ELEMS_PER_THREAD], dtype, layout=dst_lay) - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - - with Tx.thread(): - src_local = Tx.alloc_buffer([ELEMS_PER_THREAD], dtype, scope="local") - dst_local = Tx.alloc_buffer([ELEMS_PER_THREAD], dtype, scope="local") - - with Tx.thread(): - for i in Tx.serial(ELEMS_PER_THREAD): - src_local[i] = A[lane_id * ELEMS_PER_THREAD + i] - - with Tx.warp(): - src_view = src_local.view(TOTAL, layout=src_layout) - dst_view = dst_local.view(ELEMS_PER_THREAD, layout=dst_layout) - if op_type == "sum": - Tx.sum(dst_view, src_view) - elif op_type == "max": - Tx.max(dst_view, src_view) - - with Tx.thread(): - for i in Tx.serial(ELEMS_PER_THREAD): - B[i] = dst_local[i] + B = T.match_buffer(B_ptr, [ELEMS_PER_THREAD], dtype, layout=dst_lay) + + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + src_local = T.alloc_buffer([ELEMS_PER_THREAD], dtype, scope="local") + dst_local = T.alloc_buffer([ELEMS_PER_THREAD], dtype, scope="local") + for i in T.serial(ELEMS_PER_THREAD): + src_local[i] = A[lane_id * ELEMS_PER_THREAD + i] + src_view = src_local.view(TOTAL, layout=src_layout) + dst_view = dst_local.view(ELEMS_PER_THREAD, layout=dst_layout) + if op_type == "sum": + Tx.warp.sum(dst_view, src_view) + elif op_type == "max": + Tx.warp.max(dst_view, src_view) + for i in T.serial(ELEMS_PER_THREAD): + B[i] = dst_local[i] # fmt: on target = tvm.target.Target("cuda") @@ -920,7 +876,7 @@ def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: def test_reduction_warp_shuffle_multi_warp_loop(): - """Test intra-warp + cross-warp reduction via Tx.sum in a for loop with multiple warps. + """Test intra-warp + cross-warp reduction via T.sum in a for loop with multiple warps. Validates the scope alternation pattern (thread → warp → thread) inside a loop, which is needed for replacing manual warp shuffle reductions in tirx-kernels. @@ -935,65 +891,47 @@ def test_reduction_warp_shuffle_multi_warp_loop(): dst_layout = TileLayout(S[1:1] + R[BDX : 1 @ laneid]) # fmt: off - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, [N_ITER, N], "float32", scope="global") - B = Tx.match_buffer(B_ptr, [N_ITER], "float32", scope="global") - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - ty = Tx.warp_id([BDY]) - tx = Tx.lane_id([BDX]) - thread_id = Tx.meta_var(ty * BDX + tx) - - with Tx.cta(): - pool = Tx.SMEMPool() - sum_smem = pool.alloc([BDY], "float32") - pool.commit() - - with Tx.thread(): - partial_buf = Tx.alloc_buffer([1], "float32", scope="local") - result_buf = Tx.alloc_buffer([1], "float32", scope="local") - cross_buf = Tx.alloc_buffer([1], "float32", scope="local") - cross_res = Tx.alloc_buffer([1], "float32", scope="local") - - for it in Tx.serial(N_ITER): - # Phase 1: each thread loads its value - with Tx.thread(): - partial_buf[0] = A[it, thread_id] - - # Phase 2: intra-warp reduction - with Tx.warp(): - src_v = partial_buf.view(BDX, layout=src_layout) - dst_v = result_buf.view(1, layout=dst_layout) - Tx.sum(dst_v, src_v) - - # Phase 3: write per-warp result to smem - with Tx.thread(): - sum_smem[ty] = result_buf[0] - Tx.cuda.cta_sync() - - # Phase 4: cross-warp reduction (warp 0 only) - if ty == 0: - with Tx.thread(): - if tx < BDY: - cross_buf[0] = sum_smem[tx] - else: - cross_buf[0] = Tx.float32(0) - with Tx.warp(): - cs = cross_buf.view(BDX, layout=src_layout) - cd = cross_res.view(1, layout=dst_layout) - Tx.sum(cd, cs) - with Tx.thread(): - sum_smem[0] = cross_res[0] - Tx.cuda.cta_sync() - - # Phase 5: one thread writes result to global - with Tx.thread(): - if tx == 0: - if ty == 0: - B[it] = sum_smem[0] - Tx.cuda.cta_sync() + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, [N_ITER, N], "float32", scope="global") + B = T.match_buffer(B_ptr, [N_ITER], "float32", scope="global") + + T.device_entry() + cta_id = T.cta_id([1]) + ty = T.warp_id([BDY]) + tx = T.lane_id([BDX]) + thread_id = T.meta_var(ty * BDX + tx) + pool = T.SMEMPool() + sum_smem = pool.alloc([BDY], "float32") + pool.commit() + partial_buf = T.alloc_buffer([1], "float32", scope="local") + result_buf = T.alloc_buffer([1], "float32", scope="local") + cross_buf = T.alloc_buffer([1], "float32", scope="local") + cross_res = T.alloc_buffer([1], "float32", scope="local") + + for it in T.serial(N_ITER): + partial_buf[0] = A[it, thread_id] + src_v = partial_buf.view(BDX, layout=src_layout) + dst_v = result_buf.view(1, layout=dst_layout) + Tx.warp.sum(dst_v, src_v) + sum_smem[ty] = result_buf[0] + T.cuda.cta_sync() + + # Phase 4: cross-warp reduction (warp 0 only) + if ty == 0: + if tx < BDY: + cross_buf[0] = sum_smem[tx] + else: + cross_buf[0] = T.float32(0) + cs = cross_buf.view(BDX, layout=src_layout) + cd = cross_res.view(1, layout=dst_layout) + Tx.warp.sum(cd, cs) + sum_smem[0] = cross_res[0] + T.cuda.cta_sync() + if tx == 0: + if ty == 0: + B[it] = sum_smem[0] + T.cuda.cta_sync() # fmt: on target = tvm.target.Target("cuda") @@ -1020,33 +958,27 @@ def test_reduction_warpgroup_wg_local_layout(op_name): dev = tvm.cuda(0) target = tvm.target.Target("cuda") - @Tx.prim_func - def test_func(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) - B = Tx.match_buffer(B_ptr, (rows, 1), dtype, layout=TileLayout(S[(rows, 1)])) - - Tx.device_entry() - _bx = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([1]) - tid = Tx.thread_id_in_wg([rows]) - - src = Tx.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) - dst = Tx.alloc_buffer((rows, 1), dtype, scope="local", layout=wg_local_layout(1)) - - with Tx.thread(): - src_local = src.local(cols) - for i in Tx.serial(cols): - src_local[i] = A[tid, i] - - with Tx.warpgroup(): - if op_name == "sum": - Tx.sum(dst, src, axes=[-1], accum=False) - else: - Tx.max(dst, src, axes=[-1], accum=False) - - with Tx.thread(): - dst_local = dst.local(1) - B[tid, 0] = dst_local[0] + @T.prim_func + def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (rows, cols), dtype, layout=TileLayout(S[(rows, cols)])) + B = T.match_buffer(B_ptr, (rows, 1), dtype, layout=TileLayout(S[(rows, 1)])) + + T.device_entry() + _bx = T.cta_id([1]) + wg_id = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([rows]) + + src = T.alloc_buffer((rows, cols), dtype, scope="local", layout=wg_local_layout(cols)) + dst = T.alloc_buffer((rows, 1), dtype, scope="local", layout=wg_local_layout(1)) + src_local = src.local(cols) + for i in T.serial(cols): + src_local[i] = A[tid, i] + if op_name == "sum": + Tx.wg.sum(dst, src, axes=[-1], accum=False) + else: + Tx.wg.max(dst, src, axes=[-1], accum=False) + dst_local = dst.local(1) + B[tid, 0] = dst_local[0] with target: np.random.seed(0) diff --git a/tests/python/tirx/operator/tile_primitive/test_dispatcher.py b/tests/python/tirx/operator/tile_primitive/test_dispatcher.py index 95aa14472759..5ff5fe6caeaf 100644 --- a/tests/python/tirx/operator/tile_primitive/test_dispatcher.py +++ b/tests/python/tirx/operator/tile_primitive/test_dispatcher.py @@ -59,8 +59,8 @@ def __init__(self, op): self.op = op self.args = [] # not used by the tested predicates - # Use TRN copy; predicate requires exec_scope == "kernel". - op_call = _OpCall(Op.get("tirx.copy")) + # Use TRN copy; predicate requires exec_scope == "thread". + op_call = _OpCall(Op.get("tirx.tile.copy")) sctx = _DummySctx(target_kind="trn", exec_scope="warp") # intentionally wrong with pytest.raises(RuntimeError) as e: @@ -69,7 +69,7 @@ def __init__(self, op): out = str(e.value) print(out) # Header + per-variant reason must be printed in table format - assert "TIRx schedule dispatch failed: op=tirx.copy target=trn" in out + assert "TIRx schedule dispatch failed: op=tirx.tile.copy target=trn" in out assert "Variant" in out # table header present assert "default" in out # variant name present assert "rejected: exec_scope" in out @@ -88,15 +88,15 @@ def __init__(self, op): self.dispatch = "__nonexistent__" self.args = [] - op_call = _OpCall(Op.get("tirx.copy")) - sctx = _DummySctx(target_kind="trn", exec_scope="kernel") + op_call = _OpCall(Op.get("tirx.tile.copy")) + sctx = _DummySctx(target_kind="trn", exec_scope="thread") with pytest.raises(RuntimeError) as e: run_dispatch(op_call, sctx) msg = str(e.value) print(msg) - assert "TIRx schedule dispatch failed: op=tirx.copy target=trn" in msg + assert "TIRx schedule dispatch failed: op=tirx.tile.copy target=trn" in msg assert "no variant named '__nonexistent__' is registered" in msg @@ -112,15 +112,15 @@ def __init__(self, op): self.args = [] # Use TRN compose_op; variant implementation raises NotImplementedError - op_call = _OpCall(Op.get("tirx.compose_op")) - sctx = _DummySctx(target_kind="trn", exec_scope="kernel") + op_call = _OpCall(Op.get("tirx.tile.compose_op")) + sctx = _DummySctx(target_kind="trn", exec_scope="thread") with pytest.raises(RuntimeError) as e: run_dispatch(op_call, sctx) msg = str(e.value) print(msg) - assert "TIRx schedule dispatch failed: op=tirx.compose_op target=trn" in msg + assert "TIRx schedule dispatch failed: op=tirx.tile.compose_op target=trn" in msg assert "default" in msg assert "exception — NotImplementedError" in msg # opcall content and backtrace should be included inside the table @@ -136,11 +136,11 @@ def test_dispatch_prints_real_opcall_ir(): from tvm.tirx.operator.tile_primitive.dispatcher import run_dispatch from tvm.tirx.stmt import TilePrimitiveCall - # Build a real TIRx TilePrimitiveCall: tirx.copy(A[0:64], B[0:64]) + # Build a real TIRx TilePrimitiveCall: tirx.tile.copy(A[0:64], B[0:64]) A = decl_buffer((64,), "float32", scope="global") B = decl_buffer((64,), "float32", scope="shared") real_opcall = TilePrimitiveCall( - A[0:64], B[0:64], op=Op.get("tirx.copy"), workspace={}, config={} + A[0:64], B[0:64], op=Op.get("tirx.tile.copy"), workspace={}, config={} ) # Force predicate rejection to trigger formatted error with opcall IR @@ -151,8 +151,8 @@ def test_dispatch_prints_real_opcall_ir(): out = str(e.value) print(out) # Verify header and that the opcall IR is included in the table - assert "TIRx schedule dispatch failed: op=tirx.copy target=trn" in out + assert "TIRx schedule dispatch failed: op=tirx.tile.copy target=trn" in out assert "Variant" in out assert "opcall:" in out # IR should mention the operator name - assert "tirx.copy" in out + assert "tirx.tile.copy" in out diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_binary_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_binary_trn.py index 473bf659ec36..268ef0eae6f3 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_binary_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_binary_trn.py @@ -19,7 +19,8 @@ import tvm import tvm.testing from tvm.ir import assert_structural_equal as _assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout from tvm.tirx.stmt_functor import ir_transform @@ -28,8 +29,6 @@ def _strip_exec_scope_stmt(stmt): def _postorder(node): - if isinstance(node, tvm.tirx.ExecScopeStmt): - return node.body if isinstance(node, tvm.tirx.AttrStmt) and node.attr_key == "tirx.device_entry": return node.body return node @@ -38,7 +37,7 @@ def _postorder(node): stmt, preorder=lambda _node: None, postorder=_postorder, - only_enable=["tirx.ExecScopeStmt", "tirx.AttrStmt"], + only_enable=["tirx.AttrStmt"], ) @@ -65,7 +64,7 @@ def assert_structural_equal(lhs, rhs, *args, **kwargs): ], ) def test_simple_binary(op_type, operands_type): - const = Tx.float32(3.0) + const = T.float32(3.0) src1_shape = [128, 512] if operands_type != "region_broadcast_lhs" else [128, 1] src1_layout = TileLayout(S[src1_shape : (1 @ P, 1 @ F)]) src2_shape = [128, 512] if operands_type != "region_broadcast_rhs" else [128, 1] @@ -75,12 +74,12 @@ def test_simple_binary(op_type, operands_type): Tx_func = Tx_func_map[op_type] # fmt: off - @Tx.prim_func + @T.prim_func def binary() ->None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) if operands_type == "region_region" or operands_type.startswith("region_broadcast"): Tx_func(C_sbuf, A_sbuf, B_sbuf) elif operands_type == "const_region": @@ -88,29 +87,27 @@ def binary() ->None: elif operands_type == "region_const": Tx_func(C_sbuf, A_sbuf, const) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "binary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer(src1_shape, scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer(src2_shape, scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer(dst_shape, scope="trn.sbuf") - for b_loop in Tx.serial(0, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - if operands_type == "region_region": - Tx.nki.tensortensor(C_sbuf[p_loop, f_loop], A_sbuf[p_loop, f_loop], B_sbuf[p_loop, f_loop], op_type) # noqa: E501 - elif operands_type == "region_const": - Tx.nki.tensorscalar(C_sbuf[p_loop, f_loop], A_sbuf[p_loop, f_loop], Tx.float32(3.0), op_type, Tx.bool(False)) # noqa: E501 - elif operands_type == "const_region": - Tx.nki.tensorscalar(C_sbuf[p_loop, f_loop], A_sbuf[p_loop, f_loop], Tx.float32(3.0), op_type, Tx.bool(True)) # noqa: E501 - elif operands_type == "region_broadcast_rhs": - Tx.nki.tensorscalar(C_sbuf[p_loop, f_loop], A_sbuf[p_loop, f_loop], B_sbuf[p_loop, 0], op_type, Tx.bool(False)) # noqa: E501 - elif operands_type == "region_broadcast_lhs": - Tx.nki.tensorscalar(C_sbuf[p_loop, f_loop], B_sbuf[p_loop, f_loop], A_sbuf[p_loop, 0], op_type, Tx.bool(True)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "binary"}) + A_sbuf = T.alloc_buffer(src1_shape, scope="trn.sbuf") + B_sbuf = T.alloc_buffer(src2_shape, scope="trn.sbuf") + C_sbuf = T.alloc_buffer(dst_shape, scope="trn.sbuf") + for b_loop in T.serial(0, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + if operands_type == "region_region": + T.nki.tensortensor(C_sbuf[p_loop, f_loop], A_sbuf[p_loop, f_loop], B_sbuf[p_loop, f_loop], op_type) # noqa: E501 + elif operands_type == "region_const": + T.nki.tensorscalar(C_sbuf[p_loop, f_loop], A_sbuf[p_loop, f_loop], T.float32(3.0), op_type, T.bool(False)) # noqa: E501 + elif operands_type == "const_region": + T.nki.tensorscalar(C_sbuf[p_loop, f_loop], A_sbuf[p_loop, f_loop], T.float32(3.0), op_type, T.bool(True)) # noqa: E501 + elif operands_type == "region_broadcast_rhs": + T.nki.tensorscalar(C_sbuf[p_loop, f_loop], A_sbuf[p_loop, f_loop], B_sbuf[p_loop, 0], op_type, T.bool(False)) # noqa: E501 + elif operands_type == "region_broadcast_lhs": + T.nki.tensorscalar(C_sbuf[p_loop, f_loop], B_sbuf[p_loop, f_loop], A_sbuf[p_loop, 0], op_type, T.bool(True)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": binary}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -138,7 +135,7 @@ def test_binary_complex(op_type, operands_type): dst_shape = [512, 512] dst_layout = TileLayout(S[(128, 2048) : (1 @ P, 1 @ F)]) - const = Tx.float32(3.0) + const = T.float32(3.0) Tx_func = Tx_func_map[op_type] src1_view_shape = [128, 8, 512] @@ -150,12 +147,12 @@ def test_binary_complex(op_type, operands_type): dst_view_shape = [128, 4, 4, 128] # fmt: off - @Tx.prim_func + @T.prim_func def binary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) A_sbuf_view = A_sbuf.view(*src1_view_shape) B_sbuf_view = B_sbuf.view(*src2_view_shape) C_sbuf_view = C_sbuf.view(*dst_view_shape) @@ -174,33 +171,31 @@ def binary() -> None: f_extent = 128 if operands_type == "region_broadcast_lhs" else 512 b_extent = 4 if operands_type == "region_broadcast_lhs" else 1 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "binary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer(src1_layout_data_iter, scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer(src2_layout_data_iter, scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - A_sbuf_view = Tx.decl_buffer(src1_layout_data_iter, data=A_sbuf.data, scope="trn.sbuf", layout=None) # noqa: E501 - B_sbuf_view = Tx.decl_buffer(src2_layout_data_iter, data=B_sbuf.data, scope="trn.sbuf", layout=None) # noqa: E501 - C_sbuf_view = Tx.decl_buffer((128, 2048), data=C_sbuf.data, scope="trn.sbuf", layout=None) # noqa: E501 - for i, b_loop in Tx.grid(4, b_extent): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, f_extent, annotations={"nki_dim":"F"}): - if operands_type == "region_region": - Tx.nki.tensortensor(C_sbuf_view[p_loop, i * 512 + f_loop], A_sbuf_view[p_loop, i * 1024 + f_loop], B_sbuf_view[p_loop, i * 512 + f_loop], op_type) # noqa: E501 - elif operands_type == "const_region": - Tx.nki.tensorscalar(C_sbuf_view[p_loop, i * 512 + f_loop], A_sbuf_view[p_loop, i * 1024 + f_loop], Tx.float32(3.0), op_type, Tx.bool(True)) # noqa: E501 - elif operands_type == "region_const": - Tx.nki.tensorscalar(C_sbuf_view[p_loop, i * 512 + f_loop], A_sbuf_view[p_loop, i * 1024 + f_loop], Tx.float32(3.0), op_type, Tx.bool(False)) # noqa: E501 - elif operands_type == "region_broadcast_lhs": - Tx.nki.tensorscalar(C_sbuf_view[p_loop, i * 512 + b_loop * 128 + f_loop], B_sbuf_view[p_loop, i * 512 + b_loop * 128 + f_loop], A_sbuf_view[p_loop, i * 8 + b_loop], op_type, Tx.bool(True)) # noqa: E501 - elif operands_type == "region_broadcast_rhs": - Tx.nki.tensortensor(C_sbuf_view[p_loop, i * 512 + f_loop], A_sbuf_view[p_loop, i * 1024 + f_loop], B_sbuf_view[p_loop, f_loop], op_type) # noqa: E501 - - # fmt: on + T.func_attr({"global_symbol": "binary"}) + A_sbuf = T.alloc_buffer(src1_layout_data_iter, scope="trn.sbuf") + B_sbuf = T.alloc_buffer(src2_layout_data_iter, scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + A_sbuf_view = T.decl_buffer(src1_layout_data_iter, data=A_sbuf.data, scope="trn.sbuf", layout=None) # noqa: E501 + B_sbuf_view = T.decl_buffer(src2_layout_data_iter, data=B_sbuf.data, scope="trn.sbuf", layout=None) # noqa: E501 + C_sbuf_view = T.decl_buffer((128, 2048), data=C_sbuf.data, scope="trn.sbuf", layout=None) + for i, b_loop in T.grid(4, b_extent): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, f_extent, annotations={"nki_dim":"F"}): + if operands_type == "region_region": + T.nki.tensortensor(C_sbuf_view[p_loop, i * 512 + f_loop], A_sbuf_view[p_loop, i * 1024 + f_loop], B_sbuf_view[p_loop, i * 512 + f_loop], op_type) # noqa: E501 + elif operands_type == "const_region": + T.nki.tensorscalar(C_sbuf_view[p_loop, i * 512 + f_loop], A_sbuf_view[p_loop, i * 1024 + f_loop], T.float32(3.0), op_type, T.bool(True)) # noqa: E501 + elif operands_type == "region_const": + T.nki.tensorscalar(C_sbuf_view[p_loop, i * 512 + f_loop], A_sbuf_view[p_loop, i * 1024 + f_loop], T.float32(3.0), op_type, T.bool(False)) # noqa: E501 + elif operands_type == "region_broadcast_lhs": + T.nki.tensorscalar(C_sbuf_view[p_loop, i * 512 + b_loop * 128 + f_loop], B_sbuf_view[p_loop, i * 512 + b_loop * 128 + f_loop], A_sbuf_view[p_loop, i * 8 + b_loop], op_type, T.bool(True)) # noqa: E501 + elif operands_type == "region_broadcast_rhs": + T.nki.tensortensor(C_sbuf_view[p_loop, i * 512 + f_loop], A_sbuf_view[p_loop, i * 1024 + f_loop], B_sbuf_view[p_loop, f_loop], op_type) # noqa: E501 + + # fmt: on with target: mod = tvm.IRModule({"main": binary}) @@ -217,28 +212,26 @@ def test_binary_broadcast1(): dst_layout = src1_layout # fmt: off - @Tx.prim_func + @T.prim_func def binary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.add(C_sbuf, A_sbuf, B_sbuf) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "binary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - for b_loop in Tx.serial(0, 512): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 32, annotations={"nki_dim":"F"}): - Tx.nki.tensorscalar(C_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], A_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], B_sbuf[p_loop, b_loop], "add", Tx.bool(False)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "binary"}) + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + for b_loop in T.serial(0, 512): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 32, annotations={"nki_dim":"F"}): + T.nki.tensorscalar(C_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], A_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], B_sbuf[p_loop, b_loop], "add", T.bool(False)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": binary}) @@ -255,28 +248,26 @@ def test_binary_broadcast2(): dst_layout = src1_layout # fmt: off - @Tx.prim_func + @T.prim_func def binary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.add(C_sbuf, A_sbuf, B_sbuf) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "binary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - for b_loop in Tx.serial(0, 128): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 128, annotations={"nki_dim":"F"}): - Tx.nki.tensortensor(C_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 128 + f_loop], A_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 128 + f_loop], B_sbuf[p_loop, b_loop % 4 * 128 + f_loop], "add") # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "binary"}) + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + for b_loop in T.serial(0, 128): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 128, annotations={"nki_dim":"F"}): + T.nki.tensortensor(C_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 128 + f_loop], A_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 128 + f_loop], B_sbuf[p_loop, b_loop % 4 * 128 + f_loop], "add") # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": binary}) @@ -293,28 +284,26 @@ def test_binary_broadcast3(): dst_layout = src1_layout # fmt: off - @Tx.prim_func + @T.prim_func def binary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.add(C_sbuf, A_sbuf, B_sbuf[0]) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "binary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - for b_loop in Tx.serial(0, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 128, annotations={"nki_dim":"F"}): - Tx.nki.tensortensor(C_sbuf[p_loop, b_loop * 128 + f_loop], A_sbuf[p_loop, b_loop * 128 + f_loop], B_sbuf[p_loop, b_loop * 4096 + f_loop], "add") # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "binary"}) + A_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + for b_loop in T.serial(0, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 128, annotations={"nki_dim":"F"}): + T.nki.tensortensor(C_sbuf[p_loop, b_loop * 128 + f_loop], A_sbuf[p_loop, b_loop * 128 + f_loop], B_sbuf[p_loop, b_loop * 4096 + f_loop], "add") # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": binary}) @@ -331,31 +320,29 @@ def test_binary_with_guard(): dst_layout = src1_layout # fmt: off - @Tx.prim_func + @T.prim_func def binary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for j in range(4): Tx.add(C_sbuf[:, :, 0:j*128], A_sbuf[:, :, 0:j*128], B_sbuf[:, 0:j*128]) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "binary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - for j, b_loop in Tx.grid(4, 96): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 128, annotations={"nki_dim":"F"}): - if b_loop % 3 - j < 0: - Tx.nki.tensortensor(C_sbuf[p_loop, b_loop % 3 * 4096 + b_loop // 3 * 128 + f_loop], A_sbuf[p_loop, b_loop % 3 * 4096 + b_loop // 3 * 128 + f_loop], B_sbuf[p_loop, b_loop % 3 * 128 + f_loop], "add") # noqa: E501 - - # fmt: on + T.func_attr({"global_symbol": "binary"}) + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + for j, b_loop in T.grid(4, 96): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 128, annotations={"nki_dim":"F"}): + if b_loop % 3 - j < 0: + T.nki.tensortensor(C_sbuf[p_loop, b_loop % 3 * 4096 + b_loop // 3 * 128 + f_loop], A_sbuf[p_loop, b_loop % 3 * 4096 + b_loop // 3 * 128 + f_loop], B_sbuf[p_loop, b_loop % 3 * 128 + f_loop], "add") # noqa: E501 + + # fmt: on with target: mod = tvm.IRModule({"main": binary}) mod = tvm.tirx.transform.LowerTIRx()(mod) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py index a275215d7d9d..448b856d9aca 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py @@ -19,7 +19,8 @@ import tvm import tvm.testing from tvm.ir import assert_structural_equal as _assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout from tvm.tirx.stmt_functor import ir_transform @@ -28,8 +29,6 @@ def _strip_exec_scope_stmt(stmt): def _postorder(node): - if isinstance(node, tvm.tirx.ExecScopeStmt): - return node.body if isinstance(node, tvm.tirx.AttrStmt) and node.attr_key == "tirx.device_entry": return node.body return node @@ -38,7 +37,7 @@ def _postorder(node): stmt, preorder=lambda _node: None, postorder=_postorder, - only_enable=["tirx.ExecScopeStmt", "tirx.AttrStmt"], + only_enable=["tirx.AttrStmt"], ) @@ -59,34 +58,32 @@ def test_simple_activation_reduce(): C_layout = TileLayout(S[(128, 1) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def activation_reduce(): - Tx.device_entry() - A = Tx.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) - B = Tx.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) - C = Tx.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + A = T.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) + B = T.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) + C = T.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) Tx.unary_reduce(B, C, A, "sqrt", "sum", reduce_axes=1) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "activation_reduce"}) - - with Tx.thread(): - const_bias = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(512, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(0.0)) - A = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - B = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - C = Tx.alloc_buffer((128, 1), scope="trn.sbuf") - for b_loop in range(1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.activation_reduce(C[p_loop, 0], B[p_loop, f_loop], A[p_loop, f_loop], "sqrt", "add", bias=const_bias[p_loop, f_loop]) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "activation_reduce"}) + const_bias = T.alloc_buffer((128, 512), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(512, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(0.0)) + A = T.alloc_buffer((128, 512), scope="trn.sbuf") + B = T.alloc_buffer((128, 512), scope="trn.sbuf") + C = T.alloc_buffer((128, 1), scope="trn.sbuf") + for b_loop in range(1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.activation_reduce(C[p_loop, 0], B[p_loop, f_loop], A[p_loop, f_loop], "sqrt", "add", bias=const_bias[p_loop, f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": activation_reduce}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -103,34 +100,32 @@ def test_activation_reduce_in_loop(): C_layout = TileLayout(S[(2, 4, 2, 128) : (2 @ F, 4 @ F, 1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def activation_reduce(): - Tx.device_entry() - A = Tx.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) - B = Tx.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) - C = Tx.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + A = T.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) + B = T.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) + C = T.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) for i in range(2): Tx.unary_reduce(B, C, A[i*16:i*16+16], "sqrt", "sum", reduce_axes=1) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "activation_reduce"}) - - with Tx.thread(): - const_bias = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(512, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(0.0)) - A = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B = Tx.alloc_buffer((128, 8192), scope="trn.sbuf") - C = Tx.alloc_buffer((128, 16), scope="trn.sbuf") - for i, b_loop in Tx.grid(2, 16): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.activation_reduce(C[p_loop, b_loop % 8 // 2 * 4 + b_loop // 8 * 2 + b_loop % 2], B[p_loop, b_loop % 8 // 2 * 2048 + b_loop // 8 * 1024 + b_loop % 2 * 512 + f_loop], A[p_loop, i * 8192 + b_loop * 512 + f_loop], "sqrt", "add", bias=const_bias[p_loop, f_loop]) # noqa: E501 - # fmt: off + T.func_attr({"global_symbol": "activation_reduce"}) + const_bias = T.alloc_buffer((128, 512), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(512, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(0.0)) + A = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B = T.alloc_buffer((128, 8192), scope="trn.sbuf") + C = T.alloc_buffer((128, 16), scope="trn.sbuf") + for i, b_loop in T.grid(2, 16): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.activation_reduce(C[p_loop, b_loop % 8 // 2 * 4 + b_loop // 8 * 2 + b_loop % 2], B[p_loop, b_loop % 8 // 2 * 2048 + b_loop // 8 * 1024 + b_loop % 2 * 512 + f_loop], A[p_loop, i * 8192 + b_loop * 512 + f_loop], "sqrt", "add", bias=const_bias[p_loop, f_loop]) # noqa: E501 + # fmt: off with target: mod = tvm.IRModule({"main": activation_reduce}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -147,34 +142,32 @@ def test_activation_reduce_in_loop2(): C_layout = TileLayout(S[(2, 4, 2, 128) : (2 @ F, 4 @ F, 1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def activation_reduce(): - Tx.device_entry() - A = Tx.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) - B = Tx.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) - C = Tx.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + A = T.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) + B = T.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) + C = T.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) for i in range(2): Tx.unary_reduce(B, C, A[i*16:i*16+16], "sqrt", "sum", reduce_axes=1) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "activation_reduce"}) - - with Tx.thread(): - const_bias = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(512, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(0.0)) - A = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B = Tx.alloc_buffer((128, 8192), scope="trn.sbuf") - C = Tx.alloc_buffer((128, 16), scope="trn.sbuf") - for i, b_loop in Tx.grid(2, 16): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.activation_reduce(C[p_loop, b_loop % 8 // 2 * 4 + b_loop // 8 * 2 + b_loop % 2], B[p_loop, b_loop * 512 + f_loop], A[p_loop, i * 8192 + b_loop * 512 + f_loop], "sqrt", "add", bias=const_bias[p_loop, f_loop]) # noqa: E501 - # fmt: off + T.func_attr({"global_symbol": "activation_reduce"}) + const_bias = T.alloc_buffer((128, 512), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(512, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(0.0)) + A = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B = T.alloc_buffer((128, 8192), scope="trn.sbuf") + C = T.alloc_buffer((128, 16), scope="trn.sbuf") + for i, b_loop in T.grid(2, 16): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.activation_reduce(C[p_loop, b_loop % 8 // 2 * 4 + b_loop // 8 * 2 + b_loop % 2], B[p_loop, b_loop * 512 + f_loop], A[p_loop, i * 8192 + b_loop * 512 + f_loop], "sqrt", "add", bias=const_bias[p_loop, f_loop]) # noqa: E501 + # fmt: off with target: mod = tvm.IRModule({"main": activation_reduce}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -191,40 +184,38 @@ def test_activation_reduce_two_stage(): C_layout = TileLayout(S[(1, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def activation_reduce(): - Tx.device_entry() - A = Tx.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) - B = Tx.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) - C = Tx.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + A = T.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) + B = T.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) + C = T.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) for i in range(2): Tx.unary_reduce(B, C, A[i*16:i*16+16], "sqrt", "sum", reduce_axes=(0,1)) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "activation_reduce"}) - - with Tx.thread(): - partial_reduce = Tx.alloc_buffer((128, 8), scope="trn.sbuf") - const_bias = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(1024, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(0.0)) - A = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B = Tx.alloc_buffer((128, 8192), scope="trn.sbuf") - C = Tx.alloc_buffer((128, 1), scope="trn.sbuf") - for i, b_loop in Tx.grid(2, 1): - for reduction_b_loop in range(8): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(1024, annotations={"nki_dim": "F"}): - Tx.nki.activation_reduce(partial_reduce[p_loop, reduction_b_loop], B[p_loop, reduction_b_loop % 4 * 2048 + reduction_b_loop // 4 * 1024 + f_loop], A[p_loop, i * 8192 + reduction_b_loop * 1024 + f_loop], "sqrt", "add", const_bias[p_loop, f_loop], Tx.float32(1.0)) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(8, annotations={"nki_dim": "F"}): - Tx.nki.tensorreduce(C[p_loop, 0], partial_reduce[p_loop, f_loop], "add", Tx.bool(False), -1) # noqa: E501 - # fmt: off + T.func_attr({"global_symbol": "activation_reduce"}) + partial_reduce = T.alloc_buffer((128, 8), scope="trn.sbuf") + const_bias = T.alloc_buffer((128, 1024), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(1024, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(0.0)) + A = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B = T.alloc_buffer((128, 8192), scope="trn.sbuf") + C = T.alloc_buffer((128, 1), scope="trn.sbuf") + for i, b_loop in T.grid(2, 1): + for reduction_b_loop in range(8): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(1024, annotations={"nki_dim": "F"}): + T.nki.activation_reduce(partial_reduce[p_loop, reduction_b_loop], B[p_loop, reduction_b_loop % 4 * 2048 + reduction_b_loop // 4 * 1024 + f_loop], A[p_loop, i * 8192 + reduction_b_loop * 1024 + f_loop], "sqrt", "add", const_bias[p_loop, f_loop], T.float32(1.0)) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(8, annotations={"nki_dim": "F"}): + T.nki.tensorreduce(C[p_loop, 0], partial_reduce[p_loop, f_loop], "add", T.bool(False), -1) # noqa: E501 + # fmt: off with target: mod = tvm.IRModule({"main": activation_reduce}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -243,31 +234,29 @@ def test_activation_reduce_with_bias_scale(): bias_layout = TileLayout(S[(128, 1) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def activation_reduce(): - Tx.device_entry() - A = Tx.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) - B = Tx.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) - C = Tx.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) - bias = Tx.alloc_buffer(bias_shape, dtype="float32", scope="trn.sbuf", layout=bias_layout) + T.device_entry() + A = T.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) + B = T.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) + C = T.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) + bias = T.alloc_buffer(bias_shape, dtype="float32", scope="trn.sbuf", layout=bias_layout) for i in range(2): Tx.unary_reduce(B, C, A[i*16:i*16+16], "sqrt", "sum", reduce_axes=1, bias=bias, scale=2.0) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "activation_reduce"}) - - with Tx.thread(): - A = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B = Tx.alloc_buffer((128, 8192), scope="trn.sbuf") - C = Tx.alloc_buffer((128, 16), scope="trn.sbuf") - bias = Tx.alloc_buffer((128, 1), scope="trn.sbuf") - for i, b_loop in Tx.grid(2, 16): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.activation_reduce(C[p_loop, b_loop % 8 // 2 * 4 + b_loop // 8 * 2 + b_loop % 2], B[p_loop, b_loop * 512 + f_loop], A[p_loop, i * 8192 + b_loop * 512 + f_loop], "sqrt", "add", bias[p_loop, 0], Tx.float32(2.0)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "activation_reduce"}) + A = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B = T.alloc_buffer((128, 8192), scope="trn.sbuf") + C = T.alloc_buffer((128, 16), scope="trn.sbuf") + bias = T.alloc_buffer((128, 1), scope="trn.sbuf") + for i, b_loop in T.grid(2, 16): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.activation_reduce(C[p_loop, b_loop % 8 // 2 * 4 + b_loop // 8 * 2 + b_loop % 2], B[p_loop, b_loop * 512 + f_loop], A[p_loop, i * 8192 + b_loop * 512 + f_loop], "sqrt", "add", bias[p_loop, 0], T.float32(2.0)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": activation_reduce}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -283,28 +272,26 @@ def test_simple_tensor_scalar_reduce(): C_layout = TileLayout(S[(128, 1) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def tensor_scalar_reduce(): - Tx.device_entry() - A = Tx.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) - B = Tx.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) - C = Tx.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + A = T.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) + B = T.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) + C = T.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) Tx.binary_reduce(B, C, A, 1.0, "add", "sum", reduce_axes=1) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "tensor_scalar_reduce"}) - - with Tx.thread(): - A = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - B = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - C = Tx.alloc_buffer((128, 1), scope="trn.sbuf") - for b_loop in range(1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.tensorscalar_reduce(C[p_loop, 0], B[p_loop, f_loop], A[p_loop, f_loop], Tx.float32(1.0), "add", "add", Tx.bool(False)) # noqa: E501 - # fmt: off + T.func_attr({"global_symbol": "tensor_scalar_reduce"}) + A = T.alloc_buffer((128, 512), scope="trn.sbuf") + B = T.alloc_buffer((128, 512), scope="trn.sbuf") + C = T.alloc_buffer((128, 1), scope="trn.sbuf") + for b_loop in range(1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.tensorscalar_reduce(C[p_loop, 0], B[p_loop, f_loop], A[p_loop, f_loop], T.float32(1.0), "add", "add", T.bool(False)) # noqa: E501 + # fmt: off with target: mod = tvm.IRModule({"main": tensor_scalar_reduce}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -322,13 +309,13 @@ def test_tensor_tensor_reduce_fail(): C_layout = TileLayout(S[(128, 1) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def tensor_scalar_reduce(): - Tx.device_entry() - A = Tx.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) - B = Tx.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) - C = Tx.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) - D = Tx.alloc_buffer(D_shape, dtype="float32", scope="trn.sbuf", layout=D_layout) + T.device_entry() + A = T.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) + B = T.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) + C = T.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) + D = T.alloc_buffer(D_shape, dtype="float32", scope="trn.sbuf", layout=D_layout) Tx.binary_reduce(B, C, A, D, "add", "sum", reduce_axes=1) # fmt: off @@ -349,30 +336,28 @@ def test_tensor_scalar_reduce_complex(): reduce_dst_layout = TileLayout(S[(128, 4, 128) : (1 @ F, 128 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def tensor_scalar_reduce() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) - D_sbuf = Tx.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + D_sbuf = T.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 Tx.binary_reduce(C_sbuf, D_sbuf, B_sbuf, A_sbuf, "add", "sum", reduce_axes=0) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "tensor_scalar_reduce"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - D_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - for b_loop in range(512): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 32, annotations={"nki_dim":"F"}): - Tx.nki.tensorscalar_reduce(D_sbuf[p_loop, b_loop % 4 * 128 + b_loop // 4], C_sbuf[p_loop, b_loop % 4 * 4096 + f_loop * 128 + b_loop // 4], A_sbuf[p_loop, b_loop % 4 * 4096 + f_loop * 128 + b_loop // 4], B_sbuf[p_loop, b_loop % 4 * 128 + b_loop // 4], "add", "add", Tx.bool(True)) # noqa: E501 - # fmt: off + T.func_attr({"global_symbol": "tensor_scalar_reduce"}) + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + D_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + for b_loop in range(512): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 32, annotations={"nki_dim":"F"}): + T.nki.tensorscalar_reduce(D_sbuf[p_loop, b_loop % 4 * 128 + b_loop // 4], C_sbuf[p_loop, b_loop % 4 * 4096 + f_loop * 128 + b_loop // 4], A_sbuf[p_loop, b_loop % 4 * 4096 + f_loop * 128 + b_loop // 4], B_sbuf[p_loop, b_loop % 4 * 128 + b_loop // 4], "add", "add", T.bool(True)) # noqa: E501 + # fmt: off with target: mod = tvm.IRModule({"main": tensor_scalar_reduce}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -388,34 +373,32 @@ def test_tensor_scalar_reduce_two_stage(): reduce_dst_layout = TileLayout(S[(128, 4) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def tensor_scalar_reduce() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(dst1_shape, "float32", scope="trn.sbuf", layout=dst1_layout) - C_sbuf = Tx.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(dst1_shape, "float32", scope="trn.sbuf", layout=dst1_layout) + C_sbuf = T.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 Tx.binary_reduce(B_sbuf, C_sbuf, A_sbuf, 1.0, "add", "sum", reduce_axes=(1, 2)) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "tensor_scalar_reduce"}) - - with Tx.thread(): - partial_reduce = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - for b_loop in range(4): - for reduction_b_loop in range(4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(1024, annotations={"nki_dim": "F"}): - Tx.nki.tensorscalar_reduce(partial_reduce[p_loop, reduction_b_loop], B_sbuf[p_loop, reduction_b_loop * 4096 + b_loop * 1024 + f_loop], A_sbuf[p_loop, reduction_b_loop * 4096 + b_loop * 1024 + f_loop], Tx.float32(1.0), "add", "add", Tx.bool(False)) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(4, annotations={"nki_dim": "F"}): - Tx.nki.tensorreduce(C_sbuf[p_loop, b_loop], partial_reduce[p_loop, f_loop], "add", Tx.bool(False), -1) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "tensor_scalar_reduce"}) + partial_reduce = T.alloc_buffer((128, 4), scope="trn.sbuf") + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + for b_loop in range(4): + for reduction_b_loop in range(4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(1024, annotations={"nki_dim": "F"}): + T.nki.tensorscalar_reduce(partial_reduce[p_loop, reduction_b_loop], B_sbuf[p_loop, reduction_b_loop * 4096 + b_loop * 1024 + f_loop], A_sbuf[p_loop, reduction_b_loop * 4096 + b_loop * 1024 + f_loop], T.float32(1.0), "add", "add", T.bool(False)) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(4, annotations={"nki_dim": "F"}): + T.nki.tensorreduce(C_sbuf[p_loop, b_loop], partial_reduce[p_loop, f_loop], "add", T.bool(False), -1) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": tensor_scalar_reduce}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -434,32 +417,30 @@ def test_vector_chain(): dst_layout = src1_layout # fmt: off - @Tx.prim_func + @T.prim_func def binary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) - _C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) - D_sbuf = Tx.alloc_buffer(src3_shape, "float32", scope="trn.sbuf", layout=src3_layout) - E_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) + _C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + D_sbuf = T.alloc_buffer(src3_shape, "float32", scope="trn.sbuf", layout=src3_layout) + E_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.binary_chain(E_sbuf, A_sbuf, B_sbuf, D_sbuf, "add", "add", reverse1=True) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "binary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - _C_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - D_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - E_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - for b_loop in Tx.serial(0, 512): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 32, annotations={"nki_dim":"F"}): - Tx.nki.scalar_tensor_scalar(E_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], A_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], B_sbuf[p_loop, b_loop], D_sbuf[p_loop, b_loop % 4], "add", "add", Tx.bool(False), Tx.bool(True)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "binary"}) + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + _C_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + D_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + E_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + for b_loop in T.serial(0, 512): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 32, annotations={"nki_dim":"F"}): + T.nki.scalar_tensor_scalar(E_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], A_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], B_sbuf[p_loop, b_loop], D_sbuf[p_loop, b_loop % 4], "add", "add", T.bool(False), T.bool(True)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": binary}) @@ -478,32 +459,30 @@ def test_vector_chain_2(): dst_layout = src1_layout # fmt: off - @Tx.prim_func + @T.prim_func def binary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) - _C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) - D_sbuf = Tx.alloc_buffer(src3_shape, "float32", scope="trn.sbuf", layout=src3_layout) - E_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) + _C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + D_sbuf = T.alloc_buffer(src3_shape, "float32", scope="trn.sbuf", layout=src3_layout) + E_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.binary_chain(E_sbuf, A_sbuf, B_sbuf, D_sbuf, "add", "add", reverse1=True) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "binary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - _C_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - D_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - E_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - for b_loop in Tx.serial(0, 512): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 32, annotations={"nki_dim":"F"}): - Tx.nki.scalar_tensor_tensor(E_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], A_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], B_sbuf[p_loop, b_loop], D_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], "add", "add", Tx.bool(False), Tx.bool(True)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "binary"}) + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + _C_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + D_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + E_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + for b_loop in T.serial(0, 512): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 32, annotations={"nki_dim":"F"}): + T.nki.scalar_tensor_tensor(E_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], A_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], B_sbuf[p_loop, b_loop], D_sbuf[p_loop, b_loop % 4 * 4096 + b_loop // 4 * 32 + f_loop], "add", "add", T.bool(False), T.bool(True)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": binary}) @@ -518,27 +497,25 @@ def test_reduce_negate(): dst_layout = TileLayout(S[(128, 4) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def reduction(): - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): Tx.reduce_negate(B_sbuf[:, i], A_sbuf[:, :, i], reduce_op="sum", reduce_axes=-2) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "reduction"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - for i, b_loop in Tx.grid(4, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.tensorreduce(B_sbuf[p_loop, i], A_sbuf[p_loop, f_loop * 4 + i], "add", True, -1) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "reduction"}) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + for i, b_loop in T.grid(4, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.tensorreduce(B_sbuf[p_loop, i], A_sbuf[p_loop, f_loop * 4 + i], "add", True, -1) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": reduction}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -554,31 +531,29 @@ def test_binary_reduce_guard(): reduce_dst_layout = TileLayout(S[(4, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def binary_reduce() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) - C_sbuf = Tx.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + C_sbuf = T.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 for j in range(4): for i in range(4): Tx.binary_reduce(B_sbuf[0:128*(j+1), 0:128*(i+1)], C_sbuf[0:128*(j+1)], A_sbuf[0:128*(j+1), 0:128*(i+1)], 0.0, "add", "sum", [-1]) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "binary_reduce"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - for j, i, b_loop in Tx.grid(4, 4, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - if b_loop - j < 1 and f_loop < i * 128 + 128: - Tx.nki.tensorscalar_reduce(C_sbuf[p_loop, b_loop], B_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], Tx.float32(0.0), "add", "add", Tx.bool(False)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "binary_reduce"}) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + for j, i, b_loop in T.grid(4, 4, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + if b_loop - j < 1 and f_loop < i * 128 + 128: + T.nki.tensorscalar_reduce(C_sbuf[p_loop, b_loop], B_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], T.float32(0.0), "add", "add", T.bool(False)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": binary_reduce}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -595,37 +570,35 @@ def test_unary_reduce_guard(): reduce_dst_layout = TileLayout(S[(4, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def unary_reduce() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) - C_sbuf = Tx.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + C_sbuf = T.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 for j in range(4): for i in range(4): Tx.unary_reduce(B_sbuf[0:128*(j+1), 0:128*(i+1)], C_sbuf[0:128*(j+1)], A_sbuf[0:128*(j+1), 0:128*(i+1)], "sqrt", "sum", reduce_axes=[-1]) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "unary_reduce"}) - - with Tx.thread(): - const_bias = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(512, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(0.0)) - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - for j, i, b_loop in Tx.grid(4, 4, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(512, annotations={"nki_dim": "F"}): - if b_loop - j < 1 and f_loop < i * 128 + 128: - Tx.nki.activation_reduce(C_sbuf[p_loop, b_loop], B_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], "sqrt", "add", const_bias[p_loop, f_loop], Tx.float32(1.0)) # noqa: E501 - - # fmt: on + T.func_attr({"global_symbol": "unary_reduce"}) + const_bias = T.alloc_buffer((128, 512), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(512, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(0.0)) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + for j, i, b_loop in T.grid(4, 4, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(512, annotations={"nki_dim": "F"}): + if b_loop - j < 1 and f_loop < i * 128 + 128: + T.nki.activation_reduce(C_sbuf[p_loop, b_loop], B_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], "sqrt", "add", const_bias[p_loop, f_loop], T.float32(1.0)) # noqa: E501 + + # fmt: on with target: mod = tvm.IRModule({"main": unary_reduce}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -643,30 +616,28 @@ def test_binary_chain_guard(): src2_layout = TileLayout(S[(4, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def binary_chain() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(src2_shape, "float32", scope="trn.sbuf", layout=src2_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for j in range(4): for i in range(4): Tx.binary_chain(C_sbuf[0:128*(j+1), 0:128*(i+1)], A_sbuf[0:128*(j+1), 0:128*(i+1)], B_sbuf[0:128*(j+1), 0], 1.0, "add", "sub", reverse1=True) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "binary_chain"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for j, i, b_loop in Tx.grid(4, 4, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - if b_loop - j < 1 and f_loop < i * 128 + 128: - Tx.nki.scalar_tensor_scalar(C_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], B_sbuf[p_loop, b_loop], Tx.float32(1.0), "add", "sub", Tx.bool(False), Tx.bool(True)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "binary_chain"}) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for j, i, b_loop in T.grid(4, 4, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + if b_loop - j < 1 and f_loop < i * 128 + 128: + T.nki.scalar_tensor_scalar(C_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], B_sbuf[p_loop, b_loop], T.float32(1.0), "add", "sub", T.bool(False), T.bool(True)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": binary_chain}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -683,42 +654,40 @@ def test_activation_reduce_two_stage_workspace(): C_layout = TileLayout(S[(1, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def activation_reduce(): - Tx.device_entry() - intermediate_buffer = Tx.alloc_buffer((128, 16), scope="trn.sbuf") - A = Tx.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) - B = Tx.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) - C = Tx.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + intermediate_buffer = T.alloc_buffer((128, 16), scope="trn.sbuf") + A = T.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) + B = T.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) + C = T.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) for i in range(2): Tx.unary_reduce(B, C, A[i*16:i*16+16], "sqrt", "sum", reduce_axes=(0,1), workspace={"partial_reduce": intermediate_buffer}) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "activation_reduce"}) - - with Tx.thread(): - const_bias = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(1024, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(0.0)) - intermediate_buffer = Tx.alloc_buffer((128, 16), scope="trn.sbuf") - A = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B = Tx.alloc_buffer((128, 8192), scope="trn.sbuf") - C = Tx.alloc_buffer((128, 1), scope="trn.sbuf") - for i, b_loop in Tx.grid(2, 1): - for reduction_b_loop in range(8): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(1024, annotations={"nki_dim": "F"}): - Tx.nki.activation_reduce(intermediate_buffer[p_loop, reduction_b_loop], B[p_loop, reduction_b_loop % 4 * 2048 + reduction_b_loop // 4 * 1024 + f_loop], A[p_loop, i * 8192 + reduction_b_loop * 1024 + f_loop], "sqrt", "add", const_bias[p_loop, f_loop], Tx.float32(1.0)) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(8, annotations={"nki_dim": "F"}): - Tx.nki.tensorreduce(C[p_loop, 0], intermediate_buffer[p_loop, f_loop], "add", Tx.bool(False), -1) # noqa: E501 - - # fmt: on + T.func_attr({"global_symbol": "activation_reduce"}) + const_bias = T.alloc_buffer((128, 1024), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(1024, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(0.0)) + intermediate_buffer = T.alloc_buffer((128, 16), scope="trn.sbuf") + A = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B = T.alloc_buffer((128, 8192), scope="trn.sbuf") + C = T.alloc_buffer((128, 1), scope="trn.sbuf") + for i, b_loop in T.grid(2, 1): + for reduction_b_loop in range(8): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(1024, annotations={"nki_dim": "F"}): + T.nki.activation_reduce(intermediate_buffer[p_loop, reduction_b_loop], B[p_loop, reduction_b_loop % 4 * 2048 + reduction_b_loop // 4 * 1024 + f_loop], A[p_loop, i * 8192 + reduction_b_loop * 1024 + f_loop], "sqrt", "add", const_bias[p_loop, f_loop], T.float32(1.0)) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(8, annotations={"nki_dim": "F"}): + T.nki.tensorreduce(C[p_loop, 0], intermediate_buffer[p_loop, f_loop], "add", T.bool(False), -1) # noqa: E501 + + # fmt: on with target: mod = tvm.IRModule({"main": activation_reduce}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -735,35 +704,33 @@ def test_tensor_scalar_reduce_two_stage_workspace(): reduce_dst_layout = TileLayout(S[(128, 4) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def tensor_scalar_reduce() -> None: - Tx.device_entry() - intermediate_buffer = Tx.alloc_buffer((128, 8), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(dst1_shape, "float32", scope="trn.sbuf", layout=dst1_layout) - C_sbuf = Tx.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 + T.device_entry() + intermediate_buffer = T.alloc_buffer((128, 8), scope="trn.sbuf") + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(dst1_shape, "float32", scope="trn.sbuf", layout=dst1_layout) + C_sbuf = T.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 Tx.binary_reduce(B_sbuf, C_sbuf, A_sbuf, 1.0, "add", "sum", reduce_axes=(1, 2), workspace={"partial_reduce": intermediate_buffer}) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "tensor_scalar_reduce"}) - - with Tx.thread(): - intermediate_buffer = Tx.alloc_buffer((128, 8), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - for b_loop in range(4): - for reduction_b_loop in range(4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(1024, annotations={"nki_dim": "F"}): - Tx.nki.tensorscalar_reduce(intermediate_buffer[p_loop, reduction_b_loop], B_sbuf[p_loop, reduction_b_loop * 4096 + b_loop * 1024 + f_loop], A_sbuf[p_loop, reduction_b_loop * 4096 + b_loop * 1024 + f_loop], Tx.float32(1.0), "add", "add", Tx.bool(False)) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(4, annotations={"nki_dim": "F"}): - Tx.nki.tensorreduce(C_sbuf[p_loop, b_loop], intermediate_buffer[p_loop, f_loop], "add", Tx.bool(False), -1) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "tensor_scalar_reduce"}) + intermediate_buffer = T.alloc_buffer((128, 8), scope="trn.sbuf") + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + for b_loop in range(4): + for reduction_b_loop in range(4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(1024, annotations={"nki_dim": "F"}): + T.nki.tensorscalar_reduce(intermediate_buffer[p_loop, reduction_b_loop], B_sbuf[p_loop, reduction_b_loop * 4096 + b_loop * 1024 + f_loop], A_sbuf[p_loop, reduction_b_loop * 4096 + b_loop * 1024 + f_loop], T.float32(1.0), "add", "add", T.bool(False)) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(4, annotations={"nki_dim": "F"}): + T.nki.tensorreduce(C_sbuf[p_loop, b_loop], intermediate_buffer[p_loop, f_loop], "add", T.bool(False), -1) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": tensor_scalar_reduce}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -772,31 +739,29 @@ def expected(): def test_unary_reduce_complex(): # fmt: off - @Tx.prim_func + @T.prim_func def unary_reduce(): - Tx.device_entry() - p = Tx.alloc_buffer((128, 8192), "float16", scope="trn.sbuf", layout="PF") - rowsum_p = Tx.alloc_buffer((2, 128, 1), scope="trn.sbuf", layout="FPF") - qk = Tx.alloc_buffer((2, 128, 8192), scope="trn.sbuf", layout="FPF") - running_max = Tx.alloc_buffer((16384, 1), dtype="float32", scope="trn.sbuf", layout="PF") + T.device_entry() + p = T.alloc_buffer((128, 8192), "float16", scope="trn.sbuf", layout="PF") + rowsum_p = T.alloc_buffer((2, 128, 1), scope="trn.sbuf", layout="FPF") + qk = T.alloc_buffer((2, 128, 8192), scope="trn.sbuf", layout="FPF") + running_max = T.alloc_buffer((16384, 1), dtype="float32", scope="trn.sbuf", layout="PF") for i in range(4): Tx.unary_reduce(p[0:128, 0:8192], rowsum_p[i % 2, 0:128, 0], qk[i % 2, 0:128, 0:8192], "exp", "sum", bias=running_max[i * 128:i * 128 + 128, 0]) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "unary_reduce"}) - - with Tx.thread(): - p = Tx.alloc_buffer((128, 8192), "float16", scope="trn.sbuf") - rowsum_p = Tx.alloc_buffer((128, 2), scope="trn.sbuf") - qk = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - running_max = Tx.alloc_buffer((128, 128), scope="trn.sbuf") - for i, b_loop in Tx.grid(4, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(8192, annotations={"nki_dim": "F"}): - Tx.nki.activation_reduce(rowsum_p[p_loop, i % 2], p[p_loop, f_loop], qk[p_loop, i % 2 * 8192 + f_loop], "exp", "add", running_max[p_loop, i], Tx.float32(1.0)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "unary_reduce"}) + p = T.alloc_buffer((128, 8192), "float16", scope="trn.sbuf") + rowsum_p = T.alloc_buffer((128, 2), scope="trn.sbuf") + qk = T.alloc_buffer((128, 16384), scope="trn.sbuf") + running_max = T.alloc_buffer((128, 128), scope="trn.sbuf") + for i, b_loop in T.grid(4, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(8192, annotations={"nki_dim": "F"}): + T.nki.activation_reduce(rowsum_p[p_loop, i % 2], p[p_loop, f_loop], qk[p_loop, i % 2 * 8192 + f_loop], "exp", "add", running_max[p_loop, i], T.float32(1.0)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": unary_reduce}) mod = tvm.tirx.transform.LowerTIRx()(mod) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py index 3e6ec9262bdd..4be47a7ed147 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py @@ -18,7 +18,8 @@ import tvm import tvm.testing from tvm.ir import assert_structural_equal as _assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout from tvm.tirx.stmt_functor import ir_transform @@ -27,8 +28,6 @@ def _strip_exec_scope_stmt(stmt): def _postorder(node): - if isinstance(node, tvm.tirx.ExecScopeStmt): - return node.body if isinstance(node, tvm.tirx.AttrStmt) and node.attr_key == "tirx.device_entry": return node.body return node @@ -37,7 +36,7 @@ def _postorder(node): stmt, preorder=lambda _node: None, postorder=_postorder, - only_enable=["tirx.ExecScopeStmt", "tirx.AttrStmt"], + only_enable=["tirx.AttrStmt"], ) @@ -51,30 +50,29 @@ def assert_structural_equal(lhs, rhs, *args, **kwargs): def test_simple_copy(): src_shape = [128, 512] - src_layout = Tx.TileLayout(Tx.S[(128, 512) : (512, 1)]) + src_layout = T.TileLayout(T.S[(128, 512) : (512, 1)]) dst_shape = [128, 512] dst_layout = TileLayout(S[(128, 512) : (1 @ P, 1 @ F)]) - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.copy(A_sbuf, A) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) - A = Tx.match_buffer(A_ptr, (128, 512), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((65536,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - for b_loop in Tx.serial(0, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim": "F"}): - Tx.nki.load(A_sbuf[p_loop, f_loop], A_1[p_loop * 512 + f_loop]) + A = T.match_buffer(A_ptr, (128, 512), layout=None) + A_1 = T.decl_buffer((65536,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + for b_loop in T.serial(0, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim": "F"}): + T.nki.load(A_sbuf[p_loop, f_loop], A_1[p_loop * 512 + f_loop]) with target: mod = tvm.IRModule({"main": copy}) @@ -89,26 +87,25 @@ def test_simple_copy_2(): dst_shape = [128, 512] dst_layout = TileLayout(S[(128, 4, 128) : (4 @ F, 1 @ F, 1 @ P)]) - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.copy(A_sbuf, A) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) - A = Tx.match_buffer(A_ptr, (128, 512), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((65536,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - for b_loop in Tx.serial(0, 512): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, 1, annotations={"nki_dim": "F"}): - Tx.nki.load(A_sbuf[p_loop, b_loop], A_1[b_loop * 128 + p_loop]) + A = T.match_buffer(A_ptr, (128, 512), layout=None) + A_1 = T.decl_buffer((65536,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + for b_loop in T.serial(0, 512): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, 1, annotations={"nki_dim": "F"}): + T.nki.load(A_sbuf[p_loop, b_loop], A_1[b_loop * 128 + p_loop]) with target: mod = tvm.IRModule({"main": copy}) @@ -118,33 +115,32 @@ def expected(A_ptr: Tx.handle): def test_copy_in_a_loop(): src_shape = [512, 512] - src_layout = Tx.TileLayout(Tx.S[(4, 128, 512) : (512 * 128, 512, 1)]) + src_layout = T.TileLayout(T.S[(4, 128, 512) : (512 * 128, 512, 1)]) dst_shape = [512, 512] dst_layout = TileLayout(S[(4, 128, 512) : (512 @ F, 1 @ P, 1 @ F)]) - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): Tx.copy(A_sbuf[i * 128 : i * 128 + 128, :], A[i * 128 : i * 128 + 128, :]) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - A = Tx.match_buffer(A_ptr, (512, 512), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((262144,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for i, b_loop in Tx.grid(4, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim": "F"}): - Tx.nki.load( - A_sbuf[p_loop, i * 512 + f_loop], A_1[i * 65536 + p_loop * 512 + f_loop] - ) + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + + A = T.match_buffer(A_ptr, (512, 512), layout=None) + A_1 = T.decl_buffer((262144,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for i, b_loop in T.grid(4, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim": "F"}): + T.nki.load( + A_sbuf[p_loop, i * 512 + f_loop], A_1[i * 65536 + p_loop * 512 + f_loop] + ) with target: mod = tvm.IRModule({"main": copy}) @@ -154,40 +150,37 @@ def expected(A_ptr: Tx.handle): def test_copy_in_a_loop_2(): src_shape = [512, 512] - src_layout = Tx.TileLayout(Tx.S[(128, 2048) : (2048, 1)]) + src_layout = T.TileLayout(T.S[(128, 2048) : (2048, 1)]) dst_shape = [512, 512] dst_layout = TileLayout(S[(128, 2048) : (1 @ P, 1 @ F)]) - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) A_sbuf_view = A_sbuf.view(128, 4, 512) A_view = A.view(128, 4, 512) for i in range(4): Tx.copy(A_sbuf_view[:, i, :], A_view[:, i, :]) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - A = Tx.match_buffer(A_ptr, (512, 512), layout=None) - with Tx.thread(): - _A_flat = Tx.decl_buffer((262144,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - A_sbuf_view = Tx.decl_buffer( - (128, 2048), data=A_sbuf.data, scope="trn.sbuf", layout=None - ) - A_view = Tx.decl_buffer((262144,), data=A.data, layout=None) - for i, b_loop in Tx.grid(4, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim": "F"}): - Tx.nki.load( - A_sbuf_view[p_loop, i * 512 + f_loop], - A_view[p_loop * 2048 + i * 512 + f_loop], - ) + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + + A = T.match_buffer(A_ptr, (512, 512), layout=None) + _A_flat = T.decl_buffer((262144,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + A_sbuf_view = T.decl_buffer((128, 2048), data=A_sbuf.data, scope="trn.sbuf", layout=None) + A_view = T.decl_buffer((262144,), data=A.data, layout=None) + for i, b_loop in T.grid(4, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim": "F"}): + T.nki.load( + A_sbuf_view[p_loop, i * 512 + f_loop], + A_view[p_loop * 2048 + i * 512 + f_loop], + ) with target: mod = tvm.IRModule({"main": copy}) @@ -203,38 +196,36 @@ def test_copy_transpose(): dst_layout = TileLayout(S[(2048, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def copy() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.copy(B_sbuf, A_sbuf) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "copy"}) - - with Tx.thread(): - identity = Tx.alloc_buffer((128, 128), scope="trn.sbuf") - acc_psum = Tx.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - Tx.nki.identity(identity[p_loop, rhs_f_loop], 128) - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for b_loop in range(16): - for extend_b_loop in range(1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for lhs_f_loop in Tx.serial(128, annotations={"nki_dim": "lhs_F"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "rhs_F"}): - Tx.nki.matmul(acc_psum[b_loop % 8, lhs_f_loop, rhs_f_loop], A_sbuf[p_loop, b_loop * 128 + lhs_f_loop], identity[p_loop, rhs_f_loop], Tx.bool(True)) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - Tx.nki.tensor_copy(B_sbuf[p_loop, f_loop * 16 + b_loop], acc_psum[b_loop % 8, p_loop, f_loop]) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "copy"}) + identity = T.alloc_buffer((128, 128), scope="trn.sbuf") + acc_psum = T.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "F"}): + T.nki.identity(identity[p_loop, rhs_f_loop], 128) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for b_loop in range(16): + for extend_b_loop in range(1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for lhs_f_loop in T.serial(128, annotations={"nki_dim": "lhs_F"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "rhs_F"}): + T.nki.matmul(acc_psum[b_loop % 8, lhs_f_loop, rhs_f_loop], A_sbuf[p_loop, b_loop * 128 + lhs_f_loop], identity[p_loop, rhs_f_loop], T.bool(True)) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(128, annotations={"nki_dim": "F"}): + T.nki.tensor_copy(B_sbuf[p_loop, f_loop * 16 + b_loop], acc_psum[b_loop % 8, p_loop, f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": copy}) @@ -251,40 +242,38 @@ def test_copy_transpose_2(): dst_layout = TileLayout(S[(4, 128, 128, 4) : (4 @ F, 16 @ F, 1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def copy() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): Tx.copy(B_sbuf[i, :], A_sbuf) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "copy"}) - - with Tx.thread(): - identity = Tx.alloc_buffer((128, 128), scope="trn.sbuf") - acc_psum = Tx.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - Tx.nki.identity(identity[p_loop, rhs_f_loop], 128) - A_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for i in range(4): - for b_loop in range(4): - for extend_b_loop in range(1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for lhs_f_loop in Tx.serial(128, annotations={"nki_dim": "lhs_F"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "rhs_F"}): - Tx.nki.matmul(acc_psum[b_loop, lhs_f_loop, rhs_f_loop], A_sbuf[p_loop, lhs_f_loop * 4 + b_loop], identity[p_loop, rhs_f_loop], Tx.bool(True)) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - Tx.nki.tensor_copy(B_sbuf[p_loop, f_loop * 16 + i * 4 + b_loop], acc_psum[b_loop, p_loop, f_loop]) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "copy"}) + identity = T.alloc_buffer((128, 128), scope="trn.sbuf") + acc_psum = T.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "F"}): + T.nki.identity(identity[p_loop, rhs_f_loop], 128) + A_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for i in range(4): + for b_loop in range(4): + for extend_b_loop in range(1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for lhs_f_loop in T.serial(128, annotations={"nki_dim": "lhs_F"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "rhs_F"}): + T.nki.matmul(acc_psum[b_loop, lhs_f_loop, rhs_f_loop], A_sbuf[p_loop, lhs_f_loop * 4 + b_loop], identity[p_loop, rhs_f_loop], T.bool(True)) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(128, annotations={"nki_dim": "F"}): + T.nki.tensor_copy(B_sbuf[p_loop, f_loop * 16 + i * 4 + b_loop], acc_psum[b_loop, p_loop, f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -299,31 +288,29 @@ def test_copy_different_f(): dst_shape = [512, 64] dst_layout = TileLayout(S[(4, 128, 4, 4, 4) : (64 @ F, 1 @ P, 4 @ F, 16 @ F, 1 @ F)]) - @Tx.prim_func + @T.prim_func def copy() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.copy(B_sbuf, A_sbuf) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "copy"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 256), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 256), scope="trn.sbuf") - for b_loop in Tx.serial(0, 64): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, 4, annotations={"nki_dim": "F"}): - Tx.nki.tensor_copy( - B_sbuf[ - p_loop, - b_loop // 16 * 64 + b_loop % 4 * 16 + b_loop % 16 // 4 * 4 + f_loop, - ], - A_sbuf[p_loop, b_loop * 4 + f_loop], - ) + T.func_attr({"global_symbol": "copy"}) + A_sbuf = T.alloc_buffer((128, 256), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 256), scope="trn.sbuf") + for b_loop in T.serial(0, 64): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, 4, annotations={"nki_dim": "F"}): + T.nki.tensor_copy( + B_sbuf[ + p_loop, + b_loop // 16 * 64 + b_loop % 4 * 16 + b_loop % 16 // 4 * 4 + f_loop, + ], + A_sbuf[p_loop, b_loop * 4 + f_loop], + ) with target: mod = tvm.IRModule({"main": copy}) @@ -337,32 +324,28 @@ def test_copy_different_shape(): dst_shape = [4, 128, 4] dst_layout = TileLayout(S[(4, 128, 4) : (4 @ F, 1 @ P, 1 @ F)]) - @Tx.prim_func + @T.prim_func def copy() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) B_sbuf_view = B_sbuf.view(512, 4) Tx.copy(B_sbuf_view, A_sbuf[:, 0:4]) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "copy"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 256), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 16), scope="trn.sbuf") - _B_sbuf_view = Tx.decl_buffer( - (128, 16), data=B_sbuf.data, scope="trn.sbuf", layout=None - ) - for b_loop in Tx.serial(0, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, 4, annotations={"nki_dim": "F"}): - Tx.nki.tensor_copy( - B_sbuf[p_loop, b_loop * 4 + f_loop], - A_sbuf[p_loop, b_loop * 64 + f_loop], - ) + T.func_attr({"global_symbol": "copy"}) + A_sbuf = T.alloc_buffer((128, 256), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 16), scope="trn.sbuf") + _B_sbuf_view = T.decl_buffer((128, 16), data=B_sbuf.data, scope="trn.sbuf", layout=None) + for b_loop in T.serial(0, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, 4, annotations={"nki_dim": "F"}): + T.nki.tensor_copy( + B_sbuf[p_loop, b_loop * 4 + f_loop], + A_sbuf[p_loop, b_loop * 64 + f_loop], + ) with target: mod = tvm.IRModule({"main": copy}) @@ -376,27 +359,26 @@ def test_copy_irregular_shape(): dst_shape = [128, 512] dst_layout = TileLayout(S[(128, 512) : (1 @ P, 1 @ F)]) - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): Tx.copy(A[:, i * 512 : i * 512 + 512], A_sbuf) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) - A = Tx.match_buffer(A_ptr, (128, 10000), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((1280000,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - for i, b_loop in Tx.grid(4, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim": "F"}): - Tx.nki.store(A_1[p_loop * 10000 + i * 512 + f_loop], A_sbuf[p_loop, f_loop]) + A = T.match_buffer(A_ptr, (128, 10000), layout=None) + A_1 = T.decl_buffer((1280000,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + for i, b_loop in T.grid(4, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim": "F"}): + T.nki.store(A_1[p_loop * 10000 + i * 512 + f_loop], A_sbuf[p_loop, f_loop]) with target: mod = tvm.IRModule({"main": copy}) @@ -411,28 +393,27 @@ def test_copy_different_shape_dim(): dst_layout = TileLayout(S[(128, 512) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(32): Tx.copy(A_sbuf, A[i, :, :]) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - A = Tx.match_buffer(A_ptr, (32, 128, 512), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((2097152,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - for i, b_loop in Tx.grid(32, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.load(A_sbuf[p_loop, f_loop], A_1[i * 65536 + p_loop * 128 + f_loop]) - # fmt: on + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + + A = T.match_buffer(A_ptr, (32, 128, 512), layout=None) + A_1 = T.decl_buffer((2097152,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + for i, b_loop in T.grid(32, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.load(A_sbuf[p_loop, f_loop], A_1[i * 65536 + p_loop * 128 + f_loop]) + # fmt: on with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -445,30 +426,29 @@ def test_copy_with_offset(): dst_shape = [512, 512] dst_layout = TileLayout(S[(4, 128, 512) : (512 @ F, 1 @ P, 1 @ F)]) - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(2): Tx.copy(A_sbuf[i * 256 : i * 256 + 256, :], A) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - A = Tx.match_buffer(A_ptr, (256, 512), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((131072,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for i, b_loop in Tx.grid(2, 2): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim": "F"}): - Tx.nki.load( - A_sbuf[p_loop, i * 1024 + b_loop * 512 + f_loop], - A_1[b_loop * 65536 + p_loop * 512 + f_loop], - ) + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + + A = T.match_buffer(A_ptr, (256, 512), layout=None) + A_1 = T.decl_buffer((131072,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for i, b_loop in T.grid(2, 2): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim": "F"}): + T.nki.load( + A_sbuf[p_loop, i * 1024 + b_loop * 512 + f_loop], + A_1[b_loop * 65536 + p_loop * 512 + f_loop], + ) with target: mod = tvm.IRModule({"main": copy}) @@ -478,34 +458,33 @@ def expected(A_ptr: Tx.handle): def test_large_dma_copy(): src_shape = [512, 4096] - src_layout = Tx.TileLayout(Tx.S[(4, 128, 4096) : (4096 * 128, 4096, 1)]) + src_layout = T.TileLayout(T.S[(4, 128, 4096) : (4096 * 128, 4096, 1)]) dst_shape = [512, 4096] dst_layout = TileLayout(S[(4, 128, 4096) : (4096 @ F, 1 @ P, 1 @ F)]) - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): Tx.copy(A_sbuf[i * 128 : i * 128 + 128, :], A[i * 128 : i * 128 + 128, :]) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - A = Tx.match_buffer(A_ptr, (512, 4096), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((2097152,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - for i, b_loop in Tx.grid(4, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, 4096, annotations={"nki_dim": "F"}): - Tx.nki.load( - A_sbuf[p_loop, i * 4096 + f_loop], - A_1[i * 524288 + p_loop * 4096 + f_loop], - ) + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + + A = T.match_buffer(A_ptr, (512, 4096), layout=None) + A_1 = T.decl_buffer((2097152,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + for i, b_loop in T.grid(4, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, 4096, annotations={"nki_dim": "F"}): + T.nki.load( + A_sbuf[p_loop, i * 4096 + f_loop], + A_1[i * 524288 + p_loop * 4096 + f_loop], + ) with target: mod = tvm.IRModule({"main": copy}) @@ -519,29 +498,27 @@ def test_copy_with_inst_size_limit(): dst_shape = src_shape dst_layout = src_layout - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - Tx.device_entry() - B_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + T.device_entry() + B_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): Tx.copy(A_sbuf[i * 128 : i * 128 + 128, :], B_sbuf[i * 128 : i * 128 + 128, :]) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - with Tx.thread(): - B_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - for i, b_loop in Tx.grid(4, 8): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim": "F"}): - Tx.nki.tensor_copy( - A_sbuf[p_loop, i * 4096 + b_loop * 512 + f_loop], - B_sbuf[p_loop, i * 4096 + b_loop * 512 + f_loop], - ) + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + B_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + for i, b_loop in T.grid(4, 8): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim": "F"}): + T.nki.tensor_copy( + A_sbuf[p_loop, i * 4096 + b_loop * 512 + f_loop], + B_sbuf[p_loop, i * 4096 + b_loop * 512 + f_loop], + ) with target: mod = tvm.IRModule({"main": copy}) @@ -551,32 +528,31 @@ def expected(A_ptr: Tx.handle): def test_copy_with_complex_index(): A_shape = [4096, 4096] - A_layout = Tx.TileLayout(Tx.S[(4096, 4096) : (1, 4096)]) + A_layout = T.TileLayout(T.S[(4096, 4096) : (1, 4096)]) A_sbuf_shape = (2, 2048, 1024) A_sbuf_layout = TileLayout(S[(2, 2048, 8, 128) : (16384 @ F, 1 @ F, 2048 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle, ) -> None: - A = Tx.match_buffer(A_ptr, A_shape, "float32", layout=A_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(A_sbuf_shape, "float32", scope="trn.sbuf", layout=A_sbuf_layout) + @T.prim_func + def copy(A_ptr: T.handle, ) -> None: + A = T.match_buffer(A_ptr, A_shape, "float32", layout=A_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(A_sbuf_shape, "float32", scope="trn.sbuf", layout=A_sbuf_layout) Tx.copy(A_sbuf[1, 0:2048, 0:1024], A[2048: 4096, 3072:4096]) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - A = Tx.match_buffer(A_ptr, (4096, 4096), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((16777216,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 32768), scope="trn.sbuf") - for b_loop in Tx.serial(0, 8): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 2048, annotations={"nki_dim":"F"}): - Tx.nki.load(A_sbuf[p_loop, b_loop * 2048 + f_loop + 16384], A_1[b_loop * 524288 + p_loop * 4096 + f_loop + 12584960]) # noqa: E501 - # fmt: on + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + + A = T.match_buffer(A_ptr, (4096, 4096), layout=None) + A_1 = T.decl_buffer((16777216,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 32768), scope="trn.sbuf") + for b_loop in T.serial(0, 8): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 2048, annotations={"nki_dim":"F"}): + T.nki.load(A_sbuf[p_loop, b_loop * 2048 + f_loop + 16384], A_1[b_loop * 524288 + p_loop * 4096 + f_loop + 12584960]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -585,32 +561,31 @@ def expected(A_ptr: Tx.handle): def test_copy_with_complex_index_2(): A_sbuf_shape = [4096, 4096] - A_sbuf_layout = Tx.TileLayout(Tx.S[(4096, 32, 128) : (1 @ F, 4096 @ F, 1 @ P)]) + A_sbuf_layout = T.TileLayout(T.S[(4096, 32, 128) : (1 @ F, 4096 @ F, 1 @ P)]) A_shape = (2, 2048, 1024) - A_layout = Tx.TileLayout(Tx.S[(2, 2048, 1024) : (2048 * 1024, 1, 2048)]) + A_layout = T.TileLayout(T.S[(2, 2048, 1024) : (2048 * 1024, 1, 2048)]) # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle, ) -> None: - A = Tx.match_buffer(A_ptr, A_shape, "float32", layout=A_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(A_sbuf_shape, "float32", scope="trn.sbuf", layout=A_sbuf_layout) + @T.prim_func + def copy(A_ptr: T.handle, ) -> None: + A = T.match_buffer(A_ptr, A_shape, "float32", layout=A_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(A_sbuf_shape, "float32", scope="trn.sbuf", layout=A_sbuf_layout) Tx.copy(A_sbuf[2048: 4096, 3072:4096], A[1, 0:2048, 0:1024]) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - A = Tx.match_buffer(A_ptr, (2, 2048, 1024), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((4194304,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 131072), scope="trn.sbuf") - for b_loop in Tx.serial(0, 8): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 2048, annotations={"nki_dim":"F"}): - Tx.nki.load(A_sbuf[p_loop, b_loop * 4096 + f_loop + 100352], A_1[b_loop * 262144 + p_loop * 2048 + f_loop + 2097152]) # noqa: E501 - # fmt: on + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + + A = T.match_buffer(A_ptr, (2, 2048, 1024), layout=None) + A_1 = T.decl_buffer((4194304,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 131072), scope="trn.sbuf") + for b_loop in T.serial(0, 8): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 2048, annotations={"nki_dim":"F"}): + T.nki.load(A_sbuf[p_loop, b_loop * 4096 + f_loop + 100352], A_1[b_loop * 262144 + p_loop * 2048 + f_loop + 2097152]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": copy}) @@ -625,44 +600,42 @@ def test_copy_transpose_with_workspace(): dst_layout = TileLayout(S[(2048, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def copy() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) - identity = Tx.alloc_buffer((128, 128), "float32", scope="trn.sbuf") - acc_psum = Tx.alloc_buffer((1, 128, 512), "float32", scope="trn.psum", allocated_addr=(0, 0)) # noqa: E501 - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for rhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"F"}): - Tx.nki.identity(identity[p_loop, rhs_f_loop], 128) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + identity = T.alloc_buffer((128, 128), "float32", scope="trn.sbuf") + acc_psum = T.alloc_buffer((1, 128, 512), "float32", scope="trn.psum", allocated_addr=(0, 0)) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for rhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"F"}): + T.nki.identity(identity[p_loop, rhs_f_loop], 128) Tx.copy(B_sbuf, A_sbuf, workspace={"identity": identity, "acc_psum": acc_psum}) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "copy"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - identity = Tx.alloc_buffer((128, 128), scope="trn.sbuf") - acc_psum = Tx.alloc_buffer((1, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - Tx.nki.identity(identity[p_loop, rhs_f_loop], 128) - for b_loop in range(16): - for extend_b_loop in range(1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for lhs_f_loop in Tx.serial(128, annotations={"nki_dim": "lhs_F"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "rhs_F"}): - Tx.nki.matmul(acc_psum[0, lhs_f_loop, extend_b_loop * 128 + rhs_f_loop], A_sbuf[p_loop, b_loop * 128 + lhs_f_loop], identity[p_loop, rhs_f_loop], Tx.bool(True)) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - Tx.nki.tensor_copy(B_sbuf[p_loop, f_loop * 16 + b_loop], acc_psum[0, p_loop, f_loop]) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "copy"}) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + identity = T.alloc_buffer((128, 128), scope="trn.sbuf") + acc_psum = T.alloc_buffer((1, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "F"}): + T.nki.identity(identity[p_loop, rhs_f_loop], 128) + for b_loop in range(16): + for extend_b_loop in range(1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for lhs_f_loop in T.serial(128, annotations={"nki_dim": "lhs_F"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "rhs_F"}): + T.nki.matmul(acc_psum[0, lhs_f_loop, extend_b_loop * 128 + rhs_f_loop], A_sbuf[p_loop, b_loop * 128 + lhs_f_loop], identity[p_loop, rhs_f_loop], T.bool(True)) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(128, annotations={"nki_dim": "F"}): + T.nki.tensor_copy(B_sbuf[p_loop, f_loop * 16 + b_loop], acc_psum[0, p_loop, f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -671,35 +644,34 @@ def expected(): def test_copy_with_guard(): src_shape = [512, 512] - src_layout = Tx.TileLayout(Tx.S[(4, 128, 512) : (512 * 128, 512, 1)]) + src_layout = T.TileLayout(T.S[(4, 128, 512) : (512 * 128, 512, 1)]) dst_shape = [512, 512] dst_layout = TileLayout(S[(4, 128, 512) : (512 @ F, 1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for j in range(4): for i in range(4): Tx.copy(A_sbuf[i * 128 : i * 128 + 128, 0:128*j], A[i * 128 : i * 128 + 128, 0:128*j]) # noqa: E501 - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - A = Tx.match_buffer(A_ptr, (512, 512), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((262144,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for j, i, b_loop in Tx.grid(4, 4, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 384, annotations={"nki_dim":"F"}): - if f_loop < j * 128: - Tx.nki.load(A_sbuf[p_loop, i * 512 + f_loop], A_1[i * 65536 + p_loop * 512 + f_loop]) # noqa: E501 - # fmt: on + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + + A = T.match_buffer(A_ptr, (512, 512), layout=None) + A_1 = T.decl_buffer((262144,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for j, i, b_loop in T.grid(4, 4, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 384, annotations={"nki_dim":"F"}): + if f_loop < j * 128: + T.nki.load(A_sbuf[p_loop, i * 512 + f_loop], A_1[i * 65536 + p_loop * 512 + f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -709,35 +681,34 @@ def expected(A_ptr: Tx.handle): def test_copy_with_guard_2(): src_shape = [512, 512] - src_layout = Tx.TileLayout(Tx.S[(4, 128, 512) : (512 * 128, 512, 1)]) + src_layout = T.TileLayout(T.S[(4, 128, 512) : (512 * 128, 512, 1)]) dst_shape = [512, 512] dst_layout = TileLayout(S[(4, 128, 512) : (512 @ F, 1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for j in range(4): for i in range(4): Tx.copy(A_sbuf[0:128*j, 0:128*i], A[0:128*j, 0:128*i]) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - A = Tx.match_buffer(A_ptr, (512, 512), layout=None) - with Tx.thread(): - A_1 = Tx.decl_buffer((262144,), data=A.data, layout=None) - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for j, i, b_loop in Tx.grid(4, 4, 3): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 384, annotations={"nki_dim":"F"}): - if b_loop - j < 0 and f_loop < i * 128: - Tx.nki.load(A_sbuf[p_loop, b_loop * 512 + f_loop], A_1[b_loop * 65536 + p_loop * 512 + f_loop]) # noqa: E501 - # fmt: on + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + + A = T.match_buffer(A_ptr, (512, 512), layout=None) + A_1 = T.decl_buffer((262144,), data=A.data, layout=None) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for j, i, b_loop in T.grid(4, 4, 3): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 384, annotations={"nki_dim":"F"}): + if b_loop - j < 0 and f_loop < i * 128: + T.nki.load(A_sbuf[p_loop, b_loop * 512 + f_loop], A_1[b_loop * 65536 + p_loop * 512 + f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -752,42 +723,40 @@ def test_copy_transpose_with_guard(): dst_layout = TileLayout(S[(2048, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def copy() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): for j in range(4): Tx.copy(B_sbuf[i * 128 : i * 128 + 128, 0:128*j], A_sbuf[i * 128 : i * 128 + 128, 0:128*j]) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "copy"}) - - with Tx.thread(): - identity = Tx.alloc_buffer((128, 128), scope="trn.sbuf") - acc_psum = Tx.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - Tx.nki.identity(identity[p_loop, rhs_f_loop], 128) - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for i, j, b_loop in Tx.grid(4, 4, 3): - for extend_b_loop in range(1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for lhs_f_loop in Tx.serial(128, annotations={"nki_dim": "lhs_F"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "rhs_F"}): - if b_loop - j < 0: - Tx.nki.matmul(acc_psum[b_loop, lhs_f_loop, rhs_f_loop], A_sbuf[p_loop, i * 512 + b_loop * 128 + lhs_f_loop], identity[p_loop, rhs_f_loop], Tx.bool(True)) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - if b_loop - j < 0: - Tx.nki.tensor_copy(B_sbuf[p_loop, i * 512 + f_loop * 4 + b_loop], acc_psum[b_loop, p_loop, f_loop]) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "copy"}) + identity = T.alloc_buffer((128, 128), scope="trn.sbuf") + acc_psum = T.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "F"}): + T.nki.identity(identity[p_loop, rhs_f_loop], 128) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for i, j, b_loop in T.grid(4, 4, 3): + for extend_b_loop in range(1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for lhs_f_loop in T.serial(128, annotations={"nki_dim": "lhs_F"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "rhs_F"}): + if b_loop - j < 0: + T.nki.matmul(acc_psum[b_loop, lhs_f_loop, rhs_f_loop], A_sbuf[p_loop, i * 512 + b_loop * 128 + lhs_f_loop], identity[p_loop, rhs_f_loop], T.bool(True)) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(128, annotations={"nki_dim": "F"}): + if b_loop - j < 0: + T.nki.tensor_copy(B_sbuf[p_loop, i * 512 + f_loop * 4 + b_loop], acc_psum[b_loop, p_loop, f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -803,26 +772,24 @@ def test_copy_with_specified_max_inst_size(): dst_layout = src_layout # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.copy(A_sbuf, B_sbuf, max_inst_size=128) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf", layout=None) - B_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf", layout=None) - for b_loop in Tx.serial(0, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - Tx.nki.tensor_copy(A_sbuf[p_loop, b_loop * 128 + f_loop], B_sbuf[p_loop, b_loop * 128 + f_loop]) # noqa: E501 - # fmt: on + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + A_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf", layout=None) + B_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf", layout=None) + for b_loop in T.serial(0, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(128, annotations={"nki_dim": "F"}): + T.nki.tensor_copy(A_sbuf[p_loop, b_loop * 128 + f_loop], B_sbuf[p_loop, b_loop * 128 + f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -831,39 +798,37 @@ def expected(A_ptr: Tx.handle): def test_copy_transpose_with_extended_f(): # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((128, 2048), "float32", scope="trn.sbuf", layout="PF") - B_sbuf = Tx.alloc_buffer((128, 2048), "float32", scope="trn.sbuf", layout="FP") + @T.prim_func + def copy(A_ptr: T.handle) -> None: + T.device_entry() + A_sbuf = T.alloc_buffer((128, 2048), "float32", scope="trn.sbuf", layout="PF") + B_sbuf = T.alloc_buffer((128, 2048), "float32", scope="trn.sbuf", layout="FP") Tx.copy(B_sbuf, A_sbuf) - @Tx.prim_func - def expected(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "copy"}) - - with Tx.thread(): - identity = Tx.alloc_buffer((128, 128), scope="trn.sbuf") - acc_psum = Tx.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - Tx.nki.identity(identity[p_loop, rhs_f_loop], 128) - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for b_loop in range(4): - for extend_b_loop in range(4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for lhs_f_loop in Tx.serial(128, annotations={"nki_dim": "lhs_F"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "rhs_F"}): - Tx.nki.matmul(acc_psum[b_loop, lhs_f_loop, extend_b_loop * 128 + rhs_f_loop], A_sbuf[p_loop, b_loop * 512 + extend_b_loop * 128 + lhs_f_loop], identity[p_loop, rhs_f_loop], Tx.bool(True)) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(512, annotations={"nki_dim": "F"}): - Tx.nki.tensor_copy(B_sbuf[p_loop, b_loop * 512 + f_loop], acc_psum[b_loop, p_loop, f_loop]) # noqa: E501 - - # fmt: on + @T.prim_func + def expected(A_ptr: T.handle): + T.func_attr({"global_symbol": "copy"}) + identity = T.alloc_buffer((128, 128), scope="trn.sbuf") + acc_psum = T.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "F"}): + T.nki.identity(identity[p_loop, rhs_f_loop], 128) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for b_loop in range(4): + for extend_b_loop in range(4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for lhs_f_loop in T.serial(128, annotations={"nki_dim": "lhs_F"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "rhs_F"}): + T.nki.matmul(acc_psum[b_loop, lhs_f_loop, extend_b_loop * 128 + rhs_f_loop], A_sbuf[p_loop, b_loop * 512 + extend_b_loop * 128 + lhs_f_loop], identity[p_loop, rhs_f_loop], T.bool(True)) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(512, annotations={"nki_dim": "F"}): + T.nki.tensor_copy(B_sbuf[p_loop, b_loop * 512 + f_loop], acc_psum[b_loop, p_loop, f_loop]) # noqa: E501 + + # fmt: on with target: mod = tvm.IRModule({"main": copy}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py index fc61569a3281..18beb0390638 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py @@ -19,7 +19,8 @@ import tvm import tvm.testing from tvm.ir import assert_structural_equal as _assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout from tvm.tirx.stmt_functor import ir_transform @@ -28,8 +29,6 @@ def _strip_exec_scope_stmt(stmt): def _postorder(node): - if isinstance(node, tvm.tirx.ExecScopeStmt): - return node.body if isinstance(node, tvm.tirx.AttrStmt) and node.attr_key == "tirx.device_entry": return node.body return node @@ -38,7 +37,7 @@ def _postorder(node): stmt, preorder=lambda _node: None, postorder=_postorder, - only_enable=["tirx.ExecScopeStmt", "tirx.AttrStmt"], + only_enable=["tirx.AttrStmt"], ) @@ -57,29 +56,27 @@ def test_simple_gemm(): C_layout = TileLayout(S[(128, 128) : (1 @ P, 1 @ F)]).to_psum() # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((128, 128), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((128, 128), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((128, 128), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((128, 128), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((128, 128), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((128, 128), "float32", scope="trn.psum", layout=C_layout) Tx.gemm(C_psum, A_sbuf, B_sbuf, C_psum) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 128), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 128), scope="trn.sbuf") - C_psum = Tx.alloc_buffer((1, 128, 128), scope="trn.psum") - for lhs_b_loop, rhs_b_loop, reduction_b_loop in Tx.grid(1, 1, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"rhs_F"}): - Tx.nki.matmul(C_psum[0, lhs_f_loop, rhs_f_loop], A_sbuf[p_loop, lhs_f_loop], B_sbuf[p_loop, rhs_f_loop], True) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + A_sbuf = T.alloc_buffer((128, 128), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 128), scope="trn.sbuf") + C_psum = T.alloc_buffer((1, 128, 128), scope="trn.psum") + for lhs_b_loop, rhs_b_loop, reduction_b_loop in T.grid(1, 1, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"rhs_F"}): + T.nki.matmul(C_psum[0, lhs_f_loop, rhs_f_loop], A_sbuf[p_loop, lhs_f_loop], B_sbuf[p_loop, rhs_f_loop], True) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -93,29 +90,27 @@ def test_larger_gemm(): C_layout = TileLayout(S[(2, 128, 2, 128) : (256 @ F, 1 @ P, 128 @ F, 1 @ F)]).to_psum() # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((256, 512), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((512, 256), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((256, 256), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((256, 512), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((512, 256), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((256, 256), "float32", scope="trn.psum", layout=C_layout) Tx.gemm(C_psum, A_sbuf, B_sbuf, C_psum) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - C_psum = Tx.alloc_buffer((1, 128, 512), scope="trn.psum") - for lhs_b_loop, rhs_b_loop, reduction_b_loop in Tx.grid(2, 1, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 256, annotations={"nki_dim":"rhs_F"}): - Tx.nki.matmul(C_psum[0, lhs_f_loop, lhs_b_loop * 256 + rhs_f_loop], A_sbuf[p_loop, lhs_b_loop * 512 + reduction_b_loop * 128 + lhs_f_loop], B_sbuf[p_loop, reduction_b_loop * 256 + rhs_f_loop], True) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + A_sbuf = T.alloc_buffer((128, 1024), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 1024), scope="trn.sbuf") + C_psum = T.alloc_buffer((1, 128, 512), scope="trn.psum") + for lhs_b_loop, rhs_b_loop, reduction_b_loop in T.grid(2, 1, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 256, annotations={"nki_dim":"rhs_F"}): + T.nki.matmul(C_psum[0, lhs_f_loop, lhs_b_loop * 256 + rhs_f_loop], A_sbuf[p_loop, lhs_b_loop * 512 + reduction_b_loop * 128 + lhs_f_loop], B_sbuf[p_loop, reduction_b_loop * 256 + rhs_f_loop], True) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -129,12 +124,12 @@ def test_gemm_in_a_loop(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ P, 128 @ F, 1 @ F)]).to_psum() # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) for i in range(2): for k in range(2): Tx.gemm( @@ -144,21 +139,19 @@ def gemm() -> None: C_psum[256 * i : 256 * i + 256, :], ) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - C_psum = Tx.alloc_buffer((2, 128, 512), scope="trn.psum") - for i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in Tx.grid(2, 2, 2, 1, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 256, annotations={"nki_dim":"rhs_F"}): - Tx.nki.matmul(C_psum[i, lhs_f_loop, lhs_b_loop * 256 + rhs_f_loop], A_sbuf[p_loop, i * 2048 + lhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + lhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + rhs_f_loop], True) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + C_psum = T.alloc_buffer((2, 128, 512), scope="trn.psum") + for i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in T.grid(2, 2, 2, 1, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 256, annotations={"nki_dim":"rhs_F"}): + T.nki.matmul(C_psum[i, lhs_f_loop, lhs_b_loop * 256 + rhs_f_loop], A_sbuf[p_loop, i * 2048 + lhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + lhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + rhs_f_loop], True) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -172,12 +165,12 @@ def test_gemm_with_stride(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ P, 128 @ F, 1 @ F)]).to_psum() # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((512, 512, 2), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((512, 2, 256), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((512, 512, 2), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((512, 2, 256), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) for i in range(2): for k in range(2): Tx.gemm( @@ -187,21 +180,19 @@ def gemm() -> None: C_psum[256 * i : 256 * i + 256, :], ) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 4095), scope="trn.sbuf") - C_psum = Tx.alloc_buffer((2, 128, 512), scope="trn.psum") - for i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in Tx.grid(2, 2, 2, 1, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 256, annotations={"nki_dim":"rhs_F"}): - Tx.nki.matmul(C_psum[i, lhs_f_loop, lhs_b_loop * 256 + rhs_f_loop], A_sbuf[p_loop, i * 2048 + lhs_b_loop * 1024 + reduction_b_loop * 256 + k * 128 + lhs_f_loop], B_sbuf[p_loop, reduction_b_loop * 1024 + k * 512 + rhs_f_loop * 2], True) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 4095), scope="trn.sbuf") + C_psum = T.alloc_buffer((2, 128, 512), scope="trn.psum") + for i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in T.grid(2, 2, 2, 1, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 256, annotations={"nki_dim":"rhs_F"}): + T.nki.matmul(C_psum[i, lhs_f_loop, lhs_b_loop * 256 + rhs_f_loop], A_sbuf[p_loop, i * 2048 + lhs_b_loop * 1024 + reduction_b_loop * 256 + k * 128 + lhs_f_loop], B_sbuf[p_loop, reduction_b_loop * 1024 + k * 512 + rhs_f_loop * 2], True) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) @@ -216,12 +207,12 @@ def test_gemm_swap_lhs_rhs(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ F, 128 @ F, 1 @ P)]).to_psum() # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) for i in range(2): for k in range(2): Tx.gemm( @@ -231,21 +222,19 @@ def gemm() -> None: C_psum[256 * i : 256 * i + 256, :], ) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - C_psum = Tx.alloc_buffer((2, 128, 512), scope="trn.psum") - for i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in Tx.grid(2, 2, 2, 2, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"rhs_F"}): - Tx.nki.matmul(C_psum[i, lhs_f_loop, rhs_b_loop * 256 + lhs_b_loop * 128 + rhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + lhs_b_loop * 128 + lhs_f_loop], A_sbuf[p_loop, i * 2048 + rhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + rhs_f_loop], True) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + C_psum = T.alloc_buffer((2, 128, 512), scope="trn.psum") + for i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in T.grid(2, 2, 2, 2, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"rhs_F"}): + T.nki.matmul(C_psum[i, lhs_f_loop, rhs_b_loop * 256 + lhs_b_loop * 128 + rhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + lhs_b_loop * 128 + lhs_f_loop], A_sbuf[p_loop, i * 2048 + rhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + rhs_f_loop], True) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -259,12 +248,12 @@ def test_gemm_with_sbuf_output(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ F, 128 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) - C_sbuf = Tx.alloc_buffer((512, 256), "float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) + C_sbuf = T.alloc_buffer((512, 256), "float32", scope="trn.sbuf", layout=C_layout) for i in range(2): for k in range(2): Tx.gemm( @@ -273,27 +262,25 @@ def gemm() -> None: B_sbuf[512 * k : 512 * k + 512, :], C_sbuf[256 * i : 256 * i + 256, :], ) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - buffer = Tx.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - for i, k, lhs_b_loop, rhs_b_loop in Tx.grid(2, 2, 2, 2): - for reduction_b_loop in range(4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"rhs_F"}): - Tx.nki.matmul(buffer[lhs_b_loop * 2 + rhs_b_loop, lhs_f_loop, rhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + lhs_b_loop * 128 + lhs_f_loop], A_sbuf[p_loop, i * 2048 + rhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + rhs_f_loop], True) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for rhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"F"}): - Tx.nki.tensor_copy(C_sbuf[lhs_f_loop, i * 512 + rhs_b_loop * 256 + lhs_b_loop * 128 + rhs_f_loop], buffer[lhs_b_loop * 2 + rhs_b_loop, lhs_f_loop, rhs_f_loop]) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + buffer = T.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 1024), scope="trn.sbuf") + for i, k, lhs_b_loop, rhs_b_loop in T.grid(2, 2, 2, 2): + for reduction_b_loop in range(4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"rhs_F"}): + T.nki.matmul(buffer[lhs_b_loop * 2 + rhs_b_loop, lhs_f_loop, rhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + lhs_b_loop * 128 + lhs_f_loop], A_sbuf[p_loop, i * 2048 + rhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + rhs_f_loop], True) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for rhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"F"}): + T.nki.tensor_copy(C_sbuf[lhs_f_loop, i * 512 + rhs_b_loop * 256 + lhs_b_loop * 128 + rhs_f_loop], buffer[lhs_b_loop * 2 + rhs_b_loop, lhs_f_loop, rhs_f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -309,12 +296,12 @@ def test_gemm_different_shape(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ F, 128 @ F, 1 @ P)]).to_psum() # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((2, 512, 1024), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((2, 512, 1024), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) for i in range(2): for k in range(2): Tx.gemm( @@ -324,21 +311,19 @@ def gemm() -> None: C_psum[256 * i : 256 * i + 256, :], ) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 8192), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - C_psum = Tx.alloc_buffer((2, 128, 512), scope="trn.psum") - for i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in Tx.grid(2, 2, 2, 2, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"rhs_F"}): - Tx.nki.matmul(C_psum[i, lhs_f_loop, rhs_b_loop * 256 + lhs_b_loop * 128 + rhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + lhs_b_loop * 128 + lhs_f_loop], A_sbuf[p_loop, i * 2048 + rhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + rhs_f_loop + 4096], True) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + A_sbuf = T.alloc_buffer((128, 8192), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + C_psum = T.alloc_buffer((2, 128, 512), scope="trn.psum") + for i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in T.grid(2, 2, 2, 2, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"rhs_F"}): + T.nki.matmul(C_psum[i, lhs_f_loop, rhs_b_loop * 256 + lhs_b_loop * 128 + rhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + lhs_b_loop * 128 + lhs_f_loop], A_sbuf[p_loop, i * 2048 + rhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + rhs_f_loop + 4096], True) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -352,29 +337,27 @@ def test_gemm_too_large_f_size(): C_layout = TileLayout(S[(2, 128, 1024) : (1024 @ F, 1 @ P, 1 @ F)]).to_psum() # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((256, 128), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((128, 1024), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((256, 1024), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((256, 128), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((128, 1024), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((256, 1024), "float32", scope="trn.psum", layout=C_layout) Tx.gemm(C_psum, A_sbuf, B_sbuf, C_psum) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 256), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - C_psum = Tx.alloc_buffer((4, 128, 512), scope="trn.psum") - for lhs_b_loop, rhs_b_loop, reduction_b_loop in Tx.grid(2, 2, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 512, annotations={"nki_dim":"rhs_F"}): - Tx.nki.matmul(C_psum[lhs_b_loop * 2 + rhs_b_loop, lhs_f_loop, rhs_f_loop], A_sbuf[p_loop, lhs_b_loop * 128 + lhs_f_loop], B_sbuf[p_loop, rhs_b_loop * 512 + rhs_f_loop], True) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + A_sbuf = T.alloc_buffer((128, 256), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 1024), scope="trn.sbuf") + C_psum = T.alloc_buffer((4, 128, 512), scope="trn.psum") + for lhs_b_loop, rhs_b_loop, reduction_b_loop in T.grid(2, 2, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 512, annotations={"nki_dim":"rhs_F"}): + T.nki.matmul(C_psum[lhs_b_loop * 2 + rhs_b_loop, lhs_f_loop, rhs_f_loop], A_sbuf[p_loop, lhs_b_loop * 128 + lhs_f_loop], B_sbuf[p_loop, rhs_b_loop * 512 + rhs_f_loop], True) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -388,13 +371,13 @@ def test_gemm_sbuf_output_with_workspace(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ F, 128 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) - C_sbuf = Tx.alloc_buffer((512, 256), "float32", scope="trn.sbuf", layout=C_layout) - C_psum = Tx.alloc_buffer((1, 128, 512), "float32", scope="trn.psum", allocated_addr=(0, 0)) + T.device_entry() + A_sbuf = T.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) + C_sbuf = T.alloc_buffer((512, 256), "float32", scope="trn.sbuf", layout=C_layout) + C_psum = T.alloc_buffer((1, 128, 512), "float32", scope="trn.psum", allocated_addr=(0, 0)) for i in range(2): for k in range(2): Tx.gemm( @@ -404,27 +387,25 @@ def gemm() -> None: C_sbuf[256 * i : 256 * i + 256, :], workspace={"acc_psum": C_psum} ) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - C_psum = Tx.alloc_buffer((1, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) - for i, k, lhs_b_loop, rhs_b_loop in Tx.grid(2, 2, 2, 2): - for reduction_b_loop in range(4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"rhs_F"}): - Tx.nki.matmul(C_psum[0, lhs_f_loop, rhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + lhs_b_loop * 128 + lhs_f_loop], A_sbuf[p_loop, i * 2048 + rhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + rhs_f_loop], True) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for rhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"F"}): - Tx.nki.tensor_copy(C_sbuf[lhs_f_loop, i * 512 + rhs_b_loop * 256 + lhs_b_loop * 128 + rhs_f_loop], C_psum[0, lhs_f_loop, rhs_f_loop]) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 1024), scope="trn.sbuf") + C_psum = T.alloc_buffer((1, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) + for i, k, lhs_b_loop, rhs_b_loop in T.grid(2, 2, 2, 2): + for reduction_b_loop in range(4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"rhs_F"}): + T.nki.matmul(C_psum[0, lhs_f_loop, rhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + lhs_b_loop * 128 + lhs_f_loop], A_sbuf[p_loop, i * 2048 + rhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + rhs_f_loop], True) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for rhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"F"}): + T.nki.tensor_copy(C_sbuf[lhs_f_loop, i * 512 + rhs_b_loop * 256 + lhs_b_loop * 128 + rhs_f_loop], C_psum[0, lhs_f_loop, rhs_f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -439,12 +420,12 @@ def test_gemm_pf_mismatch_fail(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ P, 128 @ F, 1 @ F)]).to_psum() # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((256, 1024), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((256, 1024), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) for i in range(2): for k in range(2): Tx.gemm( @@ -467,12 +448,12 @@ def test_gemm_transpose_AB(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ P, 128 @ F, 1 @ F)]).to_psum() # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((1024, 512), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((256, 1024), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((1024, 512), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((256, 1024), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) for i in range(2): for k in range(2): Tx.gemm( @@ -484,22 +465,20 @@ def gemm() -> None: transpose_B=True, ) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - C_psum = Tx.alloc_buffer((2, 128, 512), scope="trn.psum") - for i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in Tx.grid(2, 2, 2, 1, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 256, annotations={"nki_dim":"rhs_F"}): - Tx.nki.matmul(C_psum[i, lhs_f_loop, lhs_b_loop * 256 + rhs_f_loop], A_sbuf[p_loop, i * 2048 + lhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + lhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + rhs_f_loop], True) # noqa: E501 - - #fmt: off + T.func_attr({"global_symbol": "gemm"}) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + C_psum = T.alloc_buffer((2, 128, 512), scope="trn.psum") + for i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in T.grid(2, 2, 2, 1, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 256, annotations={"nki_dim":"rhs_F"}): + T.nki.matmul(C_psum[i, lhs_f_loop, lhs_b_loop * 256 + rhs_f_loop], A_sbuf[p_loop, i * 2048 + lhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + lhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + rhs_f_loop], True) # noqa: E501 + + #fmt: off with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -513,12 +492,12 @@ def test_gemm_guard(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ F, 128 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) - C_sbuf = Tx.alloc_buffer((512, 256), "float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) + C_sbuf = T.alloc_buffer((512, 256), "float32", scope="trn.sbuf", layout=C_layout) for i in range(2): for j in range(2): for k in range(2): @@ -528,29 +507,27 @@ def gemm() -> None: B_sbuf[0: 512 * (k + 1), 0: 128 * (j + 1)], C_sbuf[0: 256 * i, 0: 128 * (j + 1)], ) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - acc_psum = Tx.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - for i, j, k, lhs_b_loop, rhs_b_loop in Tx.grid(2, 2, 2, 2, 2): - for reduction_b_loop in range(8): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"rhs_F"}): - if reduction_b_loop - k * 4 < 4 and lhs_b_loop - j < 1 and 0 < i and reduction_b_loop - k * 4 < 4: # noqa: E501 - Tx.nki.matmul(acc_psum[lhs_b_loop * 2 + rhs_b_loop, lhs_f_loop, rhs_f_loop], B_sbuf[p_loop, reduction_b_loop * 256 + lhs_b_loop * 128 + lhs_f_loop], A_sbuf[p_loop, rhs_b_loop * 1024 + reduction_b_loop * 128 + rhs_f_loop], True) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for rhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"F"}): - if 0 < i and lhs_b_loop - j < 1: - Tx.nki.tensor_copy(C_sbuf[lhs_f_loop, rhs_b_loop * 256 + lhs_b_loop * 128 + rhs_f_loop], acc_psum[lhs_b_loop * 2 + rhs_b_loop, lhs_f_loop, rhs_f_loop]) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + acc_psum = T.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 1024), scope="trn.sbuf") + for i, j, k, lhs_b_loop, rhs_b_loop in T.grid(2, 2, 2, 2, 2): + for reduction_b_loop in range(8): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"rhs_F"}): + if reduction_b_loop - k * 4 < 4 and lhs_b_loop - j < 1 and 0 < i and reduction_b_loop - k * 4 < 4: # noqa: E501 + T.nki.matmul(acc_psum[lhs_b_loop * 2 + rhs_b_loop, lhs_f_loop, rhs_f_loop], B_sbuf[p_loop, reduction_b_loop * 256 + lhs_b_loop * 128 + lhs_f_loop], A_sbuf[p_loop, rhs_b_loop * 1024 + reduction_b_loop * 128 + rhs_f_loop], True) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for rhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"F"}): + if 0 < i and lhs_b_loop - j < 1: + T.nki.tensor_copy(C_sbuf[lhs_f_loop, rhs_b_loop * 256 + lhs_b_loop * 128 + rhs_f_loop], acc_psum[lhs_b_loop * 2 + rhs_b_loop, lhs_f_loop, rhs_f_loop]) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -566,12 +543,12 @@ def test_gemm_guard2(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ P, 128 @ F, 1 @ F)]).to_psum() # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((512, 256), "float32", scope="trn.psum", layout=C_layout) for j in range(4): for i in range(2): for k in range(2): @@ -581,22 +558,20 @@ def gemm() -> None: B_sbuf[512 * k : 512 * k + (j+1) * 128, :], C_psum[256 * i : 256 * i + 256, :], ) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - C_psum = Tx.alloc_buffer((2, 128, 512), scope="trn.psum") - for j, i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in Tx.grid(4, 2, 2, 2, 1, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for lhs_f_loop in Tx.serial(0, 128, annotations={"nki_dim":"lhs_F"}): - for rhs_f_loop in Tx.serial(0, 256, annotations={"nki_dim":"rhs_F"}): - if reduction_b_loop - j < 1 and reduction_b_loop - j < 1: - Tx.nki.matmul(C_psum[i, lhs_f_loop, lhs_b_loop * 256 + rhs_f_loop], A_sbuf[p_loop, i * 2048 + lhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + lhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + rhs_f_loop], True) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "gemm"}) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + C_psum = T.alloc_buffer((2, 128, 512), scope="trn.psum") + for j, i, k, lhs_b_loop, rhs_b_loop, reduction_b_loop in T.grid(4, 2, 2, 2, 1, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for lhs_f_loop in T.serial(0, 128, annotations={"nki_dim":"lhs_F"}): + for rhs_f_loop in T.serial(0, 256, annotations={"nki_dim":"rhs_F"}): + if reduction_b_loop - j < 1 and reduction_b_loop - j < 1: + T.nki.matmul(C_psum[i, lhs_f_loop, lhs_b_loop * 256 + rhs_f_loop], A_sbuf[p_loop, i * 2048 + lhs_b_loop * 1024 + k * 512 + reduction_b_loop * 128 + lhs_f_loop], B_sbuf[p_loop, k * 1024 + reduction_b_loop * 256 + rhs_f_loop], True) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": gemm}) mod = tvm.tirx.transform.LowerTIRx()(mod) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_private_alloc_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_private_alloc_trn.py index 85da5955739d..14c0f5dea795 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_private_alloc_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_private_alloc_trn.py @@ -18,7 +18,8 @@ import tvm import tvm.testing from tvm.ir import assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout from tvm.tirx.transform.trn import TrnPrivateBufferAlloc @@ -32,27 +33,27 @@ def test_copy_transpose(): dst_layout = TileLayout(S[(2048, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def copy() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.copy(B_sbuf, A_sbuf) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "copy"}) - Tx.device_entry() - identity = Tx.alloc_buffer((128, 128), scope="trn.sbuf") - acc_psum = Tx.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for rhs_f_loop in Tx.serial(128, annotations={"nki_dim": "F"}): - Tx.nki.identity(identity[p_loop, rhs_f_loop], 128) - A_sbuf = Tx.alloc_buffer((512, 512), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(128, 2048) : (1 @ P, 1@F)])) - B_sbuf = Tx.alloc_buffer((512, 512), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(2048, 128) : (1@F, 1@P)])) + T.func_attr({"global_symbol": "copy"}) + T.device_entry() + identity = T.alloc_buffer((128, 128), scope="trn.sbuf") + acc_psum = T.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for rhs_f_loop in T.serial(128, annotations={"nki_dim": "F"}): + T.nki.identity(identity[p_loop, rhs_f_loop], 128) + A_sbuf = T.alloc_buffer((512, 512), scope="trn.sbuf", + layout=T.TileLayout(T.S[(128, 2048) : (1 @ P, 1@F)])) + B_sbuf = T.alloc_buffer((512, 512), scope="trn.sbuf", + layout=T.TileLayout(T.S[(2048, 128) : (1@F, 1@P)])) Tx.copy(B_sbuf[0:512, 0:512], A_sbuf[0:512, 0:512], workspace={"acc_psum": acc_psum, "identity": identity}) # noqa: E501 # fmt: on @@ -69,11 +70,11 @@ def test_normal_copy(): dst_layout = TileLayout(S[(128, 512) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.copy(A_sbuf, A) # fmt: on with target: @@ -87,31 +88,31 @@ def test_unary_with_bias_scale(): src_layout = TileLayout(S[(128, 4096) : (1 @ P, 1 @ F)]) dst_shape = src_shape dst_layout = src_layout - bias = Tx.float32(1.0) - scale = Tx.float32(2.0) + bias = T.float32(1.0) + scale = T.float32(2.0) # fmt: off - @Tx.prim_func + @T.prim_func def unary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.exp(C_sbuf, A_sbuf, bias=bias, scale=scale) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "unary"}) - Tx.device_entry() - const_bias = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(512, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(1.0)) - A_sbuf = Tx.alloc_buffer((512, 1024), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(128, 4096) : (1@P, 1@F)])) - C_sbuf = Tx.alloc_buffer((512, 1024), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(128, 4096) : (1@P, 1@F)])) - Tx.exp(C_sbuf[0:512, 0:1024], A_sbuf[0:512, 0:1024], Tx.float32(1.0), Tx.float32(2.0), workspace={"const_bias": const_bias}) # noqa: E501 + T.func_attr({"global_symbol": "unary"}) + T.device_entry() + const_bias = T.alloc_buffer((128, 512), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(512, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(1.0)) + A_sbuf = T.alloc_buffer((512, 1024), scope="trn.sbuf", + layout=T.TileLayout(T.S[(128, 4096) : (1@P, 1@F)])) + C_sbuf = T.alloc_buffer((512, 1024), scope="trn.sbuf", + layout=T.TileLayout(T.S[(128, 4096) : (1@P, 1@F)])) + Tx.exp(C_sbuf[0:512, 0:1024], A_sbuf[0:512, 0:1024], T.float32(1.0), T.float32(2.0), workspace={"const_bias": const_bias}) # noqa: E501 # fmt: on with target: mod = tvm.IRModule({"main": unary}) @@ -126,22 +127,22 @@ def test_reduction_two_stage(): dst_layout = TileLayout(S[(128, 4) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def reduction(): - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.sum(B_sbuf, A_sbuf, axes=(1, 3)) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "reduction"}) - Tx.device_entry() - partial_reduce = Tx.alloc_buffer((128, 32), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer((128, 32, 4, 32), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(128, 32 * 32 * 4) : (1@P, 1@F)])) - B_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(128, 4) : (1@P, 1@F)])) + T.func_attr({"global_symbol": "reduction"}) + T.device_entry() + partial_reduce = T.alloc_buffer((128, 32), scope="trn.sbuf") + A_sbuf = T.alloc_buffer((128, 32, 4, 32), scope="trn.sbuf", + layout=T.TileLayout(T.S[(128, 32 * 32 * 4) : (1@P, 1@F)])) + B_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf", + layout=T.TileLayout(T.S[(128, 4) : (1@P, 1@F)])) Tx.sum(B_sbuf[0:128, 0:4], A_sbuf[0:128, 0:32, 0:4, 0:32], [1, 3], False, workspace={"partial_reduce": partial_reduce}) # noqa: E501 # fmt: on @@ -158,12 +159,12 @@ def test_gemm(): C_layout = TileLayout(S[(4, 128, 2, 128) : (256 @ F, 1 @ F, 128 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) - C_sbuf = Tx.alloc_buffer((512, 256), "float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((512, 1024), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((1024, 256), "float32", scope="trn.sbuf", layout=B_layout) + C_sbuf = T.alloc_buffer((512, 256), "float32", scope="trn.sbuf", layout=C_layout) for i in range(2): for k in range(2): Tx.gemm( @@ -172,19 +173,19 @@ def gemm() -> None: B_sbuf[512 * k : 512 * k + 512, :], C_sbuf[256 * i : 256 * i + 256, :], ) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "gemm"}) - Tx.device_entry() - acc_psum = Tx.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) - A_sbuf = Tx.alloc_buffer((512, 1024), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(4, 128, 8, 128) : (1024@F, 1@F, 1@F, 1@P)])) # noqa: E501 - B_sbuf = Tx.alloc_buffer((1024, 256), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(8, 128, 2, 128) : (256@F, 1@P, 128@F, 1@F)])) # noqa: E501 - C_sbuf = Tx.alloc_buffer((512, 256), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(4, 128, 2, 128) : (256@F, 1@F, 128@F, 1@P)])) # noqa: E501 - for i, k in Tx.grid(2, 2): - Tx.gemm(C_sbuf[256 * i:256 * i + 256, 0:256], A_sbuf[256 * i:256 * i + 256, 512 * k:512 * k + 512], B_sbuf[512 * k:512 * k + 512, 0:256], C_sbuf[256 * i:256 * i + 256, 0:256], False, False, Tx.float32(1.0), Tx.float32(0.0), workspace={"acc_psum": acc_psum}) # noqa: E501 + T.func_attr({"global_symbol": "gemm"}) + T.device_entry() + acc_psum = T.alloc_buffer((8, 128, 512), scope="trn.psum", allocated_addr=[0, 0]) + A_sbuf = T.alloc_buffer((512, 1024), scope="trn.sbuf", + layout=T.TileLayout(T.S[(4, 128, 8, 128) : (1024@F, 1@F, 1@F, 1@P)])) # noqa: E501 + B_sbuf = T.alloc_buffer((1024, 256), scope="trn.sbuf", + layout=T.TileLayout(T.S[(8, 128, 2, 128) : (256@F, 1@P, 128@F, 1@F)])) # noqa: E501 + C_sbuf = T.alloc_buffer((512, 256), scope="trn.sbuf", + layout=T.TileLayout(T.S[(4, 128, 2, 128) : (256@F, 1@F, 128@F, 1@P)])) # noqa: E501 + for i, k in T.grid(2, 2): + Tx.gemm(C_sbuf[256 * i:256 * i + 256, 0:256], A_sbuf[256 * i:256 * i + 256, 512 * k:512 * k + 512], B_sbuf[512 * k:512 * k + 512, 0:256], C_sbuf[256 * i:256 * i + 256, 0:256], False, False, T.float32(1.0), T.float32(0.0), workspace={"acc_psum": acc_psum}) # noqa: E501 # fmt: on with target: mod = tvm.IRModule({"main": gemm}) @@ -201,26 +202,26 @@ def test_binary_reduce_two_stage(): reduce_dst_layout = TileLayout(S[(128, 4) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def tensor_scalar_reduce() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) - B_sbuf = Tx.alloc_buffer(dst1_shape, "float32", scope="trn.sbuf", layout=dst1_layout) - C_sbuf = Tx.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 + T.device_entry() + A_sbuf = T.alloc_buffer(src1_shape, "float32", scope="trn.sbuf", layout=src1_layout) + B_sbuf = T.alloc_buffer(dst1_shape, "float32", scope="trn.sbuf", layout=dst1_layout) + C_sbuf = T.alloc_buffer(reduce_dst_shape, "float32", scope="trn.sbuf", layout=reduce_dst_layout) # noqa: E501 Tx.binary_reduce(B_sbuf, C_sbuf, A_sbuf, 1.0, "add", "sum", reduce_axes=(1, 2)) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "tensor_scalar_reduce"}) - Tx.device_entry() - partial_reduce = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer((512, 1024, 4), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(128, 4096, 4) : (1 @ P, 1 @ F, 4096 @ F)])) # noqa: E501 - B_sbuf = Tx.alloc_buffer((512, 1024, 4), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(128, 4096, 4) : (1 @ P, 1 @ F, 4096 @ F)])) # noqa: E501 - C_sbuf = Tx.alloc_buffer((512,), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(128, 4) : (1 @ P, 1 @ F)])) - Tx.binary_reduce(B_sbuf[0:512, 0:1024, 0:4], C_sbuf[0:512], A_sbuf[0:512, 0:1024, 0:4], Tx.float32(1.0), "add", "sum", [1, 2], workspace={"partial_reduce": partial_reduce}) # noqa: E501 + T.func_attr({"global_symbol": "tensor_scalar_reduce"}) + T.device_entry() + partial_reduce = T.alloc_buffer((128, 4), scope="trn.sbuf") + A_sbuf = T.alloc_buffer((512, 1024, 4), scope="trn.sbuf", + layout=T.TileLayout(T.S[(128, 4096, 4) : (1 @ P, 1 @ F, 4096 @ F)])) + B_sbuf = T.alloc_buffer((512, 1024, 4), scope="trn.sbuf", + layout=T.TileLayout(T.S[(128, 4096, 4) : (1 @ P, 1 @ F, 4096 @ F)])) + C_sbuf = T.alloc_buffer((512,), scope="trn.sbuf", + layout=T.TileLayout(T.S[(128, 4) : (1 @ P, 1 @ F)])) + Tx.binary_reduce(B_sbuf[0:512, 0:1024, 0:4], C_sbuf[0:512], A_sbuf[0:512, 0:1024, 0:4], T.float32(1.0), "add", "sum", [1, 2], workspace={"partial_reduce": partial_reduce}) # noqa: E501 # fmt: on with target: mod = tvm.IRModule({"main": tensor_scalar_reduce}) @@ -237,31 +238,31 @@ def test_activation_reduce_two_stage(): C_layout = TileLayout(S[(1, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def activation_reduce(): - Tx.device_entry() - A = Tx.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) - B = Tx.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) - C = Tx.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + A = T.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) + B = T.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) + C = T.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) for i in range(2): Tx.unary_reduce(B, C, A[i*16:i*16+16], "sqrt", "sum", reduce_axes=(0,1)) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "activation_reduce"}) - Tx.device_entry() - partial_reduce = Tx.alloc_buffer((128, 8), scope="trn.sbuf") - const_bias = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(1024, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(0.0)) - A = Tx.alloc_buffer((32, 512, 128), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(16 * 1024, 128) : (1@F, 1@P)])) - B = Tx.alloc_buffer((16, 512, 128), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(2, 4, 1024, 128) : (1024@F, 2048@F, 1@F, 1@P)])) # noqa: E501 - C = Tx.alloc_buffer((1, 128), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(1, 128) : (1@F, 1@P)])) + T.func_attr({"global_symbol": "activation_reduce"}) + T.device_entry() + partial_reduce = T.alloc_buffer((128, 8), scope="trn.sbuf") + const_bias = T.alloc_buffer((128, 1024), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(1024, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(0.0)) + A = T.alloc_buffer((32, 512, 128), scope="trn.sbuf", + layout=T.TileLayout(T.S[(16 * 1024, 128) : (1@F, 1@P)])) + B = T.alloc_buffer((16, 512, 128), scope="trn.sbuf", + layout=T.TileLayout(T.S[(2, 4, 1024, 128) : (1024@F, 2048@F, 1@F, 1@P)])) + C = T.alloc_buffer((1, 128), scope="trn.sbuf", + layout=T.TileLayout(T.S[(1, 128) : (1@F, 1@P)])) for i in range(2): Tx.unary_reduce(B[0:16, 0:512, 0:128], C[0, 0:128], A[i * 16:i * 16 + 16, 0:512, 0:128], "sqrt", "sum", None, None, [0, 1], workspace={"const_bias": const_bias, "partial_reduce": partial_reduce}) # noqa: E501 # fmt: on @@ -280,32 +281,32 @@ def test_partial_workspace_specify(): C_layout = TileLayout(S[(1, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def activation_reduce(): - Tx.device_entry() - partial_reduce = Tx.alloc_buffer((128, 16), scope="trn.sbuf") - A = Tx.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) - B = Tx.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) - C = Tx.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) + T.device_entry() + partial_reduce = T.alloc_buffer((128, 16), scope="trn.sbuf") + A = T.alloc_buffer(A_shape, dtype="float32", scope="trn.sbuf", layout=A_layout) + B = T.alloc_buffer(B_shape, dtype="float32", scope="trn.sbuf", layout=B_layout) + C = T.alloc_buffer(C_shape, dtype="float32", scope="trn.sbuf", layout=C_layout) for i in range(2): Tx.unary_reduce(B, C, A[i*16:i*16+16], "sqrt", "sum", reduce_axes=(0,1), workspace={"partial_reduce": partial_reduce}) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "activation_reduce"}) - Tx.device_entry() - const_bias = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(1024, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(0.0)) - partial_reduce = Tx.alloc_buffer((128, 16), scope="trn.sbuf") - A = Tx.alloc_buffer((32, 512, 128), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(16 * 1024, 128) : (1@F, 1@P)])) - B = Tx.alloc_buffer((16, 512, 128), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(2, 4, 1024, 128) : (1024@F, 2048@F, 1@F, 1@P)])) # noqa: E501 - C = Tx.alloc_buffer((1, 128), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(1, 128) : (1@F, 1@P)])) + T.func_attr({"global_symbol": "activation_reduce"}) + T.device_entry() + const_bias = T.alloc_buffer((128, 1024), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(1024, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(0.0)) + partial_reduce = T.alloc_buffer((128, 16), scope="trn.sbuf") + A = T.alloc_buffer((32, 512, 128), scope="trn.sbuf", + layout=T.TileLayout(T.S[(16 * 1024, 128) : (1@F, 1@P)])) + B = T.alloc_buffer((16, 512, 128), scope="trn.sbuf", + layout=T.TileLayout(T.S[(2, 4, 1024, 128) : (1024@F, 2048@F, 1@F, 1@P)])) + C = T.alloc_buffer((1, 128), scope="trn.sbuf", + layout=T.TileLayout(T.S[(1, 128) : (1@F, 1@P)])) for i in range(2): Tx.unary_reduce(B[0:16, 0:512, 0:128], C[0, 0:128], A[i * 16:i * 16 + 16, 0:512, 0:128], "sqrt", "sum", None, None, [0, 1], workspace={"const_bias": const_bias, "partial_reduce": partial_reduce}) # noqa: E501 # fmt: on @@ -320,31 +321,31 @@ def test_workspace_reuse(): src_layout = TileLayout(S[(128, 4096) : (1 @ P, 1 @ F)]) dst_shape = src_shape dst_layout = src_layout - scale = Tx.float32(2.0) + scale = T.float32(2.0) # fmt: off - @Tx.prim_func + @T.prim_func def unary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.exp(C_sbuf, A_sbuf, bias=0.0, scale=scale, max_inst_size=1024) Tx.exp(C_sbuf, C_sbuf) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "unary"}) - Tx.device_entry() - const_bias = Tx.alloc_buffer((128, 1024), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(1024, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(0.0)) - A_sbuf = Tx.alloc_buffer((512, 1024), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(128, 4096) : (1 @ P, 1 @ F)])) - C_sbuf = Tx.alloc_buffer((512, 1024), scope="trn.sbuf", - layout=Tx.TileLayout(Tx.S[(128, 4096) : (1 @ P, 1 @ F)])) - Tx.exp(C_sbuf[0:512, 0:1024], A_sbuf[0:512, 0:1024], Tx.float32(0.0), Tx.float32(2.0), workspace={"const_bias": const_bias}, max_inst_size=1024) # noqa: E501 + T.func_attr({"global_symbol": "unary"}) + T.device_entry() + const_bias = T.alloc_buffer((128, 1024), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(1024, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(0.0)) + A_sbuf = T.alloc_buffer((512, 1024), scope="trn.sbuf", + layout=T.TileLayout(T.S[(128, 4096) : (1 @ P, 1 @ F)])) + C_sbuf = T.alloc_buffer((512, 1024), scope="trn.sbuf", + layout=T.TileLayout(T.S[(128, 4096) : (1 @ P, 1 @ F)])) + Tx.exp(C_sbuf[0:512, 0:1024], A_sbuf[0:512, 0:1024], T.float32(0.0), T.float32(2.0), workspace={"const_bias": const_bias}, max_inst_size=1024) # noqa: E501 Tx.exp(C_sbuf[0:512, 0:1024], C_sbuf[0:512, 0:1024], None, None, workspace={"const_bias": const_bias}) # noqa: E501 # fmt: on @@ -362,12 +363,12 @@ def test_no_rewrite_with_existing_workspace(): dst_layout = TileLayout(S[(128, 4) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def reduction(): - Tx.device_entry() - intermediate_buffer = Tx.alloc_buffer((128, 64), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + intermediate_buffer = T.alloc_buffer((128, 64), scope="trn.sbuf") + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.sum(B_sbuf, A_sbuf, axes=(1, 3), workspace={"partial_reduce": intermediate_buffer}) # fmt: on with target: @@ -383,12 +384,12 @@ def test_no_rewrite_with_psum_output(): C_layout = TileLayout(S[(128, 128) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def gemm() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer((128, 128), "float32", scope="trn.sbuf", layout=A_layout) - B_sbuf = Tx.alloc_buffer((128, 128), "float32", scope="trn.sbuf", layout=B_layout) - C_psum = Tx.alloc_buffer((128, 128), "float32", scope="trn.psum", layout=C_layout) + T.device_entry() + A_sbuf = T.alloc_buffer((128, 128), "float32", scope="trn.sbuf", layout=A_layout) + B_sbuf = T.alloc_buffer((128, 128), "float32", scope="trn.sbuf", layout=B_layout) + C_psum = T.alloc_buffer((128, 128), "float32", scope="trn.psum", layout=C_layout) Tx.gemm(C_psum, A_sbuf, B_sbuf, C_psum) # fmt: on with target: diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py index fe88accff700..ef8146b76286 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py @@ -19,7 +19,8 @@ import tvm import tvm.testing from tvm.ir import assert_structural_equal as _assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout from tvm.tirx.stmt_functor import ir_transform @@ -28,8 +29,6 @@ def _strip_exec_scope_stmt(stmt): def _postorder(node): - if isinstance(node, tvm.tirx.ExecScopeStmt): - return node.body if isinstance(node, tvm.tirx.AttrStmt) and node.attr_key == "tirx.device_entry": return node.body return node @@ -38,7 +37,7 @@ def _postorder(node): stmt, preorder=lambda _node: None, postorder=_postorder, - only_enable=["tirx.ExecScopeStmt", "tirx.AttrStmt"], + only_enable=["tirx.AttrStmt"], ) @@ -66,27 +65,25 @@ def test_simple_reduction(op_type): tx_func = Tx_func_map[op_type] # fmt: off - @Tx.prim_func + @T.prim_func def reduction() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) tx_func(B_sbuf, A_sbuf, axes=-1) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "reduction"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 1), scope="trn.sbuf") - for b_loop in range(1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.tensorreduce(B_sbuf[p_loop, 0], A_sbuf[p_loop, f_loop], opcode, False, -1) # noqa: E501 - - # fmt: on + T.func_attr({"global_symbol": "reduction"}) + A_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 1), scope="trn.sbuf") + for b_loop in range(1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.tensorreduce(B_sbuf[p_loop, 0], A_sbuf[p_loop, f_loop], opcode, False, -1) + + # fmt: on with target: mod = tvm.IRModule({"main": reduction}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -100,27 +97,25 @@ def test_reduction_with_multiple_axes(): dst_layout = TileLayout(S[128 : 1 @ P]) # fmt: off - @Tx.prim_func + @T.prim_func def reduction(): - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.sum(B_sbuf, A_sbuf, axes=(1, 2), max_inst_size=2048) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "reduction"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 1), scope="trn.sbuf") - for b_loop in range(1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 2048, annotations={"nki_dim":"F"}): - Tx.nki.tensorreduce(B_sbuf[p_loop, 0], A_sbuf[p_loop, f_loop], "add", False, -1) # noqa: E501 - - # fmt: on + T.func_attr({"global_symbol": "reduction"}) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 1), scope="trn.sbuf") + for b_loop in range(1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 2048, annotations={"nki_dim":"F"}): + T.nki.tensorreduce(B_sbuf[p_loop, 0], A_sbuf[p_loop, f_loop], "add", False, -1) + + # fmt: on with target: mod = tvm.IRModule({"main": reduction}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -134,27 +129,25 @@ def test_reduction_in_loop(): dst_layout = TileLayout(S[(128, 4) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def reduction(): - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): Tx.sum(B_sbuf[:, i], A_sbuf[:, :, i], axes=-2) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "reduction"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - for i, b_loop in Tx.grid(4, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.tensorreduce(B_sbuf[p_loop, i], A_sbuf[p_loop, f_loop * 4 + i], "add", False, -1) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "reduction"}) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + for i, b_loop in T.grid(4, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.tensorreduce(B_sbuf[p_loop, i], A_sbuf[p_loop, f_loop * 4 + i], "add", False, -1) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": reduction}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -168,33 +161,31 @@ def test_reduction_two_stage(): dst_layout = TileLayout(S[(128, 4) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def reduction(): - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.sum(B_sbuf, A_sbuf, axes=(1, 3)) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "reduction"}) - - with Tx.thread(): - intermediate_buffer = Tx.alloc_buffer((128, 32), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - for b_loop in range(4): - for reduction_b_loop in range(32): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 32, annotations={"nki_dim":"F"}): - Tx.nki.tensorreduce(intermediate_buffer[p_loop, reduction_b_loop], A_sbuf[p_loop, reduction_b_loop * 128 + b_loop * 32 + f_loop], "add", False, -1) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 32, annotations={"nki_dim":"F"}): - Tx.nki.tensorreduce(B_sbuf[p_loop, b_loop], intermediate_buffer[p_loop, f_loop], "add", False, -1) # noqa: E501 - - # fmt: on + T.func_attr({"global_symbol": "reduction"}) + intermediate_buffer = T.alloc_buffer((128, 32), scope="trn.sbuf") + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + for b_loop in range(4): + for reduction_b_loop in range(32): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 32, annotations={"nki_dim":"F"}): + T.nki.tensorreduce(intermediate_buffer[p_loop, reduction_b_loop], A_sbuf[p_loop, reduction_b_loop * 128 + b_loop * 32 + f_loop], "add", False, -1) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 32, annotations={"nki_dim":"F"}): + T.nki.tensorreduce(B_sbuf[p_loop, b_loop], intermediate_buffer[p_loop, f_loop], "add", False, -1) # noqa: E501 + + # fmt: on with target: mod = tvm.IRModule({"main": reduction}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -209,40 +200,38 @@ def test_reduction_with_guard(): dst_layout = TileLayout(S[(4, 128) : (1 @ F, 1 @ P)]) # fmt: off - @Tx.prim_func + @T.prim_func def reduction() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): for j in range(4): Tx.sum(B_sbuf[0: (i+1) * 128, 0], A_sbuf[0: (i+1) * 128, 0: (j+1) * 256], max_inst_size=512) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "reduction"}) - - with Tx.thread(): - intermediate_buffer = Tx.alloc_buffer((128, 2), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer((128, 8192), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - for i, j in Tx.grid(4, 4): - for b_loop in range(4): - for reduction_b_loop in range(2): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(512, annotations={"nki_dim": "F"}): - if ( - b_loop - i < 1 - and reduction_b_loop * 512 + f_loop < j * 256 + 256 - ): - Tx.nki.tensorreduce(intermediate_buffer[p_loop, reduction_b_loop], A_sbuf[p_loop, b_loop * 2048 + reduction_b_loop * 512 + f_loop], "add", Tx.bool(False), -1) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(2, annotations={"nki_dim": "F"}): - if b_loop - i < 1 and f_loop * 2 - j < 1: - Tx.nki.tensorreduce(B_sbuf[p_loop, b_loop], intermediate_buffer[p_loop, f_loop], "add", Tx.bool(False), -1) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "reduction"}) + intermediate_buffer = T.alloc_buffer((128, 2), scope="trn.sbuf") + A_sbuf = T.alloc_buffer((128, 8192), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + for i, j in T.grid(4, 4): + for b_loop in range(4): + for reduction_b_loop in range(2): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(512, annotations={"nki_dim": "F"}): + if ( + b_loop - i < 1 + and reduction_b_loop * 512 + f_loop < j * 256 + 256 + ): + T.nki.tensorreduce(intermediate_buffer[p_loop, reduction_b_loop], A_sbuf[p_loop, b_loop * 2048 + reduction_b_loop * 512 + f_loop], "add", T.bool(False), -1) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(2, annotations={"nki_dim": "F"}): + if b_loop - i < 1 and f_loop * 2 - j < 1: + T.nki.tensorreduce(B_sbuf[p_loop, b_loop], intermediate_buffer[p_loop, f_loop], "add", T.bool(False), -1) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": reduction}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -258,34 +247,32 @@ def test_reduction_two_stage_workspace(): dst_layout = TileLayout(S[(128, 4) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def reduction(): - Tx.device_entry() - intermediate_buffer = Tx.alloc_buffer((128, 64), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + intermediate_buffer = T.alloc_buffer((128, 64), scope="trn.sbuf") + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.sum(B_sbuf, A_sbuf, axes=(1, 3), workspace={"partial_reduce": intermediate_buffer}) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "reduction"}) - - with Tx.thread(): - intermediate_buffer = Tx.alloc_buffer((128, 64), scope="trn.sbuf") - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - for b_loop in range(4): - for reduction_b_loop in range(32): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 32, annotations={"nki_dim":"F"}): - Tx.nki.tensorreduce(intermediate_buffer[p_loop, reduction_b_loop], A_sbuf[p_loop, reduction_b_loop * 128 + b_loop * 32 + f_loop], "add", False, -1) # noqa: E501 - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 32, annotations={"nki_dim":"F"}): - Tx.nki.tensorreduce(B_sbuf[p_loop, b_loop], intermediate_buffer[p_loop, f_loop], "add", False, -1) # noqa: E501 - - # fmt: on + T.func_attr({"global_symbol": "reduction"}) + intermediate_buffer = T.alloc_buffer((128, 64), scope="trn.sbuf") + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + for b_loop in range(4): + for reduction_b_loop in range(32): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 32, annotations={"nki_dim":"F"}): + T.nki.tensorreduce(intermediate_buffer[p_loop, reduction_b_loop], A_sbuf[p_loop, reduction_b_loop * 128 + b_loop * 32 + f_loop], "add", False, -1) # noqa: E501 + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 32, annotations={"nki_dim":"F"}): + T.nki.tensorreduce(B_sbuf[p_loop, b_loop], intermediate_buffer[p_loop, f_loop], "add", False, -1) # noqa: E501 + + # fmt: on with target: mod = tvm.IRModule({"main": reduction}) mod = tvm.tirx.transform.LowerTIRx()(mod) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_select_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_select_trn.py index 8daa7205ca0e..477620eb7a9d 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_select_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_select_trn.py @@ -18,7 +18,8 @@ import tvm import tvm.testing from tvm.ir import assert_structural_equal as _assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout from tvm.tirx.stmt_functor import ir_transform @@ -27,8 +28,6 @@ def _strip_exec_scope_stmt(stmt): def _postorder(node): - if isinstance(node, tvm.tirx.ExecScopeStmt): - return node.body if isinstance(node, tvm.tirx.AttrStmt) and node.attr_key == "tirx.device_entry": return node.body return node @@ -37,7 +36,7 @@ def _postorder(node): stmt, preorder=lambda _node: None, postorder=_postorder, - only_enable=["tirx.ExecScopeStmt", "tirx.AttrStmt"], + only_enable=["tirx.AttrStmt"], ) @@ -56,26 +55,24 @@ def test_select(): dst_layout = TileLayout(S[(128, 512) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def select() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.select(B_sbuf, A_sbuf, 0.0, lambda i, j: i < j) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "select"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - for b_loop in Tx.serial(0, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.affine_select(B_sbuf[p_loop, f_loop], p_loop < f_loop, A_sbuf[p_loop, f_loop], Tx.float32(0.0)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "select"}) + A_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + for b_loop in T.serial(0, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.affine_select(B_sbuf[p_loop, f_loop], p_loop < f_loop, A_sbuf[p_loop, f_loop], T.float32(0.0)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": select}) @@ -91,28 +88,26 @@ def test_select_in_loop(): dst_layout = TileLayout(S[(128, 512) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func + @T.prim_func def select() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(2): Tx.select(B_sbuf, A_sbuf[i*16, :, :], 0.0, lambda a, b: (i+1)* a < b) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "select"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 16384), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - for i, b_loop in Tx.grid(2, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.affine_select(B_sbuf[p_loop, f_loop], (i + 1) * p_loop < f_loop, A_sbuf[p_loop, i * 8192 + f_loop], Tx.float32(0.0)) # noqa: E501 - - # fmt: on + T.func_attr({"global_symbol": "select"}) + A_sbuf = T.alloc_buffer((128, 16384), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + for i, b_loop in T.grid(2, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.affine_select(B_sbuf[p_loop, f_loop], (i + 1) * p_loop < f_loop, A_sbuf[p_loop, i * 8192 + f_loop], T.float32(0.0)) # noqa: E501 + + # fmt: on with target: mod = tvm.IRModule({"main": select}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -127,26 +122,24 @@ def test_select_expr_affine(): dst_layout = src_layout # fmt: off - @Tx.prim_func + @T.prim_func def select() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.select(B_sbuf, A_sbuf, 0.0, lambda i, j: i < j) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "select"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for b_loop in Tx.serial(0, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.affine_select(B_sbuf[p_loop, b_loop * 512 + f_loop], b_loop * 128 + p_loop < f_loop, A_sbuf[p_loop, b_loop * 512 + f_loop], Tx.float32(0.0)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "select"}) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for b_loop in T.serial(0, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.affine_select(B_sbuf[p_loop, b_loop * 512 + f_loop], b_loop * 128 + p_loop < f_loop, A_sbuf[p_loop, b_loop * 512 + f_loop], T.float32(0.0)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": select}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -161,29 +154,27 @@ def test_select_with_guard(): dst_layout = src_layout # fmt: off - @Tx.prim_func + @T.prim_func def select() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): for j in range(4): Tx.select(B_sbuf[0: (i+1) * 128, 0: (j+1) * 128], A_sbuf[0: (i+1) * 128, 0: (j+1) * 128], 0.0, lambda a, b: a < b) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "select"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - for i, j, b_loop in Tx.grid(4, 4, 4): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - if b_loop - i < 1 and f_loop < j * 128 + 128: - Tx.nki.affine_select(B_sbuf[p_loop, b_loop * 512 + f_loop], b_loop * 128 + p_loop < f_loop, A_sbuf[p_loop, b_loop * 512 + f_loop], Tx.float32(0.0)) # noqa: E501 - # fmt: on + T.func_attr({"global_symbol": "select"}) + A_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + for i, j, b_loop in T.grid(4, 4, 4): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + if b_loop - i < 1 and f_loop < j * 128 + 128: + T.nki.affine_select(B_sbuf[p_loop, b_loop * 512 + f_loop], b_loop * 128 + p_loop < f_loop, A_sbuf[p_loop, b_loop * 512 + f_loop], T.float32(0.0)) # noqa: E501 + # fmt: on with target: mod = tvm.IRModule({"main": select}) mod = tvm.tirx.transform.LowerTIRx()(mod) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py index 588f0999d70c..db6e968b36a3 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py @@ -19,7 +19,8 @@ import tvm import tvm.testing from tvm.ir import assert_structural_equal as _assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout from tvm.tirx.stmt_functor import ir_transform @@ -28,8 +29,6 @@ def _strip_exec_scope_stmt(stmt): def _postorder(node): - if isinstance(node, tvm.tirx.ExecScopeStmt): - return node.body if isinstance(node, tvm.tirx.AttrStmt) and node.attr_key == "tirx.device_entry": return node.body return node @@ -38,7 +37,7 @@ def _postorder(node): stmt, preorder=lambda _node: None, postorder=_postorder, - only_enable=["tirx.ExecScopeStmt", "tirx.AttrStmt"], + only_enable=["tirx.AttrStmt"], ) @@ -56,40 +55,38 @@ def assert_structural_equal(lhs, rhs, *args, **kwargs): @pytest.mark.parametrize("op_type", ["reciprocal", "memset"]) def test_simple_unary(op_type): src_shape = [128, 512] - src_layout = Tx.TileLayout(Tx.S[(128, 512) : (1 @ P, 1 @ F)]) + src_layout = T.TileLayout(T.S[(128, 512) : (1 @ P, 1 @ F)]) dst_shape = [128, 512] - dst_layout = Tx.TileLayout(Tx.S[(128, 512) : (1 @ P, 1 @ F)]) + dst_layout = T.TileLayout(T.S[(128, 512) : (1 @ P, 1 @ F)]) tx_func = Tx_func_map[op_type] # fmt: off - @Tx.prim_func + @T.prim_func def unary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) if op_type == "memset": - tx_func(B_sbuf, Tx.float32(0.0)) + tx_func(B_sbuf, T.float32(0.0)) else: tx_func(B_sbuf, A_sbuf) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "unary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - for b_loop in Tx.serial(0, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - if op_type == "reciprocal": - Tx.nki.reciprocal( - B_sbuf[p_loop, f_loop], A_sbuf[p_loop, f_loop] - ) - elif op_type == "memset": - Tx.nki.memset(B_sbuf[p_loop, f_loop], 0.0) - # fmt: on + T.func_attr({"global_symbol": "unary"}) + A_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 512), scope="trn.sbuf") + for b_loop in T.serial(0, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + if op_type == "reciprocal": + T.nki.reciprocal( + B_sbuf[p_loop, f_loop], A_sbuf[p_loop, f_loop] + ) + elif op_type == "memset": + T.nki.memset(B_sbuf[p_loop, f_loop], 0.0) + # fmt: on with target: mod = tvm.IRModule({"main": unary}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -99,44 +96,42 @@ def expected(): @pytest.mark.parametrize("op_type", ["reciprocal", "memset"]) def test_unary_in_a_loop(op_type): src_shape = [1024, 512] - src_layout = Tx.TileLayout(Tx.S[(128, 4096) : (1 @ P, 1 @ F)]) + src_layout = T.TileLayout(T.S[(128, 4096) : (1 @ P, 1 @ F)]) dst_shape = [512, 512] - dst_layout = Tx.TileLayout(Tx.S[(128, 2048) : (1 @ P, 1 @ F)]) + dst_layout = T.TileLayout(T.S[(128, 2048) : (1 @ P, 1 @ F)]) Tx_func = Tx_func_map[op_type] # fmt: off - @Tx.prim_func + @T.prim_func def unary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) A_sbuf_view = A_sbuf.view(128, 8, 512) B_sbuf_view = B_sbuf.view(128, 4, 512) for i in range(4): if op_type == "memset": - Tx_func(B_sbuf_view[:, i, :], Tx.float32(0.0)) + Tx_func(B_sbuf_view[:, i, :], T.float32(0.0)) else: Tx_func(B_sbuf_view[:, i, :], A_sbuf_view[:, i * 2, :]) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "unary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 2048), scope="trn.sbuf") - A_sbuf_view = Tx.decl_buffer((128, 4096), data=A_sbuf.data, scope="trn.sbuf", layout=None) # noqa: E501 - B_sbuf_view = Tx.decl_buffer((128, 2048), data=B_sbuf.data, scope="trn.sbuf", layout=None) # noqa: E501 - for i, b_loop in Tx.grid(4, 1): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - if op_type == "reciprocal": - Tx.nki.reciprocal(B_sbuf_view[p_loop, i * 512 + f_loop], A_sbuf_view[p_loop, i * 1024 + f_loop]) # noqa: E501 - elif op_type == "memset": - Tx.nki.memset(B_sbuf[p_loop, i * 512 + f_loop], 0.0) - # fmt: on + T.func_attr({"global_symbol": "unary"}) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 2048), scope="trn.sbuf") + A_sbuf_view = T.decl_buffer((128, 4096), data=A_sbuf.data, scope="trn.sbuf", layout=None) + B_sbuf_view = T.decl_buffer((128, 2048), data=B_sbuf.data, scope="trn.sbuf", layout=None) + for i, b_loop in T.grid(4, 1): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + if op_type == "reciprocal": + T.nki.reciprocal(B_sbuf_view[p_loop, i * 512 + f_loop], A_sbuf_view[p_loop, i * 1024 + f_loop]) # noqa: E501 + elif op_type == "memset": + T.nki.memset(B_sbuf[p_loop, i * 512 + f_loop], 0.0) + # fmt: on with target: mod = tvm.IRModule({"main": unary}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -148,24 +143,22 @@ def test_unary_complex1(): dst_shape = [4096, 256] # fmt: off - @Tx.prim_func + @T.prim_func def unary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) - Tx.memset(A_sbuf, Tx.float32(0.0)) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + Tx.memset(A_sbuf, T.float32(0.0)) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "unary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 8192), scope="trn.sbuf") - for b_loop in Tx.serial(0, 16): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.memset(A_sbuf[p_loop, b_loop * 512 + f_loop], Tx.float32(0.0)) - # fmt: on + T.func_attr({"global_symbol": "unary"}) + A_sbuf = T.alloc_buffer((128, 8192), scope="trn.sbuf") + for b_loop in T.serial(0, 16): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.memset(A_sbuf[p_loop, b_loop * 512 + f_loop], T.float32(0.0)) + # fmt: on with target: mod = tvm.IRModule({"main": unary}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -180,32 +173,30 @@ def test_unary_with_bias_scale(op_type): dst_layout = src_layout bias_shape = [512, 1] bias_layout = TileLayout(S[(128, 4) : (1 @ P, 1 @ F)]) - scale = Tx.float32(2.0) + scale = T.float32(2.0) tx_func = Tx_func_map[op_type] # fmt: off - @Tx.prim_func + @T.prim_func def unary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(bias_shape, "float32", scope="trn.sbuf", layout=bias_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(bias_shape, "float32", scope="trn.sbuf", layout=bias_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) tx_func(C_sbuf, A_sbuf, bias=B_sbuf, scale=scale) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "unary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - for b_loop in Tx.serial(0, 8): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - Tx.nki.activation(C_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], op_type, B_sbuf[p_loop, b_loop//2], Tx.float32(2.0)) # noqa: E501 - # fmt: off + T.func_attr({"global_symbol": "unary"}) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + for b_loop in T.serial(0, 8): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + T.nki.activation(C_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], op_type, B_sbuf[p_loop, b_loop//2], T.float32(2.0)) # noqa: E501 + # fmt: off with target: mod = tvm.IRModule({"main": unary}) mod = tvm.tirx.transform.LowerTIRx()(mod) @@ -218,36 +209,34 @@ def test_unary_with_bias_scale_2(op_type): src_layout = TileLayout(S[(128, 4096) : (1 @ P, 1 @ F)]) dst_shape = src_shape dst_layout = src_layout - bias = Tx.float32(1.0) - scale = Tx.float32(2.0) + bias = T.float32(1.0) + scale = T.float32(2.0) tx_func = Tx_func_map[op_type] # fmt: off - @Tx.prim_func + @T.prim_func def unary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) tx_func(C_sbuf, A_sbuf, bias=bias, scale=scale) - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "unary"}) - - with Tx.thread(): - const_bias = Tx.alloc_buffer((128, 512), scope="trn.sbuf") - with Tx.attr(0, "tensorized_nki_instruction", 1): - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(512, annotations={"nki_dim": "F"}): - Tx.nki.memset(const_bias[p_loop, f_loop], Tx.float32(1.0)) - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - for b_loop in Tx.serial(0, 8): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(128, annotations={"nki_dim": "P"}): - for f_loop in Tx.serial(512, annotations={"nki_dim": "F"}): - Tx.nki.activation(C_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], op_type, const_bias[p_loop, f_loop], Tx.float32(2.0)) # noqa: E501 - # fmt: off + T.func_attr({"global_symbol": "unary"}) + const_bias = T.alloc_buffer((128, 512), scope="trn.sbuf") + with T.attr(0, "tensorized_nki_instruction", 1): + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(512, annotations={"nki_dim": "F"}): + T.nki.memset(const_bias[p_loop, f_loop], T.float32(1.0)) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + for b_loop in T.serial(0, 8): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(128, annotations={"nki_dim": "P"}): + for f_loop in T.serial(512, annotations={"nki_dim": "F"}): + T.nki.activation(C_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], op_type, const_bias[p_loop, f_loop], T.float32(2.0)) # noqa: E501 + # fmt: off with target: mod = tvm.IRModule({"main": unary}) mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) @@ -262,34 +251,32 @@ def test_unary_with_guard(): dst_layout = src_layout bias_shape = [512, 1] bias_layout = TileLayout(S[(4, 128) : (1 @ F, 1 @ P)]) - scale = Tx.float32(2.0) + scale = T.float32(2.0) # fmt: off - @Tx.prim_func + @T.prim_func def unary() -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) - B_sbuf = Tx.alloc_buffer(bias_shape, "float32", scope="trn.sbuf", layout=bias_layout) - C_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(src_shape, "float32", scope="trn.sbuf", layout=src_layout) + B_sbuf = T.alloc_buffer(bias_shape, "float32", scope="trn.sbuf", layout=bias_layout) + C_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) for i in range(4): for j in range(4): Tx.sqrt(C_sbuf[0: (i+1) * 128, 0: (j+1)*256], A_sbuf[0: (i+1) * 128, 0: (j+1)*256], bias=B_sbuf[0: (i+1) * 128, 0], scale=scale) # noqa: E501 - @Tx.prim_func + @T.prim_func def expected(): - Tx.func_attr({"global_symbol": "unary"}) - - with Tx.thread(): - A_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - B_sbuf = Tx.alloc_buffer((128, 4), scope="trn.sbuf") - C_sbuf = Tx.alloc_buffer((128, 4096), scope="trn.sbuf") - for i, j, b_loop in Tx.grid(4, 4, 8): - Tx.attr(0, "tensorized_nki_instruction", 1) - for p_loop in Tx.serial(0, 128, annotations={"nki_dim":"P"}): - for f_loop in Tx.serial(0, 512, annotations={"nki_dim":"F"}): - if b_loop // 2 - i < 1 and b_loop % 2 * 512 + f_loop < j * 256 + 256: - Tx.nki.activation(C_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], "sqrt", B_sbuf[p_loop, b_loop // 2], Tx.float32(2.0)) # noqa: E501 - # fmt: off + T.func_attr({"global_symbol": "unary"}) + A_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + B_sbuf = T.alloc_buffer((128, 4), scope="trn.sbuf") + C_sbuf = T.alloc_buffer((128, 4096), scope="trn.sbuf") + for i, j, b_loop in T.grid(4, 4, 8): + T.attr(0, "tensorized_nki_instruction", 1) + for p_loop in T.serial(0, 128, annotations={"nki_dim":"P"}): + for f_loop in T.serial(0, 512, annotations={"nki_dim":"F"}): + if b_loop // 2 - i < 1 and b_loop % 2 * 512 + f_loop < j * 256 + 256: + T.nki.activation(C_sbuf[p_loop, b_loop * 512 + f_loop], A_sbuf[p_loop, b_loop * 512 + f_loop], "sqrt", B_sbuf[p_loop, b_loop // 2], T.float32(2.0)) # noqa: E501 + # fmt: off with target: mod = tvm.IRModule({"main": unary}) mod = tvm.tirx.transform.LowerTIRx()(mod) diff --git a/tests/python/tirx/test_buffer_print.py b/tests/python/tirx/test_buffer_print.py index 1049a9d486a5..211f4d390313 100644 --- a/tests/python/tirx/test_buffer_print.py +++ b/tests/python/tirx/test_buffer_print.py @@ -21,7 +21,7 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T def generate_random_data(shape, dtype): @@ -193,17 +193,17 @@ def test_vector_add_1D(dtype, dtype_str): C_np = A_np + B_np A_tvm, B_tvm = create_tvm_arrays([A_np, B_np], DEV) - @Tx.prim_func(s_tir=True) - def add_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M,), dtype_str) - B = Tx.match_buffer(B_ptr, (M,), dtype_str) - C = Tx.match_buffer(C_ptr, (M,), dtype_str) + @T.prim_func(s_tir=True) + def add_func(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M,), dtype_str) + B = T.match_buffer(B_ptr, (M,), dtype_str) + C = T.match_buffer(C_ptr, (M,), dtype_str) - for i in Tx.grid(M): - with Tx.sblock("C"): - vi = Tx.axis.spatial(M, i) + for i in T.grid(M): + with T.sblock("C"): + vi = T.axis.spatial(M, i) C[vi] = A[vi] + B[vi] - Tx.print_buffer(C.data, dtype_str, False, False, dim_num, (M,)) + T.print_buffer(C.data, dtype_str, False, False, dim_num, (M,)) sch = tvm.s_tir.Schedule(add_func) blk = sch.get_sblock("C") @@ -229,18 +229,18 @@ def test_vector_add_2D(dtype, dtype_str): C_np = A_np + B_np A_tvm, B_tvm = create_tvm_arrays([A_np, B_np], DEV) - @Tx.prim_func(s_tir=True) - def add_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M, N), dtype_str) - B = Tx.match_buffer(B_ptr, (M, N), dtype_str) - C = Tx.match_buffer(C_ptr, (M, N), dtype_str) + @T.prim_func(s_tir=True) + def add_func(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M, N), dtype_str) + B = T.match_buffer(B_ptr, (M, N), dtype_str) + C = T.match_buffer(C_ptr, (M, N), dtype_str) - for i, j in Tx.grid(M, N): - with Tx.sblock("C"): - vi = Tx.axis.spatial(M, i) - vj = Tx.axis.spatial(N, j) + for i, j in T.grid(M, N): + with T.sblock("C"): + vi = T.axis.spatial(M, i) + vj = T.axis.spatial(N, j) C[vi, vj] = A[vi, vj] + B[vi, vj] - Tx.print_buffer(C.data, C.dtype, False, False, dim_num, (M, N)) + T.print_buffer(C.data, C.dtype, False, False, dim_num, (M, N)) sch = tvm.s_tir.Schedule(add_func) blk = sch.get_sblock("C") @@ -270,19 +270,19 @@ def test_vector_add_3D(dtype, dtype_str): A_tvm, B_tvm = create_tvm_arrays([A_np, B_np], DEV) - @Tx.prim_func(s_tir=True) - def add_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M, N, K), dtype_str) - B = Tx.match_buffer(B_ptr, (M, N, K), dtype_str) - C = Tx.match_buffer(C_ptr, (M, N, K), dtype_str) - - for i, j, k in Tx.grid(M, N, K): - with Tx.sblock("C"): - vi = Tx.axis.spatial(M, i) - vj = Tx.axis.spatial(N, j) - vk = Tx.axis.spatial(K, k) + @T.prim_func(s_tir=True) + def add_func(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M, N, K), dtype_str) + B = T.match_buffer(B_ptr, (M, N, K), dtype_str) + C = T.match_buffer(C_ptr, (M, N, K), dtype_str) + + for i, j, k in T.grid(M, N, K): + with T.sblock("C"): + vi = T.axis.spatial(M, i) + vj = T.axis.spatial(N, j) + vk = T.axis.spatial(K, k) C[vi, vj, vk] = A[vi, vj, vk] + B[vi, vj, vk] - Tx.print_buffer(C.data, C.dtype, False, False, dim_num, (M, N, K)) + T.print_buffer(C.data, C.dtype, False, False, dim_num, (M, N, K)) sch = tvm.s_tir.Schedule(add_func) blk = sch.get_sblock("C") @@ -314,18 +314,18 @@ def test_const_scalar(dtype, dtype_str): C_np = A_np + B_np A_tvm, B_tvm = create_tvm_arrays([A_np, B_np], DEV) - @Tx.prim_func(s_tir=True) - def add_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M,), dtype_str) - B = Tx.match_buffer(B_ptr, (M,), dtype_str) - C = Tx.match_buffer(C_ptr, (M,), dtype_str) - Ten: Tx.let = Tx.IntImm(dtype_str, 10) + @T.prim_func(s_tir=True) + def add_func(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M,), dtype_str) + B = T.match_buffer(B_ptr, (M,), dtype_str) + C = T.match_buffer(C_ptr, (M,), dtype_str) + Ten: T.let = T.IntImm(dtype_str, 10) - for i in Tx.grid(M): - with Tx.sblock("C"): - vi = Tx.axis.spatial(M, i) + for i in T.grid(M): + with T.sblock("C"): + vi = T.axis.spatial(M, i) C[vi] = A[vi] + B[vi] - Tx.print_buffer(Ten, "int32", False, True, dim_num, ()) + T.print_buffer(Ten, "int32", False, True, dim_num, ()) sch = tvm.s_tir.Schedule(add_func) blk = sch.get_sblock("C") @@ -351,18 +351,18 @@ def test_string(dtype, dtype_str, test_string): C_np = A_np + B_np A_tvm, B_tvm = create_tvm_arrays([A_np, B_np], DEV) - @Tx.prim_func(s_tir=True) - def add_func(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (M,), dtype_str) - B = Tx.match_buffer(B_ptr, (M,), dtype_str) - C = Tx.match_buffer(C_ptr, (M,), dtype_str) - string_var = Tx.StringImm(test_string) + @T.prim_func(s_tir=True) + def add_func(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (M,), dtype_str) + B = T.match_buffer(B_ptr, (M,), dtype_str) + C = T.match_buffer(C_ptr, (M,), dtype_str) + string_var = T.StringImm(test_string) - for i in Tx.grid(M): - with Tx.sblock("C"): - vi = Tx.axis.spatial(M, i) + for i in T.grid(M): + with T.sblock("C"): + vi = T.axis.spatial(M, i) C[vi] = A[vi] + B[vi] - Tx.print_buffer(string_var, "int8", True, False, dim_num, ()) + T.print_buffer(string_var, "int8", True, False, dim_num, ()) sch = tvm.s_tir.Schedule(add_func) blk = sch.get_sblock("C") diff --git a/tests/python/tirx/test_control_flow.py b/tests/python/tirx/test_control_flow.py index 8e0522cc7e2c..1f905bd03cc9 100644 --- a/tests/python/tirx/test_control_flow.py +++ b/tests/python/tirx/test_control_flow.py @@ -17,7 +17,7 @@ import numpy as np import tvm -from tvm.script import tirx as Tx +from tvm.script import tirx as T def run_test_break_continue(func, shape, expected): @@ -34,20 +34,19 @@ def run_test_break_continue(func, shape, expected): def test_break_continue1(): # fmt: off - @Tx.prim_func - def func(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (10,), "int32") - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([32]) - with Tx.thread(): - for i in Tx.serial(10): - if i == 2: - continue - if i == 7: - break - A[i] = i + @T.prim_func + def func(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (10,), "int32") + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([32]) + for i in T.serial(10): + if i == 2: + continue + if i == 7: + break + A[i] = i # fmt: on expected = np.array([0, 1, 0, 3, 4, 5, 6, 0, 0, 0], dtype="int32") @@ -56,25 +55,24 @@ def func(A_ptr: Tx.handle): def test_break_continue2(): # fmt: off - @Tx.prim_func - def func(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (9,), "int32") - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([32]) - with Tx.thread(): - idx = Tx.alloc_buffer((1,), "int32", scope="local") - idx[0] = 0 - for i in Tx.serial(3): - if i == 0: - idx[0] += 1 - continue - for j in Tx.serial(3): - A[idx[0]] = i * 10 + j - idx[0] += 1 - if j == 1: - break + @T.prim_func + def func(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (9,), "int32") + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([32]) + idx = T.alloc_buffer((1,), "int32", scope="local") + idx[0] = 0 + for i in T.serial(3): + if i == 0: + idx[0] += 1 + continue + for j in T.serial(3): + A[idx[0]] = i * 10 + j + idx[0] += 1 + if j == 1: + break # fmt: on expected = np.array([0, 10, 11, 20, 21, 0, 0, 0, 0], dtype="int32") @@ -83,24 +81,23 @@ def func(A_ptr: Tx.handle): def test_break_continue3(): # fmt: off - @Tx.prim_func - def func(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (10,), "int32") - - Tx.device_entry() - cta_id = Tx.cta_id([1]) - tid = Tx.thread_id([32]) - with Tx.thread(): - i = Tx.alloc_buffer((1,), "int32", scope="local") - i[0] = 0 - while i[0] < 10: - if (i[0] % 2) == 1: - i[0] += 1 - continue - A[i[0]] = i[0] + @T.prim_func + def func(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (10,), "int32") + + T.device_entry() + cta_id = T.cta_id([1]) + tid = T.thread_id([32]) + i = T.alloc_buffer((1,), "int32", scope="local") + i[0] = 0 + while i[0] < 10: + if (i[0] % 2) == 1: i[0] += 1 - if i[0] == 7: - break + continue + A[i[0]] = i[0] + i[0] += 1 + if i[0] == 7: + break # fmt: on expected = np.array([0, 0, 2, 0, 4, 0, 6, 0, 0, 0], dtype="int32") diff --git a/tests/python/tirx/test_hint.py b/tests/python/tirx/test_hint.py index 27076db0cf5f..88daad9b188a 100644 --- a/tests/python/tirx/test_hint.py +++ b/tests/python/tirx/test_hint.py @@ -34,15 +34,11 @@ def test_hint_statement(): @T.prim_func def func(A_ptr: T.handle) -> None: _A = T.match_buffer(A_ptr, (64,), "float32", scope="global") - with T.thread(): - bx, by, bz = T.cta_id([1, 1, 1]) - warp_id = T.warp_id([1]) - lane_id = T.lane_id([32]) - with T.cta(): - with T.warp(): - with T.thread(): - T.hint("persistent tile scheduler with L2 swizzle") - T.evaluate(0) + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + T.hint("persistent tile scheduler with L2 swizzle") + T.evaluate(0) # Walk the IR to find the AttrStmt with tirx_hint found = [False] @@ -64,15 +60,11 @@ def test_hint_context_manager(): @T.prim_func def func(A_ptr: T.handle) -> None: _A = T.match_buffer(A_ptr, (64,), "float32", scope="global") - with T.thread(): - bx, by, bz = T.cta_id([1, 1, 1]) - warp_id = T.warp_id([1]) - lane_id = T.lane_id([32]) - with T.cta(): - with T.warp(): - with T.thread(): - with T.hint("software pipeline, depth 4"): - T.evaluate(0) + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + with T.hint("software pipeline, depth 4"): + T.evaluate(0) found = [False] @@ -92,15 +84,11 @@ def test_hint_with_attrs(): @T.prim_func def func(A_ptr: T.handle) -> None: _A = T.match_buffer(A_ptr, (64,), "float32", scope="global") - with T.thread(): - bx, by, bz = T.cta_id([1, 1, 1]) - warp_id = T.warp_id([1]) - lane_id = T.lane_id([32]) - with T.cta(): - with T.warp(): - with T.thread(): - T.hint("scheduler", mode="persistent", depth="4") - T.evaluate(0) + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + T.hint("scheduler", mode="persistent", depth="4") + T.evaluate(0) found = [False] @@ -122,15 +110,11 @@ def test_hint_printer_roundtrip_statement(): @T.prim_func def func(A_ptr: T.handle) -> None: _A = T.match_buffer(A_ptr, (64,), "float32", scope="global") - with T.thread(): - bx, by, bz = T.cta_id([1, 1, 1]) - warp_id = T.warp_id([1]) - lane_id = T.lane_id([32]) - with T.cta(): - with T.warp(): - with T.thread(): - T.hint("persistent tile scheduler with L2 swizzle") - T.evaluate(0) + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + T.hint("persistent tile scheduler with L2 swizzle") + T.evaluate(0) code = func.script() assert 'hint("persistent tile scheduler with L2 swizzle")' in code @@ -144,15 +128,11 @@ def test_hint_printer_roundtrip_context_manager(): @T.prim_func def func(A_ptr: T.handle) -> None: _A = T.match_buffer(A_ptr, (64,), "float32", scope="global") - with T.thread(): - bx, by, bz = T.cta_id([1, 1, 1]) - warp_id = T.warp_id([1]) - lane_id = T.lane_id([32]) - with T.cta(): - with T.warp(): - with T.thread(): - with T.hint("software pipeline, depth 4"): - T.evaluate(0) + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + with T.hint("software pipeline, depth 4"): + T.evaluate(0) code = func.script() assert 'hint("software pipeline, depth 4")' in code @@ -166,15 +146,11 @@ def test_hint_printer_roundtrip_with_attrs(): @T.prim_func def func(A_ptr: T.handle) -> None: _A = T.match_buffer(A_ptr, (64,), "float32", scope="global") - with T.thread(): - bx, by, bz = T.cta_id([1, 1, 1]) - warp_id = T.warp_id([1]) - lane_id = T.lane_id([32]) - with T.cta(): - with T.warp(): - with T.thread(): - T.hint("scheduler", mode="persistent") - T.evaluate(0) + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + T.hint("scheduler", mode="persistent") + T.evaluate(0) code = func.script() assert 'hint("scheduler"' in code @@ -194,7 +170,7 @@ def test_hint_keyword_arg_on_tx_op(): op_call = TilePrimitiveCall( A[0:64, 0:64], A_sm[0:64, 0:64], - op=tvm.ir.Op.get("tirx.copy"), + op=tvm.ir.Op.get("tirx.tile.copy"), workspace={}, config={"hint": "3-input ptx"}, ) @@ -204,14 +180,13 @@ def test_hint_keyword_arg_on_tx_op(): def test_hint_keyword_arg_on_tx_op_roundtrip(): """Tx.op(..., hint="msg") roundtrips through printer/parser.""" - from tvm.script import tirx as Tx + from tvm.script.tirx import tile as Tx @T.prim_func def func(A_ptr: T.handle, B_ptr: T.handle): A = T.match_buffer(A_ptr, [10], "float32", scope="global") B = T.match_buffer(B_ptr, [10], "float32", scope="global") - with T.thread(): - Tx.add(B, A, T.float32(1), hint="use_fast_math") + Tx.add(B, A, T.float32(1), hint="use_fast_math") code = func.script() assert 'hint="use_fast_math"' in code @@ -226,15 +201,11 @@ def test_hint_no_message(): @T.prim_func def func(A_ptr: T.handle) -> None: A = T.match_buffer(A_ptr, (128,), "float32", scope="global") - with T.thread(): - bx, by, bz = T.cta_id([1, 1, 1]) - warp_id = T.warp_id([1]) - lane_id = T.lane_id([32]) - with T.cta(): - with T.warp(): - with T.thread(): - T.hint(access=A[0:64]) - T.evaluate(0) + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + T.hint(access=A[0:64]) + T.evaluate(0) found = [False] @@ -259,15 +230,11 @@ def test_hint_access_buffer_region(): @T.prim_func def func(A_ptr: T.handle) -> None: A = T.match_buffer(A_ptr, (128, 64), "float32", scope="global") - with T.thread(): - bx, by, bz = T.cta_id([2, 1, 1]) - warp_id = T.warp_id([1]) - lane_id = T.lane_id([32]) - with T.cta(): - with T.warp(): - with T.thread(): - T.hint("partition", access=A[bx * 64 : (bx + 1) * 64, 0:64]) - T.evaluate(0) + bx, by, bz = T.cta_id([2, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + T.hint("partition", access=A[bx * 64 : (bx + 1) * 64, 0:64]) + T.evaluate(0) found = [False] diff --git a/tests/python/tirx/test_inline.py b/tests/python/tirx/test_inline.py index 33a65aab06ee..438c187c6c7c 100644 --- a/tests/python/tirx/test_inline.py +++ b/tests/python/tirx/test_inline.py @@ -14,11 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Tests for T.inline / Tx.inline with Python LEGB scoping semantics.""" +"""Tests for T.inline / T.inline with Python LEGB scoping semantics.""" from tvm.ir import assert_structural_equal from tvm.script import tirx as T -from tvm.script import tirx as Tx # Module-level constant for testing global visibility MODULE_CONST = 42 @@ -201,27 +200,27 @@ def test_recursive_inline(): """Recursive inline (defined inside prim_func).""" # fmt: off - @Tx.prim_func(private=True) + @T.prim_func(private=True) def func(): - Tx.device_entry() - for x in Tx.serial(10): + T.device_entry() + for x in T.serial(10): - @Tx.inline + @T.inline def add(x, c): if c > 0: add(x, c - 1) - Tx.evaluate(x) + T.evaluate(x) add(x, 3) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def expected(): - Tx.device_entry() + T.device_entry() for x in range(10): - Tx.evaluate(x) - Tx.evaluate(x) - Tx.evaluate(x) - Tx.evaluate(x) + T.evaluate(x) + T.evaluate(x) + T.evaluate(x) + T.evaluate(x) # fmt: on assert_structural_equal(func, expected) diff --git a/tests/python/tirx/test_jit.py b/tests/python/tirx/test_jit.py index 637563867c27..393d640b237b 100644 --- a/tests/python/tirx/test_jit.py +++ b/tests/python/tirx/test_jit.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # ruff: noqa: F821 -"""Tests for ``@Tx.jit`` + ``Tx.constexpr``.""" +"""Tests for ``@T.jit`` + ``T.constexpr``.""" from __future__ import annotations @@ -23,26 +23,26 @@ import tvm from tvm.ir import assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T def test_int_constexpr_specializes_loop_bound(): - @Tx.jit(private=True) + @T.jit(private=True) def add( - A: Tx.Buffer((N,), "int32"), - B: Tx.Buffer((N,), "int32"), - C: Tx.Buffer((N,), "int32"), + A: T.Buffer((N,), "int32"), + B: T.Buffer((N,), "int32"), + C: T.Buffer((N,), "int32"), *, - N: Tx.constexpr, + N: T.constexpr, ): for i in range(N): C[i] = A[i] + B[i] - @Tx.prim_func(private=True) + @T.prim_func(private=True) def expected( - A: Tx.Buffer((128,), "int32"), - B: Tx.Buffer((128,), "int32"), - C: Tx.Buffer((128,), "int32"), + A: T.Buffer((128,), "int32"), + B: T.Buffer((128,), "int32"), + C: T.Buffer((128,), "int32"), ): for i in range(128): C[i] = A[i] + B[i] @@ -51,24 +51,24 @@ def expected( def test_constexpr_in_2d_buffer_shape(): - @Tx.jit(private=True) + @T.jit(private=True) def matadd( - A: Tx.Buffer((M, K), "int32"), - B: Tx.Buffer((M, K), "int32"), - C: Tx.Buffer((M, K), "int32"), + A: T.Buffer((M, K), "int32"), + B: T.Buffer((M, K), "int32"), + C: T.Buffer((M, K), "int32"), *, - M: Tx.constexpr, - K: Tx.constexpr, + M: T.constexpr, + K: T.constexpr, ): for m in range(M): for k in range(K): C[m, k] = A[m, k] + B[m, k] - @Tx.prim_func(private=True) + @T.prim_func(private=True) def expected( - A: Tx.Buffer((4, 8), "int32"), - B: Tx.Buffer((4, 8), "int32"), - C: Tx.Buffer((4, 8), "int32"), + A: T.Buffer((4, 8), "int32"), + B: T.Buffer((4, 8), "int32"), + C: T.Buffer((4, 8), "int32"), ): for m in range(4): for k in range(8): @@ -78,21 +78,21 @@ def expected( def test_constexpr_in_body_expression(): - @Tx.jit(private=True) + @T.jit(private=True) def scaled_copy( - A: Tx.Buffer((N,), "int32"), - B: Tx.Buffer((N,), "int32"), + A: T.Buffer((N,), "int32"), + B: T.Buffer((N,), "int32"), *, - N: Tx.constexpr, - SCALE: Tx.constexpr, + N: T.constexpr, + SCALE: T.constexpr, ): for i in range(N): B[i] = A[i] * SCALE - @Tx.prim_func(private=True) + @T.prim_func(private=True) def expected( - A: Tx.Buffer((16,), "int32"), - B: Tx.Buffer((16,), "int32"), + A: T.Buffer((16,), "int32"), + B: T.Buffer((16,), "int32"), ): for i in range(16): B[i] = A[i] * 3 @@ -101,11 +101,11 @@ def expected( def test_specialize_cache_returns_same_instance(): - @Tx.jit(private=True) + @T.jit(private=True) def k( - A: Tx.Buffer((N,), "int32"), + A: T.Buffer((N,), "int32"), *, - N: Tx.constexpr, + N: T.constexpr, ): for i in range(N): A[i] = 0 @@ -116,11 +116,11 @@ def k( def test_specialize_different_args_produce_different_funcs(): - @Tx.jit(private=True) + @T.jit(private=True) def k( - A: Tx.Buffer((N,), "int32"), + A: T.Buffer((N,), "int32"), *, - N: Tx.constexpr, + N: T.constexpr, ): for i in range(N): A[i] = 0 @@ -129,12 +129,12 @@ def k( def test_specialize_missing_constexpr_raises(): - @Tx.jit(private=True) + @T.jit(private=True) def k( - A: Tx.Buffer((N,), "int32"), + A: T.Buffer((N,), "int32"), *, - N: Tx.constexpr, - SCALE: Tx.constexpr, + N: T.constexpr, + SCALE: T.constexpr, ): for i in range(N): A[i] = SCALE @@ -144,11 +144,11 @@ def k( def test_specialize_extra_kwarg_raises(): - @Tx.jit(private=True) + @T.jit(private=True) def k( - A: Tx.Buffer((N,), "int32"), + A: T.Buffer((N,), "int32"), *, - N: Tx.constexpr, + N: T.constexpr, ): for i in range(N): A[i] = 0 @@ -158,22 +158,22 @@ def k( def test_jit_kernel_with_nested_inline_helper(): - @Tx.jit(private=True) + @T.jit(private=True) def k( - A: Tx.Buffer((N,), "int32"), + A: T.Buffer((N,), "int32"), *, - N: Tx.constexpr, + N: T.constexpr, ): - @Tx.inline + @T.inline def double(x): A[x] = A[x] * 2 for i in range(N): double(i) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def expected( - A: Tx.Buffer((4,), "int32"), + A: T.Buffer((4,), "int32"), ): for i in range(4): A[i] = A[i] * 2 @@ -182,19 +182,19 @@ def expected( def test_constexpr_default_value(): - @Tx.jit(private=True) + @T.jit(private=True) def k( - A: Tx.Buffer((N,), "int32"), + A: T.Buffer((N,), "int32"), *, - N: Tx.constexpr, - SCALE: Tx.constexpr = 7, + N: T.constexpr, + SCALE: T.constexpr = 7, ): for i in range(N): A[i] = SCALE - @Tx.prim_func(private=True) + @T.prim_func(private=True) def expected( - A: Tx.Buffer((8,), "int32"), + A: T.Buffer((8,), "int32"), ): for i in range(8): A[i] = 7 @@ -206,11 +206,11 @@ def expected( def test_specialize_returns_primfunc(): - @Tx.jit(private=True) + @T.jit(private=True) def k( - A: Tx.Buffer((N,), "int32"), + A: T.Buffer((N,), "int32"), *, - N: Tx.constexpr, + N: T.constexpr, ): for i in range(N): A[i] = 0 diff --git a/tests/python/tirx/test_layout.py b/tests/python/tirx/test_layout.py index 4cc4fa4b6481..1666d616e663 100644 --- a/tests/python/tirx/test_layout.py +++ b/tests/python/tirx/test_layout.py @@ -25,7 +25,7 @@ from tvm.arith import Analyzer from tvm.ir import assert_structural_equal from tvm.ir.type import PointerType, PrimType -from tvm.script import tirx as Tx +from tvm.script import tirx as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import tirx as Tx_builder from tvm.tirx import Var @@ -1424,7 +1424,7 @@ def test_pool_allocator_alloc_mma(): def alloc_layout(shape, dtype, swizzle_mode="auto"): with IRBuilder(): with Tx_builder.prim_func(): - pool = Tx.SMEMPool(Var("smem_ptr", PointerType(PrimType("uint8")))) + pool = T.SMEMPool(Var("smem_ptr", PointerType(PrimType("uint8")))) buf = pool.alloc_mma(shape, dtype, swizzle_mode=swizzle_mode) return buf.layout diff --git a/tests/python/tirx/test_op.py b/tests/python/tirx/test_op.py index 985240440ddf..480e6cd3ddbc 100644 --- a/tests/python/tirx/test_op.py +++ b/tests/python/tirx/test_op.py @@ -16,16 +16,13 @@ # under the License. import pytest -import tvm from tvm.ir import Op -from tvm.script import tirx as T -from tvm.script import tirx as Tx from tvm.tirx.buffer import decl_buffer from tvm.tirx.stmt import TilePrimitiveCall def _test(op: str, *args): - return TilePrimitiveCall(*args, op=Op.get("tirx." + op), workspace={}, config={}) + return TilePrimitiveCall(*args, op=Op.get("tirx.tile." + op), workspace={}, config={}) def test_copy(): @@ -47,125 +44,6 @@ def test_gemm(): _test("gemm", D[:, :], A[:, :], B[:, :], C[:, :], True, False, 1.0, 0.0) -def test_generic_op_creates_op(): - """GenericOp auto-registers unknown ops.""" - from tvm.tirx.operator.tile_primitive.ops import GenericOp - - A = decl_buffer((64,), "float32", scope="global") - B = decl_buffer((64,), "float32", scope="global") - - op_call = GenericOp(B[0:64], A[0:64], op_name="my_custom_op_1") - assert op_call.op == Op.get("tirx.my_custom_op_1") - assert len(op_call.args) == 2 - - -def test_generic_op_reuses_registered_op(): - """GenericOp reuses already-registered ops without error.""" - from tvm.tirx.operator.tile_primitive.ops import GenericOp - - A = decl_buffer((64,), "float32", scope="global") - B = decl_buffer((64,), "float32", scope="global") - - # Create twice with same name — should not error - op1 = GenericOp(B[0:64], A[0:64], op_name="my_custom_op_2") - op2 = GenericOp(B[0:64], A[0:64], op_name="my_custom_op_2") - assert op1.op == op2.op - - -def test_generic_op_with_existing_tirx_op(): - """GenericOp works with already-registered tirx ops (e.g., tirx.copy).""" - from tvm.tirx.operator.tile_primitive.ops import GenericOp - - A = decl_buffer((64,), "float32", scope="global") - B = decl_buffer((64,), "float32", scope="global") - - op_call = GenericOp(B[0:64], A[0:64], op_name="copy") - assert op_call.op == Op.get("tirx.copy") - - -def test_tx_dynamic_op_module_getattr(): - """Tx.some_undefined_op resolves via module __getattr__.""" - fn = Tx.my_dynamic_test_op - assert callable(fn) - assert fn.__name__ == "my_dynamic_test_op" - - -def test_tx_dynamic_op_in_prim_func(): - """Tx.copy_and_cast(...) works inside a prim_func without pre-registration.""" - - @T.prim_func - def func(A_ptr: T.handle, B_ptr: T.handle): - A = T.match_buffer(A_ptr, [64], "float32", scope="global") - B = T.match_buffer(B_ptr, [64], "float16", scope="global") - with T.thread(): - Tx.copy_and_cast(B, A) - - # Walk IR to find TilePrimitiveCall with op="tirx.copy_and_cast" - found = [False] - - def visit(stmt): - if isinstance(stmt, TilePrimitiveCall) and stmt.op == Op.get("tirx.copy_and_cast"): - found[0] = True - - tvm.tirx.stmt_functor.post_order_visit(func.body, visit) - assert found[0], "Expected TilePrimitiveCall with tirx.copy_and_cast not found" - - -def test_tx_dynamic_op_with_workspace(): - """Tx.some_op(..., workspace={...}) passes workspace to TilePrimitiveCall.""" - - @T.prim_func - def func(A_ptr: T.handle, B_ptr: T.handle, W_ptr: T.handle): - A = T.match_buffer(A_ptr, [64], "float32", scope="global") - B = T.match_buffer(B_ptr, [64], "float32", scope="global") - W = T.match_buffer(W_ptr, [64], "float32", scope="shared") - with T.thread(): - Tx.custom_with_ws(B, A, workspace={"tmp": W}) - - found = [False] - - def visit(stmt): - if isinstance(stmt, TilePrimitiveCall) and stmt.op == Op.get("tirx.custom_with_ws"): - assert "tmp" in stmt.workspace - found[0] = True - - tvm.tirx.stmt_functor.post_order_visit(func.body, visit) - assert found[0], "Expected TilePrimitiveCall with workspace not found" - - -def test_tx_existing_op_not_overridden(): - """Existing Tx.copy still dispatches to the registered copy op, not __getattr__.""" - - @T.prim_func - def func(A_ptr: T.handle, B_ptr: T.handle): - A = T.match_buffer(A_ptr, [64], "float32", scope="global") - B = T.match_buffer(B_ptr, [64], "float32", scope="global") - with T.thread(): - Tx.copy(B, A) - - found = [False] - - def visit(stmt): - if isinstance(stmt, TilePrimitiveCall) and stmt.op == Op.get("tirx.copy"): - found[0] = True - - tvm.tirx.stmt_functor.post_order_visit(func.body, visit) - assert found[0], "Expected TilePrimitiveCall with tirx.copy not found" - - -def test_opcall_downcast_tolerant(): - """TilePrimitiveCall.downcast returns instance as-is for unknown ops.""" - from tvm.tirx.operator.tile_primitive.ops import GenericOp - - A = decl_buffer((64,), "float32", scope="global") - B = decl_buffer((64,), "float32", scope="global") - - op_call = GenericOp(B[0:64], A[0:64], op_name="totally_unknown_op") - # downcast should not raise - result = TilePrimitiveCall.downcast(op_call) - assert result is not None - - def test_buffer_replacer_no_shared_default(): """Regression test for F4: BufferReplacer default dicts must not be shared.""" from tvm.tirx.transform.common import BufferReplacer @@ -199,13 +77,5 @@ def test_gemm_async_partial_scale_factor(): test_copy() test_fill() test_gemm() - test_generic_op_creates_op() - test_generic_op_reuses_registered_op() - test_generic_op_with_existing_tirx_op() - test_tx_dynamic_op_module_getattr() - test_tx_dynamic_op_in_prim_func() - test_tx_dynamic_op_with_workspace() - test_tx_existing_op_not_overridden() - test_opcall_downcast_tolerant() test_buffer_replacer_no_shared_default() test_gemm_async_partial_scale_factor() diff --git a/tests/python/tirx/test_op_namespace_cleanup.py b/tests/python/tirx/test_op_namespace_cleanup.py new file mode 100644 index 000000000000..5c0aa9615207 --- /dev/null +++ b/tests/python/tirx/test_op_namespace_cleanup.py @@ -0,0 +1,248 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for TIRx op namespace split between T, T.tile, and device namespaces.""" + +import pytest + +import tvm +from tvm.ir import Op, assert_structural_equal +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx +from tvm.tirx.stmt import TilePrimitiveCall + + +def _tile_calls(func): + calls = [] + + def visit(stmt): + if isinstance(stmt, TilePrimitiveCall): + calls.append(stmt) + + tvm.tirx.stmt_functor.post_order_visit(func.body, visit) + return calls + + +def _expr_calls(func): + calls = [] + + def visit(node): + if isinstance(node, tvm.tirx.Call): + calls.append(node) + + tvm.tirx.stmt_functor.post_order_visit(func.body, visit) + return calls + + +def _op_attr(op_name, attr_name): + return Op.get(op_name).get_attr(attr_name) + + +def _has_path(root, path): + cur = root + for part in path.split("."): + if not hasattr(cur, part): + return False + cur = getattr(cur, part) + return True + + +def test_tx_is_tile_shorthand_only(): + assert T.tile is Tx + assert T.tile.copy is Tx.copy + assert not hasattr(T, "copy") + assert not hasattr(Tx, "SMEMPool") + assert not hasattr(Tx, "ScopedOp") + assert not hasattr(Tx, "meta_class") + assert T.cast is not Tx.cast + assert T.sqrt is not Tx.sqrt + + +def test_tx_rejects_expression_overloads(): + x = tvm.tirx.Var("x", "float32") + y = tvm.tirx.Var("y", "int32") + + with pytest.raises(TypeError, match="tile-only"): + Tx.sqrt(x) + with pytest.raises(TypeError, match="tile-only"): + T.tile.sqrt(x) + with pytest.raises(TypeError, match="tile-only"): + Tx.cast(y, "float32") + with pytest.raises(TypeError, match="tile-only"): + T.tile.cast(y, "float32") + + +def test_builtin_expression_ops_are_not_tile_primitives(): + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "float32") + + cast = T.cast(x, "float32") + assert isinstance(cast, tvm.tirx.Cast) + assert cast.dtype == "float32" + + sqrt = T.sqrt(y) + assert sqrt.op.name == "tirx.sqrt" + + fma = T.fma(y, y, y) + assert fma.op.name == "tirx.fma" + + +def test_tile_shorthand_and_scoped_aliases_use_tile_ops(): + @T.prim_func(check_well_formed=False) + def tile_aliases(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.tile.copy(A[0:16], B[0:16]) + Tx.cast(A[0:16], B[0:16]) + T.cta.cast(A[0:16], B[0:16]) + Tx.cta.sqrt(A[0:16], B[0:16]) + + calls = _tile_calls(tile_aliases) + assert [call.op.name for call in calls] == [ + "tirx.tile.copy", + "tirx.tile.cast", + "tirx.tile.cast", + "tirx.tile.sqrt", + ] + assert [call.scope.name for call in calls] == ["thread", "thread", "cta", "cta"] + + +def test_device_intrinsic_namespaces_are_canonical_and_classified(): + buffer = tvm.tirx.decl_buffer((1,), "float32") + calls = [ + T.ptx.elect_sync(), + T.cuda.thread_fence(), + T.nvshmem.fence(), + T.nki.identity(buffer[0:1], 1), + ] + + expected = [ + ("tirx.ptx.elect_sync", "ptx"), + ("tirx.cuda.thread_fence", "cuda"), + ("tirx.nvshmem.fence", "nvshmem"), + ("tirx.nki.identity", "nki"), + ] + assert [ + (call.op.name, _op_attr(call.op.name, "TDeviceIntrinsicNamespace")) for call in calls + ] == expected + for op_name, namespace in expected: + assert _op_attr(op_name, "TIRxOpCategory") == "device_intrin" + assert _op_attr(op_name, "TDeviceIntrinsicNamespace") == namespace + assert _op_attr(op_name, "TTilePrimitiveKind") is None + + +def test_device_intrinsic_printer_roundtrips_canonical_namespaces(): + @T.prim_func + def device_namespaces(dst: T.handle, src: T.handle): + A = T.match_buffer(src, (1,), "float32") + R = T.alloc_buffer((1,), "float32", scope="local") + T.cuda.copy_bytes(dst, src, 16) + T.ptx.ldg32(R[0], 1, A[0], 0) + T.metal.simd_shuffle(A[0], 0) + + calls = _expr_calls(device_namespaces) + assert [call.op.name for call in calls] == [ + "tirx.cuda.copy_bytes", + "tirx.ptx.ldg32", + "tirx.metal.simd_shuffle", + ] + for op_name, namespace in [ + ("tirx.cuda.copy_bytes", "cuda"), + ("tirx.ptx.ldg32", "ptx"), + ("tirx.metal.simd_shuffle", "metal"), + ]: + assert _op_attr(op_name, "TIRxOpCategory") == "device_intrin" + assert _op_attr(op_name, "TDeviceIntrinsicNamespace") == namespace + assert _op_attr(op_name, "TCallEffectKind") in (1, 3) + + code = device_namespaces.script() + assert "T.cuda.copy_bytes(" in code + assert "T.ptx.ldg32(" in code + assert "T.metal.simd_shuffle(" in code + assert "T.tirx." not in code + reparsed = tvm.script.from_source(code) + assert reparsed.script() == code + assert_structural_equal(device_namespaces, reparsed) + + +def test_registered_tirx_ops_have_exactly_one_category(): + if _op_attr("tirx.sqrt", "TIRxOpCategory") is None: + pytest.skip("TIRx op categories require a rebuilt C++ runtime") + + categories = {"builtin", "tile_primitive", "device_intrin"} + tile_kinds = {"dispatch", "compose", "async", "marker"} + device_namespaces = {"cuda", "ptx", "nvshmem", "nki", "metal", "webgpu"} + flat_tile_only_names = { + "tirx.add", + "tirx.binary_chain", + "tirx.binary_reduce", + "tirx.compose_op", + "tirx.copy", + "tirx.copy_async", + "tirx.fdiv", + "tirx.fill", + "tirx.gemm", + "tirx.gemm_async", + "tirx.maximum", + "tirx.memset", + "tirx.minimum", + "tirx.mul", + "tirx.permute_layout", + "tirx.reduce_negate", + "tirx.select", + "tirx.sub", + "tirx.sum", + "tirx.unary_reduce", + "tirx.zero", + } + + missing = [] + invalid = [] + lingering_flat_tile = [] + for op_name in sorted(name for name in Op.list_op_names() if name.startswith("tirx.")): + category = _op_attr(op_name, "TIRxOpCategory") + tile_kind = _op_attr(op_name, "TTilePrimitiveKind") + device_namespace = _op_attr(op_name, "TDeviceIntrinsicNamespace") + + if category is None: + missing.append(op_name) + continue + if category not in categories: + invalid.append((op_name, category)) + continue + if op_name in flat_tile_only_names: + lingering_flat_tile.append(op_name) + + if category == "tile_primitive": + if not op_name.startswith("tirx.tile."): + lingering_flat_tile.append(op_name) + assert tile_kind in tile_kinds, op_name + assert device_namespace is None, op_name + elif category == "device_intrin": + assert tile_kind is None, op_name + assert device_namespace in device_namespaces, op_name + printer_name = _op_attr(op_name, "TScriptPrinterName") + assert printer_name is not None, op_name + assert printer_name.startswith(device_namespace + "."), op_name + assert _has_path(T, printer_name), op_name + else: + assert category == "builtin" + assert tile_kind is None, op_name + assert device_namespace is None, op_name + + assert not missing + assert not invalid + assert not lingering_flat_tile diff --git a/tests/python/tirx/test_parser_printer.py b/tests/python/tirx/test_parser_printer.py index 1e77e73d00d3..561adfc602ed 100644 --- a/tests/python/tirx/test_parser_printer.py +++ b/tests/python/tirx/test_parser_printer.py @@ -21,7 +21,7 @@ import tvm.testing from tvm.ir import PointerType, PrimType, assert_structural_equal from tvm.script import tirx as T -from tvm.script import tirx as Tx +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import laneid, warpid @@ -31,14 +31,11 @@ def from_source(code): def _make_minimal_tirx_prim_func(): source = ( - "# from tvm.script import tirx as Tx\n\n" - "@Tx.prim_func()\n" - "def f(a: Tx.handle):\n" - ' A = Tx.match_buffer(a, (1,), "float32")\n' - " with Tx.thread():\n" - " with Tx.cta():\n" - " with Tx.thread():\n" - " A[0] = Tx.float32(1)" + "# from tvm.script import tirx as T\n\n" + "@T.prim_func()\n" + "def f(a: T.handle):\n" + ' A = T.match_buffer(a, (1,), "float32")\n' + " A[0] = T.float32(1)" ) return from_source(source) @@ -49,20 +46,17 @@ def from_source_tir(code): def test_roundtrip_scopeid1(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (64,), "float32", scope="global") - - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - with Tx.warp(): - with Tx.thread(): - A_local = Tx.alloc_buffer([1], dtype="float16", scope="local") - for i in Tx.serial(2): - A_local[0] = A[lane_id * 2 + i] + @T.prim_func + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (64,), "float32", scope="global") + + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + A_local = T.alloc_buffer([1], dtype="float16", scope="local") + for i in T.serial(2): + A_local[0] = A[lane_id * 2 + i] # fmt: on code = test.script() @@ -72,114 +66,103 @@ def test(A_ptr: Tx.handle) -> None: def test_roundtrip_scopeid2(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - _ = Tx.match_buffer(A_ptr, (64,), "float32", scope="global") - - Tx.device_entry() - bx, by, bz = Tx.cta_id([8, 10, 12]) - cbx, cby, cbz = Tx.cta_id_in_cluster([2, 2, 1]) - cta_id_in_pair = Tx.cta_id_in_pair() - clx, cly, clz = Tx.cluster_id([4, 5, 12]) - with Tx.cta(): - with Tx.warp(): - with Tx.thread(): - Tx.evaluate(bx + by + bz) - Tx.evaluate(cbx + cby + cbz) - Tx.evaluate(cta_id_in_pair) - Tx.evaluate(clx + cly + clz) + @T.prim_func + def test(A_ptr: T.handle) -> None: + _ = T.match_buffer(A_ptr, (64,), "float32", scope="global") + + T.device_entry() + bx, by, bz = T.cta_id([8, 10, 12]) + cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 1]) + cta_id_in_pair = T.cta_id_in_pair() + clx, cly, clz = T.cluster_id([4, 5, 12]) + T.evaluate(bx + by + bz) + T.evaluate(cbx + cby + cbz) + T.evaluate(cta_id_in_pair) + T.evaluate(clx + cly + clz) # fmt: on code = test.script() - assert "cta_id_in_pair = Tx.cta_id_in_pair()" in code + assert "cta_id_in_pair = T.cta_id_in_pair()" in code assert from_source(code).script() == code assert_structural_equal(test, from_source(code)) def test_roundtrip_scopeid_deferred(): """Deferred ScopeIdDef (extent=None) survives print→parse round-trip - as a no-arg ``Tx.cta_id()``/``Tx.thread_id()`` etc. call.""" - - # fmt: off - @Tx.prim_func(private=True) - def test(A_ptr: Tx.handle) -> None: - _ = Tx.match_buffer(A_ptr, (64,), "float32", scope="global") - Tx.device_entry() - bx = Tx.cta_id() # deferred kernel→cta - cbx = Tx.cta_id_in_cluster([2]) - clx = Tx.cluster_id([4]) - tx = Tx.thread_id() # deferred cta→thread - Tx.warp_id([4]) - Tx.lane_id([32]) - with Tx.thread(): - Tx.evaluate(bx + cbx + clx + tx) + as a no-arg ``T.cta_id()``/``T.thread_id()`` etc. call.""" + + # fmt: off + @T.prim_func(private=True) + def test(A_ptr: T.handle) -> None: + _ = T.match_buffer(A_ptr, (64,), "float32", scope="global") + T.device_entry() + bx = T.cta_id() # deferred kernel→cta + cbx = T.cta_id_in_cluster([2]) + clx = T.cluster_id([4]) + tx = T.thread_id() # deferred cta→thread + T.warp_id([4]) + T.lane_id([32]) + T.evaluate(bx + cbx + clx + tx) # fmt: on code = test.script() - assert "bx = Tx.cta_id()" in code - assert "tx = Tx.thread_id()" in code + assert "bx = T.cta_id()" in code + assert "tx = T.thread_id()" in code assert from_source(code).script() == code assert_structural_equal(test, from_source(code)) -def test_exec_scope_filter_guard_roundtrip_with_scope_arg_sugar(): - @Tx.prim_func(private=True) - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") +def test_exec_scope_filter_guard_roundtrip(): + @T.prim_func(private=True) + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - tx = Tx.thread_id([128]) - with Tx.cta(): - with Tx.thread((0 <= tx) & (tx < 1)): - A[0] = Tx.float32(1) + T.device_entry() + T.cta_id([1]) + tx = T.thread_id([128]) + if (0 <= tx) & (tx < 1): + A[0] = T.float32(1) code = test.script() - assert "with Tx.thread(Tx.bitwise_and(0 <= tx, tx < 1)):" in code - assert "if Tx.filter(tx, 0, 1):" not in code assert from_source(code).script() == code assert_structural_equal(test, from_source(code)) def test_roundtrip_layout(): def get_layout1(): - return Tx.TileLayout(Tx.S[(8, 8, 8, 4, 2) : (6, 4 @ laneid, 2, 1 @ laneid, 1)]) + return T.TileLayout(T.S[(8, 8, 8, 4, 2) : (6, 4 @ laneid, 2, 1 @ laneid, 1)]) def get_layout2(): - return Tx.TileLayout(Tx.S[(8, 8, 8, 4, 2) : (64, 4 @ laneid, 8, 2, 1)]) + return T.TileLayout(T.S[(8, 8, 8, 4, 2) : (64, 4 @ laneid, 8, 2, 1)]) def get_layout3(): - return Tx.TileLayout(Tx.S[(8, 16, 8, 16) : (1024, 16, 128, 1)]) + return T.TileLayout(T.S[(8, 16, 8, 16) : (1024, 16, 128, 1)]) def get_layout4(): - return Tx.SwizzleLayout(per_element=3, swizzle_len=3, atom_len=3) + return T.SwizzleLayout(per_element=3, swizzle_len=3, atom_len=3) def get_layout5(): - return Tx.ComposeLayout( - Tx.SwizzleLayout(per_element=3, swizzle_len=3, atom_len=3), - Tx.TileLayout(Tx.S[(64, 64, 4) : (64, 1, 64 * 64)]), + return T.ComposeLayout( + T.SwizzleLayout(per_element=3, swizzle_len=3, atom_len=3), + T.TileLayout(T.S[(64, 64, 4) : (64, 1, 64 * 64)]), ) # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - _ = Tx.match_buffer(A_ptr, (64,), "float32", scope="global") - - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - C = Tx.alloc_buffer([128, 128], dtype="float16", scope="shared", layout=get_layout3()) - D = Tx.alloc_buffer([128, 32], dtype="float16", scope="shared", layout=get_layout4()) - - with Tx.cta(): - A_warp = Tx.alloc_buffer([64, 64], dtype="float16", scope="shared", layout=get_layout1()) # noqa: E501 - B_warp = Tx.alloc_buffer([64, 64], dtype="float16", scope="shared", layout=get_layout2()) # noqa: E501 - - E = Tx.alloc_buffer([64, 256], dtype="float16", scope="shared", layout=get_layout5()) - - with Tx.thread(): - Tx.evaluate(A_warp[0, 0] + B_warp[0, 0] + C[0, 0] + D[0, 0] + E[0, 0]) + @T.prim_func + def test(A_ptr: T.handle) -> None: + _ = T.match_buffer(A_ptr, (64,), "float32", scope="global") + + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + C = T.alloc_buffer([128, 128], dtype="float16", scope="shared", layout=get_layout3()) + D = T.alloc_buffer([128, 32], dtype="float16", scope="shared", layout=get_layout4()) + A_warp = T.alloc_buffer([64, 64], dtype="float16", scope="shared", layout=get_layout1()) + B_warp = T.alloc_buffer([64, 64], dtype="float16", scope="shared", layout=get_layout2()) + + E = T.alloc_buffer([64, 256], dtype="float16", scope="shared", layout=get_layout5()) + T.evaluate(A_warp[0, 0] + B_warp[0, 0] + C[0, 0] + D[0, 0] + E[0, 0]) # fmt: on code = test.script() @@ -194,31 +177,26 @@ def test_roundtrip_layout_replica_and_offset(): of overwriting (see `_merge_offset` in `tvm.tirx.layout`).""" def get_shard_replica(): - return Tx.TileLayout(Tx.S[8 : 4 @ laneid] + Tx.R[4 : 1 @ laneid]) + return T.TileLayout(T.S[8 : 4 @ laneid] + T.R[4 : 1 @ laneid]) def get_shard_offset_single(): - return Tx.TileLayout(Tx.S[8 : 4 @ laneid] + 1 @ laneid) + return T.TileLayout(T.S[8 : 4 @ laneid] + 1 @ laneid) def get_shard_offset_multi(): - return Tx.TileLayout(Tx.S[8 : 4 @ laneid] + 1 @ laneid + 2 @ warpid + 64) + return T.TileLayout(T.S[8 : 4 @ laneid] + 1 @ laneid + 2 @ warpid + 64) def get_full(): - return Tx.TileLayout( - Tx.S[(1,) : (1,)] + Tx.R[(8, 4) : (4 @ laneid, 1 @ laneid)] + 2 @ warpid - ) + return T.TileLayout(T.S[(1,) : (1,)] + T.R[(8, 4) : (4 @ laneid, 1 @ laneid)] + 2 @ warpid) # fmt: off - @Tx.prim_func + @T.prim_func def test() -> None: - Tx.device_entry() - with Tx.cta(): - A = Tx.alloc_buffer([8], dtype="float16", scope="shared", layout=get_shard_replica()) - B = Tx.alloc_buffer([8], dtype="float16", scope="shared", layout=get_shard_offset_single()) # noqa: E501 - C = Tx.alloc_buffer([8], dtype="float16", scope="shared", layout=get_shard_offset_multi()) # noqa: E501 - D = Tx.alloc_buffer([32], dtype="float16", scope="shared", layout=get_full()) - - with Tx.thread(): - Tx.evaluate(A[0] + B[0] + C[0] + D[0]) + T.device_entry() + A = T.alloc_buffer([8], dtype="float16", scope="shared", layout=get_shard_replica()) + B = T.alloc_buffer([8], dtype="float16", scope="shared", layout=get_shard_offset_single()) + C = T.alloc_buffer([8], dtype="float16", scope="shared", layout=get_shard_offset_multi()) + D = T.alloc_buffer([32], dtype="float16", scope="shared", layout=get_full()) + T.evaluate(A[0] + B[0] + C[0] + D[0]) # fmt: on code = test.script() @@ -228,19 +206,19 @@ def test() -> None: def test_print_kwargs_schedule_op_full_code(): # fmt: off - @Tx.prim_func + @T.prim_func def test(): - A = Tx.alloc_buffer((16,), "float32") - Tx.memset(A[0:16], Tx.float32(1.25), dispatch="v10", bar=7, foo=42) + A = T.alloc_buffer((16,), "float32") + Tx.memset(A[0:16], T.float32(1.25), dispatch="v10", bar=7, foo=42) # fmt: on expected = ( - "# from tvm.script import tirx as Tx\n" + "# from tvm.script import tirx as T\n" "# from tvm.tirx.layout import Axis\n\n" - "@Tx.prim_func\n" + "@T.prim_func\n" "def test():\n" - " A = Tx.alloc_buffer((16,))\n" - ' Tx.memset(A[0:16], Tx.float32(1.25), dispatch="v10", bar=7, foo=42)' + " A = T.alloc_buffer((16,))\n" + ' T.tile.memset(A[0:16], T.float32(1.25), dispatch="v10", bar=7, foo=42)' ) code = test.script() assert code == expected @@ -249,36 +227,32 @@ def test(): def test_default_script_prefix_tirx_irmodule_non_main(): - """IRModule with non-main TIRx PrimFunc should default to Tx prefix.""" + """IRModule with non-main TIRx PrimFunc should default to T prefix.""" mod = tvm.IRModule({"foo": _make_minimal_tirx_prim_func()}) code = mod.script() - assert "# from tvm.script import tirx as Tx" in code + assert "# from tvm.script import tirx as T" in code assert "# from tvm.script import tir as T" not in code - assert "@Tx.prim_func" in code + assert "@T.prim_func" in code assert "def foo(" in code - assert "with Tx.thread():" in code parsed = from_source(code) assert parsed.script() == code assert_structural_equal(mod, parsed) -L_LANE = Tx.TileLayout(Tx.S[32 : 1 @ laneid]) +L_LANE = T.TileLayout(T.S[32 : 1 @ laneid]) def test_roundtrip_buffer_view_get1(): # fmt: off - @Tx.prim_func + @T.prim_func def test() -> None: - Tx.device_entry() - with Tx.cta(): - A = Tx.alloc_buffer([2], dtype="float16", scope="local") - A_layout = Tx.TileLayout(Tx.S[(1, 2) : (2, 1)]) - A_warp_layout = A_layout.tile(L_LANE, (8, 4), (1, 2)) - A_warp = A.view(8, 8, layout=A_warp_layout) - - with Tx.thread(): - A_local = A_warp.local(2) - A_local[0] = Tx.float16(0) + T.device_entry() + A = T.alloc_buffer([2], dtype="float16", scope="local") + A_layout = T.TileLayout(T.S[(1, 2) : (2, 1)]) + A_warp_layout = A_layout.tile(L_LANE, (8, 4), (1, 2)) + A_warp = A.view(8, 8, layout=A_warp_layout) + A_local = A_warp.local(2) + A_local[0] = T.float16(0) # fmt: on code = test.script() @@ -288,24 +262,21 @@ def test() -> None: def test_roundtrip_buffer_view_get2(): # fmt: off - @Tx.prim_func - def test(out_ptr: Tx.handle) -> None: - out = Tx.match_buffer(out_ptr, (2), "float32", scope="global") - - Tx.device_entry() - bx, by, bz = Tx.cta_id([32, 32, 1]) - tx, ty, tz = Tx.thread_id([16, 8, 1]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - A = Tx.alloc_buffer([2,], dtype="float16", scope="local") - A_layout = Tx.TileLayout(Tx.S[(1, 2) : (2, 1)]) - B_layout = A_layout.tile(L_LANE, (8, 4), (1, 2)) - B = A.view(8, 8, layout=B_layout) - D = B.local(2) - - with Tx.thread(): - out[0] = A[0] + B[0, 0] + D[0] + @T.prim_func + def test(out_ptr: T.handle) -> None: + out = T.match_buffer(out_ptr, (2), "float32", scope="global") + + T.device_entry() + bx, by, bz = T.cta_id([32, 32, 1]) + tx, ty, tz = T.thread_id([16, 8, 1]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + A = T.alloc_buffer([2,], dtype="float16", scope="local") + A_layout = T.TileLayout(T.S[(1, 2) : (2, 1)]) + B_layout = A_layout.tile(L_LANE, (8, 4), (1, 2)) + B = A.view(8, 8, layout=B_layout) + D = B.local(2) + out[0] = A[0] + B[0, 0] + D[0] # fmt: on code = test.script() assert from_source(code).script() == code @@ -314,17 +285,14 @@ def test(out_ptr: Tx.handle) -> None: def test_roundtrip_buffer_view_get3(): # fmt: off - @Tx.prim_func + @T.prim_func def test() -> None: - Tx.device_entry() - with Tx.cta(): - A = Tx.alloc_buffer([8, 8], dtype="float32", scope="local") - A_f16 = A.view("float16") - A_f64 = A.view("float64") - - with Tx.thread(): - A_f16[0, 0] = Tx.float16(0) - A_f64[0, 0] = Tx.float64(0) + T.device_entry() + A = T.alloc_buffer([8, 8], dtype="float32", scope="local") + A_f16 = A.view("float16") + A_f64 = A.view("float64") + A_f16[0, 0] = T.float16(0) + A_f64[0, 0] = T.float64(0) # fmt: on code = test.script() @@ -335,22 +303,21 @@ def test() -> None: def test_roundtrip_op1(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (64,), "float32", scope="global") - - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - A_smem = Tx.alloc_buffer([64], dtype="float32", scope="shared") - - Tx.copy(A_smem, A) - for i in range(10): - Tx.fill(A_smem, Tx.float32(0)) - Tx.gemm(A_smem, A_smem, A_smem, A_smem) - Tx.copy(A, A_smem) + @T.prim_func + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (64,), "float32", scope="global") + + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + A_smem = T.alloc_buffer([64], dtype="float32", scope="shared") + + Tx.cta.copy(A_smem, A) + for i in range(10): + Tx.cta.fill(A_smem, T.float32(0)) + Tx.cta.gemm(A_smem, A_smem, A_smem, A_smem) + Tx.cta.copy(A, A_smem) # fmt: on code = test.script() @@ -360,26 +327,25 @@ def test(A_ptr: Tx.handle) -> None: def test_roundtrip_op2(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128, 128), "float16", scope="global") - B = Tx.match_buffer(B_ptr, (128, 64), "float16", scope="global") - C = Tx.match_buffer(C_ptr, (128, 64), "float32", scope="global") - - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - A_smem = Tx.alloc_buffer([128, 32], dtype="float16", scope="shared") - B_smem = Tx.alloc_buffer([32, 64], dtype="float16", scope="shared") - - C_local = Tx.alloc_buffer([128, 64], dtype="float32", scope="local") - for k in range(4): - Tx.copy(A_smem, A[:, k * 32 : k * 32 + 32]) - Tx.copy(B_smem, B[k * 32 : k * 32 + 32, 0:64]) - Tx.gemm(C_local, A_smem, B_smem, C_local) - Tx.copy(C, C_local) + @T.prim_func + def test(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128, 128), "float16", scope="global") + B = T.match_buffer(B_ptr, (128, 64), "float16", scope="global") + C = T.match_buffer(C_ptr, (128, 64), "float32", scope="global") + + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + A_smem = T.alloc_buffer([128, 32], dtype="float16", scope="shared") + B_smem = T.alloc_buffer([32, 64], dtype="float16", scope="shared") + + C_local = T.alloc_buffer([128, 64], dtype="float32", scope="local") + for k in range(4): + Tx.cta.copy(A_smem, A[:, k * 32 : k * 32 + 32]) + Tx.cta.copy(B_smem, B[k * 32 : k * 32 + 32, 0:64]) + Tx.cta.gemm(C_local, A_smem, B_smem, C_local) + Tx.cta.copy(C, C_local) # fmt: on code = test.script() @@ -392,34 +358,33 @@ def test_roundtrip_op3(): NUM_STAGES = 3 K = 4096 - @Tx.prim_func - def test(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128, K), "float16", scope="global") - B = Tx.match_buffer(B_ptr, (K, 64), "float16", scope="global") - C = Tx.match_buffer(C_ptr, (128, 64), "float32", scope="global") - - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - A_smem = Tx.alloc_buffer([NUM_STAGES, 128, 32], dtype="float16", scope="shared") - B_smem = Tx.alloc_buffer([NUM_STAGES, 32, 64], dtype="float16", scope="shared") - - C_local = Tx.alloc_buffer([128, 64], dtype="float32", scope="local") - for i in range(NUM_STAGES - 1): - Tx.copy(A_smem[i, :, :], A[:, i * 32 : i * 32 + 32]) - Tx.copy(B_smem[i, :, :], B[i * 32 : i * 32 + 32, :]) - - for k in range(K // 32): - copy_k = Tx.meta_var(k + NUM_STAGES - 1) - gemm_stage = Tx.meta_var(k % NUM_STAGES) - copy_stage = Tx.meta_var(copy_k % NUM_STAGES) - Tx.copy(A_smem[copy_stage, :, :], A[:, copy_k * 32 : copy_k * 32 + 32]) - Tx.copy(B_smem[copy_stage, :, :], B[copy_k * 32 : copy_k * 32 + 32, :]) - Tx.gemm(C_local, A_smem[gemm_stage, :, :], B_smem[gemm_stage, :, :], C_local) - - Tx.copy(C, C_local) + @T.prim_func + def test(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128, K), "float16", scope="global") + B = T.match_buffer(B_ptr, (K, 64), "float16", scope="global") + C = T.match_buffer(C_ptr, (128, 64), "float32", scope="global") + + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + A_smem = T.alloc_buffer([NUM_STAGES, 128, 32], dtype="float16", scope="shared") + B_smem = T.alloc_buffer([NUM_STAGES, 32, 64], dtype="float16", scope="shared") + + C_local = T.alloc_buffer([128, 64], dtype="float32", scope="local") + for i in range(NUM_STAGES - 1): + Tx.cta.copy(A_smem[i, :, :], A[:, i * 32 : i * 32 + 32]) + Tx.cta.copy(B_smem[i, :, :], B[i * 32 : i * 32 + 32, :]) + + for k in range(K // 32): + copy_k = T.meta_var(k + NUM_STAGES - 1) + gemm_stage = T.meta_var(k % NUM_STAGES) + copy_stage = T.meta_var(copy_k % NUM_STAGES) + Tx.cta.copy(A_smem[copy_stage, :, :], A[:, copy_k * 32 : copy_k * 32 + 32]) + Tx.cta.copy(B_smem[copy_stage, :, :], B[copy_k * 32 : copy_k * 32 + 32, :]) + Tx.cta.gemm(C_local, A_smem[gemm_stage, :, :], B_smem[gemm_stage, :, :], C_local) + + Tx.cta.copy(C, C_local) # fmt: on code = test.script() @@ -429,13 +394,13 @@ def test(A_ptr: Tx.handle, B_ptr: Tx.handle, C_ptr: Tx.handle) -> None: def test_roundtrip_tensormap(): # fmt: off - @Tx.prim_func - def func1(A_ptr: Tx.handle): - Tx.func_attr({"global_symbol": "func"}) - _ = Tx.match_buffer(A_ptr, [128], "float32") + @T.prim_func + def func1(A_ptr: T.handle): + T.func_attr({"global_symbol": "func"}) + _ = T.match_buffer(A_ptr, [128], "float32") - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.tensormap_init", Tx.address_of(A_map), A_ptr) + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.tensormap_init", T.address_of(A_map), A_ptr) # fmt: on code = func1.script() assert from_source(code).script() == code @@ -444,29 +409,28 @@ def func1(A_ptr: Tx.handle): def test_roundtrip_tensormap_kernel_param(): # fmt: off - @Tx.prim_func - def func1(A_map: Tx.TensorMap()): - Tx.func_attr({"global_symbol": "func"}) - Tx.evaluate(Tx.address_of(A_map)) + @T.prim_func + def func1(A_map: T.TensorMap()): + T.func_attr({"global_symbol": "func"}) + T.evaluate(T.address_of(A_map)) # fmt: on code = func1.script() - assert "Tx.TensorMap()" in code + assert "T.TensorMap()" in code assert from_source(code).script() == code assert_structural_equal(func1, from_source(code)) def test_roundtrip_break_for(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (10,), "int32") + @T.prim_func + def test(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (10,), "int32") - Tx.device_entry() - with Tx.cta(): - for i in Tx.serial(10): - if i > 5: - break - A[i] = i + T.device_entry() + for i in T.serial(10): + if i > 5: + break + A[i] = i # fmt: on code = test.script() assert from_source(code).script() == code @@ -475,19 +439,18 @@ def test(A_ptr: Tx.handle): def test_roundtrip_break_while(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (10,), "int32") - - Tx.device_entry() - with Tx.cta(): - i = Tx.alloc_buffer((1,), "int32", scope="local") - i[0] = 0 - while i[0] < 10: - A[i[0]] = i[0] * 2 - if A[i[0]] > 10: - break - i[0] = i[0] + 1 + @T.prim_func + def test(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (10,), "int32") + + T.device_entry() + i = T.alloc_buffer((1,), "int32", scope="local") + i[0] = 0 + while i[0] < 10: + A[i[0]] = i[0] * 2 + if A[i[0]] > 10: + break + i[0] = i[0] + 1 # fmt: on code = test.script() assert from_source(code).script() == code @@ -496,20 +459,19 @@ def test(A_ptr: Tx.handle): def test_roundtrip_break_nested(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (9,), "int32") - - Tx.device_entry() - with Tx.cta(): - idx = Tx.alloc_buffer((1,), "int32", scope="local") - idx[0] = 0 - for i in Tx.serial(3): - for j in Tx.serial(3): - A[idx[0]] = i * 10 + j - idx[0] += 1 - if j == 1: - break + @T.prim_func + def test(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (9,), "int32") + + T.device_entry() + idx = T.alloc_buffer((1,), "int32", scope="local") + idx[0] = 0 + for i in T.serial(3): + for j in T.serial(3): + A[idx[0]] = i * 10 + j + idx[0] += 1 + if j == 1: + break # fmt: on code = test.script() assert from_source(code).script() == code @@ -518,16 +480,15 @@ def test(A_ptr: Tx.handle): def test_roundtrip_continue_for(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (10,), "int32") - - Tx.device_entry() - with Tx.cta(): - for i in Tx.serial(10): - if (i % 2) == 0: - continue - A[i] = i + @T.prim_func + def test(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (10,), "int32") + + T.device_entry() + for i in T.serial(10): + if (i % 2) == 0: + continue + A[i] = i # fmt: on code = test.script() assert from_source(code).script() == code @@ -536,20 +497,19 @@ def test(A_ptr: Tx.handle): def test_roundtrip_continue_while(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (10,), "int32") - - Tx.device_entry() - with Tx.cta(): - i = Tx.alloc_buffer((1,), "int32", scope="local") - i[0] = 0 - while i[0] < 10: - if (i[0] % 2) == 1: - i[0] += 1 - continue - A[i[0]] = i[0] + @T.prim_func + def test(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (10,), "int32") + + T.device_entry() + i = T.alloc_buffer((1,), "int32", scope="local") + i[0] = 0 + while i[0] < 10: + if (i[0] % 2) == 1: i[0] += 1 + continue + A[i[0]] = i[0] + i[0] += 1 # fmt: on code = test.script() assert from_source(code).script() == code @@ -558,20 +518,19 @@ def test(A_ptr: Tx.handle): def test_roundtrip_continue_nested(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (9,), "int32") - - Tx.device_entry() - with Tx.cta(): - idx = Tx.alloc_buffer((1,), dtype="int32", scope="local") - idx[0] = 0 - for i in Tx.serial(3): - for j in Tx.serial(3): - if j == 1: - continue - A[idx[0]] = i * 10 + j - idx[0] += 1 + @T.prim_func + def test(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (9,), "int32") + + T.device_entry() + idx = T.alloc_buffer((1,), dtype="int32", scope="local") + idx[0] = 0 + for i in T.serial(3): + for j in T.serial(3): + if j == 1: + continue + A[idx[0]] = i * 10 + j + idx[0] += 1 # fmt: on code = test.script() assert from_source(code).script() == code @@ -580,18 +539,17 @@ def test(A_ptr: Tx.handle): def test_roundtrip_break_and_continue(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (10,), "int32") - - Tx.device_entry() - with Tx.cta(): - for i in Tx.serial(10): - if i == 2: - continue - if i == 7: - break - A[i] = i + @T.prim_func + def test(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (10,), "int32") + + T.device_entry() + for i in T.serial(10): + if i == 2: + continue + if i == 7: + break + A[i] = i # fmt: on code = test.script() assert from_source(code).script() == code @@ -600,17 +558,16 @@ def test(A_ptr: Tx.handle): def test_roundtrip_unreachable_after_break(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (5,), "int32") - - Tx.device_entry() - with Tx.cta(): - for i in Tx.serial(5): - A[i] = i - break - # This line is never reached - A[i] = -1 + @T.prim_func + def test(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (5,), "int32") + + T.device_entry() + for i in T.serial(5): + A[i] = i + break + # This line is never reached + A[i] = -1 # fmt: on code = test.script() assert from_source(code).script() == code @@ -619,12 +576,12 @@ def test(A_ptr: Tx.handle): def test_roundtrip_allocated_addr(): # fmt: off - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - A = Tx.alloc_buffer([10], "float32", scope="trn.sbuf", allocated_addr=1024) - for i in Tx.serial(2): - Tx.memset(A[i*5:i*5+5], Tx.float32(0.0)) + T.device_entry() + A = T.alloc_buffer([10], "float32", scope="trn.sbuf", allocated_addr=1024) + for i in T.serial(2): + Tx.memset(A[i*5:i*5+5], T.float32(0.0)) # fmt: on code = test.script() @@ -634,11 +591,11 @@ def test(): def test_roundtrip_implicit_buffer_region(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (10, 10, 10), "float32", layout=Tx.TileLayout(Tx.S[10, 10, 10])) - Tx.device_entry() - Tx.memset(A[0], Tx.float32(0.0)) + @T.prim_func + def test(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (10, 10, 10), "float32", layout=T.TileLayout(T.S[10, 10, 10])) + T.device_entry() + Tx.memset(A[0], T.float32(0.0)) # fmt: on code = test.script() @@ -648,12 +605,12 @@ def test(A_ptr: Tx.handle): def test_roundtrip_alloc_under_any_scope(): # fmt: off - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - for i in Tx.serial(10): - A = Tx.alloc_buffer([100], "float32", scope="trn.sbuf", allocated_addr=1024) - Tx.memset(A[i*10:i*10+10], Tx.float32(0.0)) + T.device_entry() + for i in T.serial(10): + A = T.alloc_buffer([100], "float32", scope="trn.sbuf", allocated_addr=1024) + Tx.memset(A[i*10:i*10+10], T.float32(0.0)) # fmt: on code = test.script() @@ -663,15 +620,15 @@ def test(): def test_roundtrip_compose_op(): # fmt: off - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - A = Tx.alloc_buffer([10], "float32", scope="trn.sbuf") - B = Tx.alloc_buffer([10], "float32", scope="trn.sbuf") - C = Tx.alloc_buffer([10], "float32", scope="trn.sbuf") + T.device_entry() + A = T.alloc_buffer([10], "float32", scope="trn.sbuf") + B = T.alloc_buffer([10], "float32", scope="trn.sbuf") + C = T.alloc_buffer([10], "float32", scope="trn.sbuf") with Tx.compose_op(): - Tx.add(B, A, Tx.float32(1)) - Tx.add(C, B, Tx.float32(1)) + Tx.add(B, A, T.float32(1)) + Tx.add(C, B, T.float32(1)) # fmt: on code = test.script() assert from_source(code).script() == code @@ -680,13 +637,13 @@ def test(): def test_roundtrip_op_call_workspace(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, [10], "float32", scope="global") - B = Tx.match_buffer(B_ptr, [10], "float32", scope="global") - Tx.device_entry() - smem = Tx.alloc_buffer([10], "float32", scope="shared") - Tx.add(B, A, Tx.float32(1), workspace={"smem": smem}) + @T.prim_func + def test(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, [10], "float32", scope="global") + B = T.match_buffer(B_ptr, [10], "float32", scope="global") + T.device_entry() + smem = T.alloc_buffer([10], "float32", scope="shared") + Tx.add(B, A, T.float32(1), workspace={"smem": smem}) # fmt: on code = test.script() assert from_source(code).script() == code @@ -695,17 +652,17 @@ def test(A_ptr: Tx.handle, B_ptr: Tx.handle): def test_roundtrip_compose_op_call_workspace(): # fmt: off - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - A = Tx.alloc_buffer([10], "float32", scope="trn.sbuf") - B = Tx.alloc_buffer([10], "float32", scope="trn.sbuf") - C = Tx.alloc_buffer([10], "float32", scope="trn.sbuf") - psum = Tx.alloc_buffer([10], "float32", scope="trn.psum") - intermediate = Tx.alloc_buffer([10], "float32", scope="trn.sbuf") + T.device_entry() + A = T.alloc_buffer([10], "float32", scope="trn.sbuf") + B = T.alloc_buffer([10], "float32", scope="trn.sbuf") + C = T.alloc_buffer([10], "float32", scope="trn.sbuf") + psum = T.alloc_buffer([10], "float32", scope="trn.psum") + intermediate = T.alloc_buffer([10], "float32", scope="trn.sbuf") with Tx.compose_op(workspace={"intermediate": intermediate}): - Tx.add(B, A, Tx.float32(1)) - Tx.add(C, B, Tx.float32(1), workspace={"psum": psum}) + Tx.add(B, A, T.float32(1)) + Tx.add(C, B, T.float32(1), workspace={"psum": psum}) # fmt: on code = test.script() assert from_source(code).script() == code @@ -714,12 +671,12 @@ def test(): def test_roundtrip_op_call_config(): # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, [10], "float32", scope="global") - B = Tx.match_buffer(B_ptr, [10], "float32", scope="global") - Tx.device_entry() - Tx.add(B, A, Tx.float32(1), schedule="A") + @T.prim_func + def test(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, [10], "float32", scope="global") + B = T.match_buffer(B_ptr, [10], "float32", scope="global") + T.device_entry() + Tx.add(B, A, T.float32(1), schedule="A") # fmt: on code = test.script() assert from_source(code).script() == code @@ -728,16 +685,16 @@ def test(A_ptr: Tx.handle, B_ptr: Tx.handle): def test_roundtrip_compose_op_call_config(): # fmt: off - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - A = Tx.alloc_buffer([10], "float32", scope="trn.sbuf") - B = Tx.alloc_buffer([10], "float32", scope="trn.sbuf") - C = Tx.alloc_buffer([10], "float32", scope="trn.sbuf") - psum = Tx.alloc_buffer([10], "float32", scope="trn.psum") + T.device_entry() + A = T.alloc_buffer([10], "float32", scope="trn.sbuf") + B = T.alloc_buffer([10], "float32", scope="trn.sbuf") + C = T.alloc_buffer([10], "float32", scope="trn.sbuf") + psum = T.alloc_buffer([10], "float32", scope="trn.psum") with Tx.compose_op( schedule="A"): - Tx.add(B, A, Tx.float32(1)) - Tx.add(C, B, Tx.float32(1), workspace={"psum": psum}) + Tx.add(B, A, T.float32(1)) + Tx.add(C, B, T.float32(1), workspace={"psum": psum}) # fmt: on code = test.script() assert from_source(code).script() == code @@ -746,11 +703,11 @@ def test(): def test_predicate(): # fmt: off - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - A = Tx.alloc_buffer([10, 10], "float32") - B = Tx.alloc_buffer([10, 10], "float32") + T.device_entry() + A = T.alloc_buffer([10, 10], "float32") + B = T.alloc_buffer([10, 10], "float32") Tx.select(B, A, 1.0, lambda i, j: i < j) # fmt: on code = test.script() @@ -760,12 +717,11 @@ def test(): def test_grid(): # fmt: off - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - with Tx.thread(): - for lvs in Tx.grid(10, (2, 12)): - Tx.evaluate(lvs[0] + lvs[1]) + T.device_entry() + for lvs in T.grid(10, (2, 12)): + T.evaluate(lvs[0] + lvs[1]) # fmt: on code = test.script() assert from_source(code).script() == code @@ -774,65 +730,64 @@ def test(): def test_alloc_apis(): # fmt: off - @Tx.meta_class + @T.meta_class class Test: def __init__(self, Ta, inner_pool): self.Ta = Ta self.inner_pool = inner_pool - self.Tb = Tx.shared_scalar("float16") - self.idx = Tx.local_scalar("int32") - self.inner_pool2 = Tx.decl_scalar("float16", self.inner_pool.data, "shared.dyn", 5) + self.Tb = T.shared_scalar("float16") + self.idx = T.local_scalar("int32") + self.inner_pool2 = T.decl_scalar("float16", self.inner_pool.data, "shared.dyn", 5) - @Tx.inline + @T.inline def init(self): - self.Ta = self.Ta + Tx.float16(1) - self.Tb = self.Tb + Tx.float16(2) - self.idx.buffer[0] = Tx.int32(0) - self.idx = self.idx + Tx.int32(1) - self.inner_pool2 = self.inner_pool2 + Tx.float16(1) - Tx.evaluate(Tx.address_of(self.Ta)) - Tx.evaluate(Tx.address_of(self.Tb)) - Tx.evaluate(Tx.address_of(self.idx)) - Tx.evaluate(Tx.address_of(self.inner_pool)) - Tx.evaluate(Tx.address_of(self.inner_pool2)) - - @Tx.prim_func + self.Ta = self.Ta + T.float16(1) + self.Tb = self.Tb + T.float16(2) + self.idx.buffer[0] = T.int32(0) + self.idx = self.idx + T.int32(1) + self.inner_pool2 = self.inner_pool2 + T.float16(1) + T.evaluate(T.address_of(self.Ta)) + T.evaluate(T.address_of(self.Tb)) + T.evaluate(T.address_of(self.idx)) + T.evaluate(T.address_of(self.inner_pool)) + T.evaluate(T.address_of(self.inner_pool2)) + + @T.prim_func def test(): - Tx.device_entry() + T.device_entry() # normal buffer - A = Tx.alloc_shared([10], "float16") - B = Tx.alloc_local([10], "float16") + A = T.alloc_shared([10], "float16") + B = T.alloc_local([10], "float16") # scalar buffer (alloc) - C = Tx.shared_scalar("float16") - D: Tx.float16 - pool = Tx.alloc_buffer([10], "uint8", scope="shared.dyn") + C = T.shared_scalar("float16") + D: T.float16 + pool = T.alloc_buffer([10], "uint8", scope="shared.dyn") # scalar buffer (decl) - E = Tx.decl_scalar("float16", pool.data, "shared.dyn", 0) + E = T.decl_scalar("float16", pool.data, "shared.dyn", 0) # normal 1-dim buffer with shape (1,) - F = Tx.alloc_local((1,), "float16") - with Tx.thread(): - Ta: Tx.float16 - inner_pool = Tx.decl_buffer(shape=[10], data=pool.data, dtype="uint8", scope="shared.dyn") # noqa: E501 - test = Test(Ta, inner_pool) # noqa: F821 - test.init() - A[0] = C - A[0] = C + D # noqa: F821 - A[1] = B[0] * C - D.buffer[0] = D + Tx.float16(1) # noqa: F821 - D = D + Tx.float16(1) # noqa: F821 - C = D - Tx.evaluate(E) - E = E + Tx.float16(1) - # normal 1-dim buffer with shape (1,) can be assigned directly, - # but not loaded directly - F = F[0] + Tx.float16(1) - C += D - D += E + C + D - Tx.evaluate(Tx.address_of(C)) - Tx.evaluate(C.buffer.access_ptr("rw", offset=0)) - Tx.evaluate(C.buffer.data) - Tx.evaluate(D) - Tx.evaluate(Tx.address_of(D)) + F = T.alloc_local((1,), "float16") + Ta: T.float16 + inner_pool = T.decl_buffer(shape=[10], data=pool.data, dtype="uint8", scope="shared.dyn") + test = Test(Ta, inner_pool) # noqa: F821 + test.init() + A[0] = C + A[0] = C + D # noqa: F821 + A[1] = B[0] * C + D.buffer[0] = D + T.float16(1) # noqa: F821 + D = D + T.float16(1) # noqa: F821 + C = D + T.evaluate(E) + E = E + T.float16(1) + # normal 1-dim buffer with shape (1,) can be assigned directly, + # but not loaded directly + F = F[0] + T.float16(1) + C += D + D += E + C + D + T.evaluate(T.address_of(C)) + T.evaluate(C.buffer.access_ptr("rw", offset=0)) + T.evaluate(C.buffer.data) + T.evaluate(D) + T.evaluate(T.address_of(D)) # fmt: on code = test.script() @@ -842,49 +797,48 @@ def test(): def test_alloc_apis_reject_name_argument(): with pytest.raises(TypeError): - Tx.alloc_buffer((1,), "int32", name="buf") + T.alloc_buffer((1,), "int32", name="buf") with pytest.raises(TypeError): - Tx.local_scalar("int32", name="idx") + T.local_scalar("int32", name="idx") def test_meta_class_constructor_rejects_unowned_resource(): - @Tx.meta_class + @T.meta_class class Bad: def __init__(self): - tmp = Tx.alloc_buffer((1,), "int32", scope="local") + tmp = T.alloc_buffer((1,), "int32", scope="local") with pytest.raises(tvm.error.DiagnosticError): - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() + T.device_entry() bad = Bad() def test_meta_class_multiple_instances_auto_name_owned_resources(): - @Tx.meta_class + @T.meta_class class Holder: def __init__(self, external): self.external = external - self.buf = Tx.alloc_buffer((2,), "int32", scope="local") - self.scalar = Tx.local_scalar("int32") + self.buf = T.alloc_buffer((2,), "int32", scope="local") + self.scalar = T.local_scalar("int32") - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - with Tx.thread(): - external = Tx.alloc_buffer((2,), "int32", scope="local") - first = Holder(external) - second = Holder(external) - Tx.evaluate( - first.buf[0] - + second.buf[1] - + first.scalar - + second.scalar - + first.external[0] - + second.external[1] - ) + T.device_entry() + external = T.alloc_buffer((2,), "int32", scope="local") + first = Holder(external) + second = Holder(external) + T.evaluate( + first.buf[0] + + second.buf[1] + + first.scalar + + second.scalar + + first.external[0] + + second.external[1] + ) code = test.script() bufs = _collect_buffers(test) @@ -892,29 +846,29 @@ def test(): assert "first_external" not in bufs assert "second_external" not in bufs assert {"first_buf", "second_buf", "first_scalar", "second_scalar"}.issubset(bufs) - assert 'first_buf = Tx.alloc_local((2,), "int32")' in code - assert 'second_buf = Tx.alloc_local((2,), "int32")' in code - assert "first_scalar: Tx.int32" in code - assert "second_scalar: Tx.int32" in code + assert 'first_buf = T.alloc_local((2,), "int32")' in code + assert 'second_buf = T.alloc_local((2,), "int32")' in code + assert "first_scalar: T.int32" in code + assert "second_scalar: T.int32" in code assert from_source(code).script() == code def test_macro(): # fmt: off - @Tx.inline + @T.inline def mul(x, c): - Tx.evaluate(x * c) + T.evaluate(x * c) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def test(): - Tx.device_entry() + T.device_entry() for x in range(10): - @Tx.inline + @T.inline def add(c): - Tx.evaluate(x + c) + T.evaluate(x + c) - @Tx.inline + @T.inline def two_add_and_mul(c): add(c) add(c + c) @@ -924,16 +878,16 @@ def two_add_and_mul(c): two_add_and_mul(2) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def expected(): - Tx.device_entry() + T.device_entry() for x in range(10): - Tx.evaluate(x + 1) - Tx.evaluate(x + 2) - Tx.evaluate(x) - Tx.evaluate(x + 2) - Tx.evaluate(x + 4) - Tx.evaluate(x * 2) + T.evaluate(x + 1) + T.evaluate(x + 2) + T.evaluate(x) + T.evaluate(x + 2) + T.evaluate(x + 4) + T.evaluate(x * 2) # fmt: on code = test.script() assert from_source(code).script() == code @@ -943,29 +897,29 @@ def expected(): def test_macro_recursive(): # fmt: off - @Tx.prim_func(private=True) + @T.prim_func(private=True) def test(): - Tx.device_entry() - for x in Tx.serial(10): + T.device_entry() + for x in T.serial(10): - @Tx.inline + @T.inline def add(x, c): if c > 0: add(x, c - 1) - Tx.evaluate(x) + T.evaluate(x) add(x, 5) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def expected(): - Tx.device_entry() + T.device_entry() for x in range(10): - Tx.evaluate(x) - Tx.evaluate(x) - Tx.evaluate(x) - Tx.evaluate(x) - Tx.evaluate(x) - Tx.evaluate(x) + T.evaluate(x) + T.evaluate(x) + T.evaluate(x) + T.evaluate(x) + T.evaluate(x) + T.evaluate(x) # fmt: on code = test.script() print(code) @@ -976,16 +930,15 @@ def expected(): def test_list_comprehension(): # fmt: off - @Tx.prim_func(private=True) + @T.prim_func(private=True) def test(): - Tx.device_entry() - with Tx.thread(): - acc = Tx.alloc_local([10], "bool") - regs = Tx.meta_var([acc[_] for _ in range(10)]) - Tx.evaluate(regs[0]) - Tx.evaluate(tvm.tirx.all(*regs)) - Tx.evaluate(tvm.tirx.all(*[acc[_] for _ in range(10)])) - Tx.evaluate(tvm.tirx.all(*([acc[_] for _ in range(2, 4)] + [acc[_] for _ in range(6, 8)]))) # noqa: E501 + T.device_entry() + acc = T.alloc_local([10], "bool") + regs = T.meta_var([acc[_] for _ in range(10)]) + T.evaluate(regs[0]) + T.evaluate(tvm.tirx.all(*regs)) + T.evaluate(tvm.tirx.all(*[acc[_] for _ in range(10)])) + T.evaluate(tvm.tirx.all(*([acc[_] for _ in range(2, 4)] + [acc[_] for _ in range(6, 8)]))) # fmt: on code = test.script() print(code) @@ -995,14 +948,14 @@ def test(): def test_range(): # fmt: off - @Tx.prim_func(private=True) + @T.prim_func(private=True) def test(): - l = Tx.meta_var([i for i in range(10)]) # noqa: E741 - Tx.evaluate(l[3]) + l = T.meta_var([i for i in range(10)]) # noqa: E741 + T.evaluate(l[3]) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def expected(): - Tx.evaluate(3) + T.evaluate(3) # fmt: on code = test.script() @@ -1014,34 +967,32 @@ def expected(): def test_buffer(): # fmt: off - @Tx.prim_func(private=True) + @T.prim_func(private=True) def test( - A: Tx.Buffer((10, 11), "float32", layout=None), - B: Tx.Buffer((10, 11), "float32", scope="global"), - C: Tx.Buffer((10, 11), "float32", layout="default"), - D: Tx.Buffer((10, 11), "float32", layout=Tx.TileLayout(Tx.S[(10, 11) : (1, 10)])), - E_ptr: Tx.handle, - F_ptr: Tx.handle, - G_ptr: Tx.handle, - H_ptr: Tx.handle, + A: T.Buffer((10, 11), "float32", layout=None), + B: T.Buffer((10, 11), "float32", scope="global"), + C: T.Buffer((10, 11), "float32", layout="default"), + D: T.Buffer((10, 11), "float32", layout=T.TileLayout(T.S[(10, 11) : (1, 10)])), + E_ptr: T.handle, + F_ptr: T.handle, + G_ptr: T.handle, + H_ptr: T.handle, ): - _E = Tx.match_buffer(E_ptr, [10, 11], "float16", layout=None) - _F = Tx.match_buffer(F_ptr, [10, 11], "float16", scope="global") - _G = Tx.match_buffer(G_ptr, [10, 11], "float16", layout="default") - _H = Tx.match_buffer(H_ptr, [10, 11], "float16", layout=Tx.TileLayout(Tx.S[(10, 11) : (1, 10)])) # noqa: E501 - - _A0 = Tx.decl_buffer((10, 11), "float32", data=A.data, layout=None) - _B0 = Tx.decl_buffer((10, 11), "float32", data=B.data, scope="global") - _C0 = Tx.decl_buffer((10, 11), "float32", data=C.data, layout="default") - _D0 = Tx.decl_buffer((10, 11), "float32", data=D.data, layout=Tx.TileLayout(Tx.S[(10, 11) : (1, 10)])) # noqa: E501 - - with Tx.thread(): - _A1 = Tx.alloc_buffer((10, 11), "float32", layout=None) - _B1 = Tx.alloc_buffer((10, 11), "float32", scope="global") - _C1 = Tx.alloc_buffer((10, 11), "float32", layout="default") - _D1 = Tx.alloc_buffer((10, 11), "float32", layout=Tx.TileLayout(Tx.S[(10, 11) : (1, 10)])) # noqa: E501 - - pass + _E = T.match_buffer(E_ptr, [10, 11], "float16", layout=None) + _F = T.match_buffer(F_ptr, [10, 11], "float16", scope="global") + _G = T.match_buffer(G_ptr, [10, 11], "float16", layout="default") + _H = T.match_buffer(H_ptr, [10, 11], "float16", layout=T.TileLayout(T.S[(10, 11) : (1, 10)])) # noqa: E501 + + _A0 = T.decl_buffer((10, 11), "float32", data=A.data, layout=None) + _B0 = T.decl_buffer((10, 11), "float32", data=B.data, scope="global") + _C0 = T.decl_buffer((10, 11), "float32", data=C.data, layout="default") + _D0 = T.decl_buffer((10, 11), "float32", data=D.data, layout=T.TileLayout(T.S[(10, 11) : (1, 10)])) # noqa: E501 + _A1 = T.alloc_buffer((10, 11), "float32", layout=None) + _B1 = T.alloc_buffer((10, 11), "float32", scope="global") + _C1 = T.alloc_buffer((10, 11), "float32", layout="default") + _D1 = T.alloc_buffer((10, 11), "float32", layout=T.TileLayout(T.S[(10, 11) : (1, 10)])) + + pass # fmt: on code = test.script() assert from_source(code).script() == code @@ -1050,10 +1001,10 @@ def test( def test_kwargs_op_call(): # fmt: off - @Tx.prim_func(private=True) - def test(A: Tx.Buffer((10, 10), "float32"), B: Tx.Buffer((10, 10), "float32")): - Tx.device_entry() - kwargs = Tx.meta_var({"dispatch": "tma", "cta_group": 2}) + @T.prim_func(private=True) + def test(A: T.Buffer((10, 10), "float32"), B: T.Buffer((10, 10), "float32")): + T.device_entry() + kwargs = T.meta_var({"dispatch": "tma", "cta_group": 2}) Tx.copy_async(A[:, :], B[:, :], **kwargs) # fmt: on code = test.script() @@ -1115,19 +1066,18 @@ class State: def __init__(self, counter): self.counter = counter - @Tx.inline + @T.inline def add_one(self): # PrimExpr assigned to scalar via self.attr → buffer_store succeeds - self.counter = self.counter + Tx.int32(1) + self.counter = self.counter + T.int32(1) - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - with Tx.thread(): - counter: Tx.int32 - state = Tx.meta_var(State(counter)) # noqa: F821 - state.add_one() - Tx.evaluate(state.counter) + T.device_entry() + counter: T.int32 + state = T.meta_var(State(counter)) # noqa: F821 + state.add_one() + T.evaluate(state.counter) # fmt: on code = test.script() @@ -1153,14 +1103,13 @@ def bomb(*args, **kwargs): return original(*args, **kwargs) src = """ -# from tvm.script import tirx as Tx +# from tvm.script import tirx as T -@Tx.prim_func +@T.prim_func def func(): - Tx.device_entry() - with Tx.thread(): - v: Tx.int32 - v = v + Tx.int32(1) + T.device_entry() + v: T.int32 + v = v + T.int32(1) """ # The ValueError propagates through the parser framework which wraps it # into a DiagnosticError. Before the fix the broad ``except Exception`` @@ -1171,24 +1120,23 @@ def func(): def test_scalar_annotation_syntax(): - """Test the scalar annotation syntax: x: Tx.int32 = init, x: Tx.int32, and T.let.""" + """Test the scalar annotation syntax: x: T.int32 = init, x: T.int32, and T.let.""" # fmt: off - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - with Tx.thread(): - # Scalar with init value - x: Tx.int32 = 0 - y: Tx.float16 = Tx.float16(1.0) - # Scalar without init - z: Tx.int32 - # Use scalars - x = x + Tx.int32(1) - z = x + Tx.int32(2) - y = y + Tx.float16(3.0) - Tx.evaluate(x + z) - Tx.evaluate(y) + T.device_entry() + # Scalar with init value + x: T.int32 = 0 + y: T.float16 = T.float16(1.0) + # Scalar without init + z: T.int32 + # Use scalars + x = x + T.int32(1) + z = x + T.int32(2) + y = y + T.float16(3.0) + T.evaluate(x + z) + T.evaluate(y) # fmt: on code = test.script() @@ -1199,39 +1147,37 @@ def test(): def test_scalar_allocbuffer_annotation_and_init_merge(): # fmt: off - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - with Tx.thread(): - phase_mma = Tx.alloc_local((1,), "int32") - phase_mma[0] = Tx.int32(0) - phase_aux = Tx.alloc_local((1,), "int32") - Tx.evaluate(phase_mma[0] + phase_aux[0]) + T.device_entry() + phase_mma = T.alloc_local((1,), "int32") + phase_mma[0] = T.int32(0) + phase_aux = T.alloc_local((1,), "int32") + T.evaluate(phase_mma[0] + phase_aux[0]) # fmt: on code = test.script() - assert "phase_mma: Tx.int32 = 0" in code - assert "phase_aux: Tx.int32" in code - assert "phase_mma = Tx.alloc_local" not in code - assert "phase_aux = Tx.alloc_local" not in code + assert "phase_mma: T.int32 = 0" in code + assert "phase_aux: T.int32" in code + assert "phase_mma = T.alloc_local" not in code + assert "phase_aux = T.alloc_local" not in code assert from_source(code).script() == code assert_structural_equal(test, from_source(code)) def test_scalar_allocbuffer_layout_none_keeps_alloc_local(): # fmt: off - @Tx.prim_func + @T.prim_func def test(): - Tx.device_entry() - with Tx.thread(): - phase_mma = Tx.alloc_local((1,), "int32", layout=None) - phase_mma[0] = Tx.int32(0) - Tx.evaluate(phase_mma[0]) + T.device_entry() + phase_mma = T.alloc_local((1,), "int32", layout=None) + phase_mma[0] = T.int32(0) + T.evaluate(phase_mma[0]) # fmt: on code = test.script() - assert 'phase_mma = Tx.alloc_local((1,), "int32", layout=None)' in code - assert "phase_mma: Tx.int32" not in code + assert 'phase_mma = T.alloc_local((1,), "int32", layout=None)' in code + assert "phase_mma: T.int32" not in code assert from_source(code).script() == code assert_structural_equal(test, from_source(code)) @@ -1246,8 +1192,8 @@ def test(): # fmt: on code = test.script() - assert "x: Tx.int32 = 0" in code - assert "x = Tx.alloc_buffer" not in code + assert "x: T.int32 = 0" in code + assert "x = T.alloc_buffer" not in code assert from_source(code).script() == code assert_structural_equal(test, from_source(code)) @@ -1256,18 +1202,17 @@ def test_let_annotation_syntax(): """Test explicit LetStmt syntax: T.let[T.int32] and T.let.""" # fmt: off - @Tx.prim_func + @T.prim_func def test(): - blockIdx_x = Tx.launch_thread("blockIdx.x", 4) - threadIdx_x = Tx.launch_thread("threadIdx.x", 128) + blockIdx_x = T.launch_thread("blockIdx.x", 4) + threadIdx_x = T.launch_thread("threadIdx.x", 128) # Explicit LetStmt with type - bx: Tx.let[Tx.int32] = blockIdx_x - tx: Tx.let[Tx.int32] = threadIdx_x + bx: T.let[T.int32] = blockIdx_x + tx: T.let[T.int32] = threadIdx_x # Explicit LetStmt with auto-type - combined: Tx.let = bx + tx - Tx.device_entry() - with Tx.thread(): - Tx.evaluate(bx + tx + combined) + combined: T.let = bx + tx + T.device_entry() + T.evaluate(bx + tx + combined) # fmt: on code = test.script() @@ -1279,17 +1224,16 @@ def test(): def test_annotation_syntax_comprehensive(): """Comprehensive test for scalar annotation, T.let, banned annotations, and bare assignment.""" - # 1. T.let with Tx.Var(PointerType) — round-trip + # 1. T.let with T.Var(PointerType) — round-trip # fmt: off - @Tx.prim_func + @T.prim_func def test_let_var(): - Tx.device_entry() - smem = Tx.alloc_shared([128], "float16") - with Tx.thread(): - ptr: Tx.let[Tx.Var(name="ptr", dtype=PointerType(PrimType("uint64")))] = Tx.reinterpret( - "handle", smem.access_ptr("rw") - ) - Tx.evaluate(ptr) + T.device_entry() + smem = T.alloc_shared([128], "float16") + ptr: T.let[T.Var(name="ptr", dtype=PointerType(PrimType("uint64")))] = T.reinterpret( + "handle", smem.access_ptr("rw") + ) + T.evaluate(ptr) # fmt: on code = test_let_var.script() assert from_source(code).script() == code @@ -1317,14 +1261,13 @@ def func(): # 4. Bare assignment to new variable creates scalar — round-trip # fmt: off - @Tx.prim_func + @T.prim_func def test_bare_assign(): - Tx.device_entry() - with Tx.thread(): - tid = Tx.launch_thread("threadIdx.x", 128) - x = tid + Tx.int32(1) - x = x + Tx.int32(2) - Tx.evaluate(x) + T.device_entry() + tid = T.launch_thread("threadIdx.x", 128) + x = tid + T.int32(1) + x = x + T.int32(2) + T.evaluate(x) # fmt: on code = test_bare_assign.script() assert from_source(code).script() == code @@ -1332,16 +1275,13 @@ def test_bare_assign(): def test_roundtrip_buffer_permute(): # fmt: off - @Tx.prim_func + @T.prim_func def test() -> None: - Tx.device_entry() - with Tx.cta(): - A = Tx.alloc_buffer([8, 4], dtype="float16", scope="local", - layout=Tx.TileLayout(Tx.S[(8, 4) : (4, 1)])) - B = A.permute(1, 0) - - with Tx.thread(): - B[0, 0] = Tx.float16(0) + T.device_entry() + A = T.alloc_buffer([8, 4], dtype="float16", scope="local", + layout=T.TileLayout(T.S[(8, 4) : (4, 1)])) + B = A.permute(1, 0) + B[0, 0] = T.float16(0) # fmt: on code = test.script() assert from_source(code).script() == code @@ -1350,17 +1290,14 @@ def test() -> None: def test_roundtrip_buffer_local_auto(): # fmt: off - @Tx.prim_func + @T.prim_func def test() -> None: - Tx.device_entry() - with Tx.cta(): - A = Tx.alloc_buffer([2], dtype="float16", scope="local") - A_layout = Tx.TileLayout(Tx.S[(1, 2) : (2, 1)]) - B = A.view(8, 8, layout=A_layout.tile(L_LANE, (8, 4), (1, 2))) - - with Tx.thread(): - B_local = B.local() - B_local[0] = Tx.float16(0) + T.device_entry() + A = T.alloc_buffer([2], dtype="float16", scope="local") + A_layout = T.TileLayout(T.S[(1, 2) : (2, 1)]) + B = A.view(8, 8, layout=A_layout.tile(L_LANE, (8, 4), (1, 2))) + B_local = B.local() + B_local[0] = T.float16(0) # fmt: on code = test.script() assert from_source(code).script() == code @@ -1388,17 +1325,14 @@ def test_buffer_local_ir(): """Verify .local() auto-infer: shape from storage shard extents, layout, shared data.""" # fmt: off - @Tx.prim_func + @T.prim_func def func() -> None: - Tx.device_entry() - with Tx.cta(): - A = Tx.alloc_buffer([2], dtype="float16", scope="local") - A_layout = Tx.TileLayout(Tx.S[(1, 2) : (2, 1)]) - B = A.view(8, 8, layout=A_layout.tile(L_LANE, (8, 4), (1, 2))) - - with Tx.thread(): - B_local = B.local() - B_local[0] = Tx.float16(0) + T.device_entry() + A = T.alloc_buffer([2], dtype="float16", scope="local") + A_layout = T.TileLayout(T.S[(1, 2) : (2, 1)]) + B = A.view(8, 8, layout=A_layout.tile(L_LANE, (8, 4), (1, 2))) + B_local = B.local() + B_local[0] = T.float16(0) # fmt: on bufs = _collect_buffers(func) @@ -1426,15 +1360,13 @@ def test_buffer_permute_ir(): """Verify .permute(1, 0): shape swapped, layout permuted, shared data.""" # fmt: off - @Tx.prim_func + @T.prim_func def func() -> None: - Tx.device_entry() - with Tx.cta(): - A = Tx.alloc_buffer([8, 4], dtype="float16", scope="local", - layout=Tx.TileLayout(Tx.S[(8, 4) : (4, 1)])) - B = A.permute(1, 0) - with Tx.thread(): - B[0, 0] = Tx.float16(0) + T.device_entry() + A = T.alloc_buffer([8, 4], dtype="float16", scope="local", + layout=T.TileLayout(T.S[(8, 4) : (4, 1)])) + B = A.permute(1, 0) + B[0, 0] = T.float16(0) # fmt: on bufs = _collect_buffers(func) @@ -1457,14 +1389,12 @@ def test_buffer_view_dtype_ir(): """Verify .view('float32') on float16: dtype correct, last dim halved, shared data.""" # fmt: off - @Tx.prim_func + @T.prim_func def func() -> None: - Tx.device_entry() - with Tx.cta(): - A = Tx.alloc_buffer([8, 8], dtype="float16", scope="local") - B = A.view("float32") - with Tx.thread(): - B[0, 0] = Tx.float32(0) + T.device_entry() + A = T.alloc_buffer([8, 8], dtype="float16", scope="local") + B = A.view("float32") + B[0, 0] = T.float32(0) # fmt: on bufs = _collect_buffers(func) @@ -1515,19 +1445,18 @@ def test_buffer_region_slice(): def test_roundtrip_serial_unroll_false(): - """Tx.serial(N, unroll=False) should round-trip.""" - - # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128,), "float32", scope="global") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - for _ in Tx.serial(10, unroll=False): - Tx.fill(A[0:32], Tx.float32(0)) + """T.serial(N, unroll=False) should round-trip.""" + + # fmt: off + @T.prim_func + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128,), "float32", scope="global") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + for _ in T.serial(10, unroll=False): + Tx.cta.fill(A[0:32], T.float32(0)) # fmt: on code = test.script() @@ -1538,19 +1467,18 @@ def test(A_ptr: Tx.handle) -> None: def test_roundtrip_serial_unroll_true(): - """Tx.serial(N, unroll=True) should round-trip as a pragma-unroll request.""" - - # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128,), "float32", scope="global") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - for _ in Tx.serial(10, unroll=True): - Tx.fill(A[0:32], Tx.float32(0)) + """T.serial(N, unroll=True) should round-trip as a pragma-unroll request.""" + + # fmt: off + @T.prim_func + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128,), "float32", scope="global") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + for _ in T.serial(10, unroll=True): + Tx.cta.fill(A[0:32], T.float32(0)) # fmt: on code = test.script() @@ -1564,16 +1492,15 @@ def test_roundtrip_serial_unroll_false_with_other_annotations(): """When other annotations exist alongside disable_unroll, fall back to full dict.""" # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128,), "float32", scope="global") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - for _ in Tx.serial(10, annotations={"disable_unroll": True, "custom": 42}): - Tx.fill(A[0:32], Tx.float32(0)) + @T.prim_func + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128,), "float32", scope="global") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + for _ in T.serial(10, annotations={"disable_unroll": True, "custom": 42}): + Tx.cta.fill(A[0:32], T.float32(0)) # fmt: on code = test.script() @@ -1586,25 +1513,25 @@ def test_roundtrip_unary_inplace(): """Single-arg unary ops (in-place) should round-trip.""" # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128,), "float32", scope="global") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - with Tx.warp(): - Tx.exp2(A[0:32]) - Tx.sqrt(A[32:64]) - Tx.reciprocal(A[64:96]) + @T.prim_func + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128,), "float32", scope="global") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + Tx.warp.exp2(A[0:32]) + Tx.warp.sqrt(A[32:64]) + Tx.warp.reciprocal(A[64:96]) # fmt: on code = test.script() # Each op should appear with a single arg (no duplicate src, no trailing Nones) - assert "Tx.exp2(A[0:32])" in code, f"expected single-arg exp2, got:\n{code}" - assert "Tx.sqrt(A[32:64])" in code, f"expected single-arg sqrt, got:\n{code}" - assert "Tx.reciprocal(A[64:96])" in code, f"expected single-arg reciprocal, got:\n{code}" + assert 'T.warp.exp2(A[0:32])' in code, f"expected single-arg exp2, got:\n{code}" + assert 'T.warp.sqrt(A[32:64])' in code, f"expected single-arg sqrt, got:\n{code}" + assert 'T.warp.reciprocal(A[64:96])' in code, ( + f"expected single-arg reciprocal, got:\n{code}" + ) assert "None" not in code, f"trailing None args should be trimmed:\n{code}" assert from_source(code).script() == code assert_structural_equal(test, from_source(code)) @@ -1614,38 +1541,37 @@ def test_roundtrip_unary_different_dst_src(): """Unary ops with different dst and src should keep both args.""" # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (128,), "float32", scope="global") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - with Tx.warp(): - Tx.exp2(A[0:32], B[0:32]) + @T.prim_func + def test(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128,), "float32", scope="global") + B = T.match_buffer(B_ptr, (128,), "float32", scope="global") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + Tx.warp.exp2(A[0:32], B[0:32]) # fmt: on code = test.script() - assert "Tx.exp2(A[0:32], B[0:32])" in code, f"different dst/src should keep both:\n{code}" + assert 'T.warp.exp2(A[0:32], B[0:32])' in code, ( + f"different dst/src should keep both:\n{code}" + ) assert from_source(code).script() == code assert_structural_equal(test, from_source(code)) def test_roundtrip_persistent_decorator(): - """@Tx.prim_func(persistent=True) should round-trip.""" - - # fmt: off - @Tx.prim_func(persistent=True) - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128,), "float32", scope="global") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - Tx.fill(A[0:32], Tx.float32(0)) + """@T.prim_func(persistent=True) should round-trip.""" + + # fmt: off + @T.prim_func(persistent=True) + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128,), "float32", scope="global") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + Tx.cta.fill(A[0:32], T.float32(0)) # fmt: on code = test.script() @@ -1659,15 +1585,14 @@ def test_roundtrip_persistent_not_present(): """Without persistent=True, the keyword should not appear.""" # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128,), "float32", scope="global") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - warp_id = Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - Tx.fill(A[0:32], Tx.float32(0)) + @T.prim_func + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128,), "float32", scope="global") + T.device_entry() + cta_id = T.cta_id([1]) + warp_id = T.warp_id([1]) + lane_id = T.lane_id([32]) + Tx.cta.fill(A[0:32], T.float32(0)) # fmt: on code = test.script() @@ -1679,27 +1604,26 @@ def test_warp_role(): from tvm.tirx.lang.warp_role import WarpRole # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128,), "float32", scope="global") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([4]) - warp_id = Tx.warp_id_in_wg([4]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - with WarpRole(warp_id, 1, regs=48): - Tx.fill(A[0:32], Tx.float32(0)) - with WarpRole(warp_id, 0, regs=232, increase=True): - Tx.fill(A[32:64], Tx.float32(1)) + @T.prim_func + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128,), "float32", scope="global") + T.device_entry() + cta_id = T.cta_id([1]) + wg_id = T.warpgroup_id([4]) + warp_id = T.warp_id_in_wg([4]) + lane_id = T.lane_id([32]) + with WarpRole(warp_id, 1, regs=48): + Tx.cta.fill(A[0:32], T.float32(0)) + with WarpRole(warp_id, 0, regs=232, increase=True): + Tx.cta.fill(A[32:64], T.float32(1)) # fmt: on code = test.script() assert "warp_id == 1" in code, f"should have warp_id==1 guard:\n{code}" assert "warp_id == 0" in code, f"should have warp_id==0 guard:\n{code}" assert "setmaxnreg" in code, f"should have setmaxnreg:\n{code}" - assert "with Tx.warp(warp_id == 1):" in code, f"should have guarded Tx.warp scope:\n{code}" - assert "with Tx.warp(warp_id == 0):" in code, f"should have guarded Tx.warp scope:\n{code}" + assert "if warp_id == 1:" in code, f"should have warp_id==1 if-guard:\n{code}" + assert "if warp_id == 0:" in code, f"should have warp_id==0 if-guard:\n{code}" # The printed code is valid TIR — it should parse back assert from_source(code).script() == code assert_structural_equal(test, from_source(code)) @@ -1710,17 +1634,16 @@ def test_warpgroup_role(): from tvm.tirx.lang.warp_role import WarpgroupRole # fmt: off - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (128,), "float32", scope="global") - Tx.device_entry() - cta_id = Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([4]) - warp_id_in_wg = Tx.warp_id_in_wg([4]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - with WarpgroupRole(wg_id, 2, regs=200, increase=True): - Tx.fill(A[0:32], Tx.float32(0)) + @T.prim_func + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128,), "float32", scope="global") + T.device_entry() + cta_id = T.cta_id([1]) + wg_id = T.warpgroup_id([4]) + warp_id_in_wg = T.warp_id_in_wg([4]) + lane_id = T.lane_id([32]) + with WarpgroupRole(wg_id, 2, regs=200, increase=True): + Tx.cta.fill(A[0:32], T.float32(0)) # fmt: on code = test.script() @@ -1731,34 +1654,31 @@ def test(A_ptr: Tx.handle) -> None: def test_vector_annotation_syntax_1d(): - """Test x: Tx.f32[N] produces the same IR as Tx.alloc_local([N], 'float32').""" + """Test x: T.f32[N] produces the same IR as T.alloc_local([N], 'float32').""" # fmt: off - @Tx.prim_func + @T.prim_func def func(): - Tx.device_entry() - with Tx.thread(): - v: Tx.float32[8] - Tx.evaluate(v[0]) # noqa: F821 + T.device_entry() + v: T.float32[8] + T.evaluate(v[0]) # noqa: F821 - @Tx.prim_func + @T.prim_func def func(): # noqa: F811 - Tx.device_entry() - with Tx.thread(): - v = Tx.alloc_local([8], "float32") - Tx.evaluate(v[0]) + T.device_entry() + v = T.alloc_local([8], "float32") + T.evaluate(v[0]) # fmt: on # func was redefined; compare first (annotation) with second (alloc_local). # Re-create the annotation version for comparison: # fmt: off - @Tx.prim_func + @T.prim_func def annotation_func(): - Tx.device_entry() - with Tx.thread(): - v: Tx.float32[8] - Tx.evaluate(v[0]) # noqa: F821 + T.device_entry() + v: T.float32[8] + T.evaluate(v[0]) # noqa: F821 # fmt: on # Verify both produce valid IR that round-trips through printer/parser @@ -1771,15 +1691,14 @@ def annotation_func(): def test_vector_annotation_syntax_multidim(): - """Test x: Tx.f32[M, N] produces the same IR as Tx.alloc_local([M, N], 'float32').""" + """Test x: T.f32[M, N] produces the same IR as T.alloc_local([M, N], 'float32').""" # fmt: off - @Tx.prim_func + @T.prim_func def func(): - Tx.device_entry() - with Tx.thread(): - m: Tx.float32[4, 8] - Tx.evaluate(m[0, 0]) # noqa: F821 + T.device_entry() + m: T.float32[4, 8] + T.evaluate(m[0, 0]) # noqa: F821 # fmt: on code = func.script() @@ -1789,17 +1708,16 @@ def func(): def test_vector_annotation_shorthand_aliases(): - """Test shorthand aliases: Tx.f32, Tx.i32, Tx.f16, etc.""" + """Test shorthand aliases: T.f32, T.i32, T.f16, etc.""" # fmt: off - @Tx.prim_func + @T.prim_func def func(): - Tx.device_entry() - with Tx.thread(): - a: Tx.f32[4] - b: Tx.i32[2] - c: Tx.f16[8] - Tx.evaluate(a[0] + Tx.float32(b[0]) + Tx.float32(c[0])) # noqa: F821 + T.device_entry() + a: T.f32[4] + b: T.i32[2] + c: T.f16[8] + T.evaluate(a[0] + T.float32(b[0]) + T.float32(c[0])) # noqa: F821 # fmt: on code = func.script() @@ -1808,18 +1726,17 @@ def func(): def test_scalar_annotation_shorthand(): - """Test x: Tx.f32 (scalar) shorthand produces same IR as x: Tx.float32.""" + """Test x: T.f32 (scalar) shorthand produces same IR as x: T.float32.""" # fmt: off - @Tx.prim_func + @T.prim_func def func(): - Tx.device_entry() - with Tx.thread(): - x: Tx.f32 = 0 - y: Tx.i32 - x = x + Tx.float32(1.0) - y = Tx.int32(2) - Tx.evaluate(x + Tx.float32(y)) + T.device_entry() + x: T.f32 = 0 + y: T.i32 + x = x + T.float32(1.0) + y = T.int32(2) + T.evaluate(x + T.float32(y)) # fmt: on code = func.script() @@ -1828,16 +1745,15 @@ def func(): def test_vector_annotation_with_python_variable_size(): - """Test x: Tx.f16[vec_size] where vec_size is a Python variable.""" + """Test x: T.f16[vec_size] where vec_size is a Python variable.""" vec_size = 16 # fmt: off - @Tx.prim_func + @T.prim_func def func(): - Tx.device_entry() - with Tx.thread(): - v: Tx.f16[vec_size] - Tx.evaluate(Tx.float32(v[0])) # noqa: F821 + T.device_entry() + v: T.f16[vec_size] + T.evaluate(T.float32(v[0])) # noqa: F821 # fmt: on code = func.script() @@ -1851,13 +1767,13 @@ def test_roundtrip_tmem_decl_buffer(): a .buffer suffix.""" # fmt: off - @Tx.prim_func + @T.prim_func def func(): - with Tx.launch_thread("blockIdx.x", 1): - Tx.launch_thread("threadIdx.x", 128) - addr = Tx.alloc_shared((1,), "uint32", layout=None) - addr_alias = Tx.Buffer((1,), "uint32", data=addr.data, scope="shared") - buf = Tx.decl_buffer((64,), scope="tmem", layout=None, allocated_addr=addr_alias[0]) + with T.launch_thread("blockIdx.x", 1): + T.launch_thread("threadIdx.x", 128) + addr = T.alloc_shared((1,), "uint32", layout=None) + addr_alias = T.Buffer((1,), "uint32", data=addr.data, scope="shared") + buf = T.decl_buffer((64,), scope="tmem", layout=None, allocated_addr=addr_alias[0]) # fmt: on code = func.script() @@ -1870,12 +1786,11 @@ def test_roundtrip_cuda_func_call_source_code(): inline string literal, not as a metadata reference.""" # fmt: off - @Tx.prim_func + @T.prim_func def func(): - Tx.device_entry() - with Tx.cta(): - desc = Tx.alloc_local((1,), "uint64") - Tx.cuda.func_call("my_func", Tx.address_of(desc[0]), source_code="\n__device__ void my_func(uint64_t* p) {\n *p = 42;\n}\n") # noqa: E501 + T.device_entry() + desc = T.alloc_local((1,), "uint64") + T.cuda.func_call("my_func", T.address_of(desc[0]), source_code="\n__device__ void my_func(uint64_t* p) {\n *p = 42;\n}\n") # noqa: E501 # fmt: on code = func.script() @@ -1887,15 +1802,15 @@ def test_roundtrip_cp_async_bulk_tensor_g2c(): """cp.async.bulk.tensor.g2c must round-trip with *coords at end.""" # fmt: off - @Tx.prim_func(check_well_formed=False) - def func(A_ptr: Tx.handle): - _ = Tx.match_buffer(A_ptr, (16, 16), "float32") - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - with Tx.launch_thread("blockIdx.x", 1): - Tx.launch_thread("threadIdx.x", 128) - A_smem = Tx.alloc_buffer((16, 16), "float32", scope="shared") - Tx.ptx.cp_async.bulk.tensor.g2c( - 2, A_smem.data, 0, Tx.address_of(A_map), 0, 1, "", 0, 0 + @T.prim_func(check_well_formed=False) + def func(A_ptr: T.handle): + _ = T.match_buffer(A_ptr, (16, 16), "float32") + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + with T.launch_thread("blockIdx.x", 1): + T.launch_thread("threadIdx.x", 128) + A_smem = T.alloc_buffer((16, 16), "float32", scope="shared") + T.ptx.cp_async.bulk.tensor.g2c( + 2, A_smem.data, 0, T.address_of(A_map), 0, 1, "", 0, 0 ) # fmt: on @@ -1908,15 +1823,15 @@ def test_roundtrip_cp_async_bulk_tensor_s2g(): """cp.async.bulk.tensor.s2g must round-trip with *coords at end.""" # fmt: off - @Tx.prim_func(check_well_formed=False) - def func(A_ptr: Tx.handle): - _ = Tx.match_buffer(A_ptr, (16, 16), "float32") - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - with Tx.launch_thread("blockIdx.x", 1): - Tx.launch_thread("threadIdx.x", 128) - A_smem = Tx.alloc_buffer((16, 16), "float32", scope="shared") - Tx.ptx.cp_async.bulk.tensor.s2g( - 2, A_smem.data, Tx.address_of(A_map), "", 0, 0 + @T.prim_func(check_well_formed=False) + def func(A_ptr: T.handle): + _ = T.match_buffer(A_ptr, (16, 16), "float32") + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + with T.launch_thread("blockIdx.x", 1): + T.launch_thread("threadIdx.x", 128) + A_smem = T.alloc_buffer((16, 16), "float32", scope="shared") + T.ptx.cp_async.bulk.tensor.s2g( + 2, A_smem.data, T.address_of(A_map), "", 0, 0 ) # fmt: on @@ -1929,14 +1844,14 @@ def test_roundtrip_cp_async_bulk_tensor_g2c_prefetch(): """cp.async.bulk.tensor.g2c_prefetch must round-trip with *coords at end.""" # fmt: off - @Tx.prim_func(check_well_formed=False) - def func(A_ptr: Tx.handle): - _ = Tx.match_buffer(A_ptr, (16, 16), "float32") - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - with Tx.launch_thread("blockIdx.x", 1): - Tx.launch_thread("threadIdx.x", 128) - Tx.ptx.cp_async.bulk.tensor.g2c_prefetch( - 2, Tx.address_of(A_map), "", 0, 0 + @T.prim_func(check_well_formed=False) + def func(A_ptr: T.handle): + _ = T.match_buffer(A_ptr, (16, 16), "float32") + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + with T.launch_thread("blockIdx.x", 1): + T.launch_thread("threadIdx.x", 128) + T.ptx.cp_async.bulk.tensor.g2c_prefetch( + 2, T.address_of(A_map), "", 0, 0 ) # fmt: on @@ -1949,15 +1864,15 @@ def test_roundtrip_cp_async_bulk_tensor_s2g_reduce(): """cp.async.bulk.tensor.s2g_reduce must round-trip with *coords at end.""" # fmt: off - @Tx.prim_func(check_well_formed=False) - def func(A_ptr: Tx.handle): - _ = Tx.match_buffer(A_ptr, (16, 16), "float32") - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - with Tx.launch_thread("blockIdx.x", 1): - Tx.launch_thread("threadIdx.x", 128) - A_smem = Tx.alloc_buffer((16, 16), "float32", scope="shared") - Tx.ptx.cp_async.bulk.tensor.s2g_reduce( - 2, A_smem.data, Tx.address_of(A_map), "", "add", 0, 0 + @T.prim_func(check_well_formed=False) + def func(A_ptr: T.handle): + _ = T.match_buffer(A_ptr, (16, 16), "float32") + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + with T.launch_thread("blockIdx.x", 1): + T.launch_thread("threadIdx.x", 128) + A_smem = T.alloc_buffer((16, 16), "float32", scope="shared") + T.ptx.cp_async.bulk.tensor.s2g_reduce( + 2, A_smem.data, T.address_of(A_map), "", "add", 0, 0 ) # fmt: on diff --git a/tests/python/tirx/test_printer_tir_namespaces.py b/tests/python/tirx/test_printer_tir_namespaces.py index 56c185f12656..c79d700c8e01 100644 --- a/tests/python/tirx/test_printer_tir_namespaces.py +++ b/tests/python/tirx/test_printer_tir_namespaces.py @@ -16,38 +16,40 @@ # under the License. +import tvm from tvm import tirx as tir +from tvm.script import tirx as T def _assert_print(obj, expected): - # Use Tx prefix so standalone TIR nodes (non-PrimFunc) print as Tx to match tirx namespace - out = obj.script(verbose_expr=True, extra_config={"tirx.prefix": "Tx"}).strip() + # Standalone TIR nodes use the canonical tirx script prefix. + out = obj.script(verbose_expr=True, extra_config={"tirx.prefix": "T"}).strip() assert out == expected.strip() def test_printer_cuda_namespace_printf(): node = tir.Evaluate(tir.op.cuda_printf("x=%d", tir.IntImm("int32", 1))) - _assert_print(node, 'Tx.cuda.printf("x=%d", 1)') + _assert_print(node, 'T.cuda.printf("x=%d", 1)') def test_printer_ptx_namespace_wgmma_commit_group(): node = tir.Evaluate(tir.op.ptx_wgmma_commit_group()) - _assert_print(node, "Tx.ptx.wgmma.commit_group()") + _assert_print(node, "T.ptx.wgmma.commit_group()") def test_printer_cuda_cluster_sync(): node = tir.Evaluate(tir.op.cuda_cluster_sync()) - _assert_print(node, "Tx.cuda.cluster_sync()") + _assert_print(node, "T.cuda.cluster_sync()") def test_printer_ptx_namespace_cp_async_wait_group(): node = tir.Evaluate(tir.op.ptx_cp_async_wait_group(tir.IntImm("int32", 0))) - _assert_print(node, "Tx.ptx.cp_async.wait_group(0)") + _assert_print(node, "T.ptx.cp_async.wait_group(0)") def test_printer_nvshmem_namespace(): node = tir.Evaluate(tir.op.nvshmem_fence()) - _assert_print(node, "Tx.nvshmem.fence()") + _assert_print(node, "T.nvshmem.fence()") def test_printer_ptx_more(): @@ -57,62 +59,62 @@ def test_printer_ptx_more(): # New API: (trans, num, dtype, smem_ptr, *dst_handles). # .x1.b16 has 1 dst register, so 1 dst handle. tir.op.ptx_ldmatrix(True, 1, ".b16", s, r), - 's = Tx.handle()\nr = Tx.handle()\nTx.ptx.ldmatrix("void", Tx.bool(True), 1, ".b16", s, r)', + 's = T.handle()\nr = T.handle()\nT.ptx.ldmatrix(T.bool(True), 1, ".b16", s, r)', ) _assert_print( # New API: (trans, num, dtype, smem_ptr, *src_handles). # .x1.b16 has 1 src register, so 1 src handle. tir.op.ptx_stmatrix(False, 1, ".b16", s, r), ( - "s = Tx.handle()\nr = Tx.handle()\nTx.ptx.stmatrix(" - 'Tx.bool(False), 1, ".b16", "m8n8", "shared", s, r)' + "s = T.handle()\nr = T.handle()\nT.ptx.stmatrix(" + 'T.bool(False), 1, ".b16", "m8n8", "shared", s, r)' ), ) - _assert_print(tir.op.ptx_setmaxnreg(True, 64), "Tx.ptx.setmaxnreg(Tx.bool(True), 64)") - _assert_print(tir.op.ptx_fetch_register(32, "laneid"), 'Tx.ptx.fetch_register(32, "laneid")') - _assert_print(tir.op.ptx_wgmma_fence(), "Tx.ptx.wgmma.fence()") - _assert_print(tir.op.ptx_wgmma_wait_group(0), "Tx.ptx.wgmma.wait_group(0)") - _assert_print(tir.op.ptx_cp_async_commit_group(), "Tx.ptx.cp_async.commit_group()") - _assert_print(tir.op.ptx_cp_async_bulk_commit_group(), "Tx.ptx.cp_async.bulk.commit_group()") + _assert_print(tir.op.ptx_setmaxnreg(True, 64), "T.ptx.setmaxnreg(T.bool(True), 64)") + _assert_print(tir.op.ptx_fetch_register(32, "laneid"), 'T.ptx.fetch_register(32, "laneid")') + _assert_print(tir.op.ptx_wgmma_fence(), "T.ptx.wgmma.fence()") + _assert_print(tir.op.ptx_wgmma_wait_group(0), "T.ptx.wgmma.wait_group(0)") + _assert_print(tir.op.ptx_cp_async_commit_group(), "T.ptx.cp_async.commit_group()") + _assert_print(tir.op.ptx_cp_async_bulk_commit_group(), "T.ptx.cp_async.bulk.commit_group()") _assert_print( tir.op.ptx_cp_async_bulk_wait_group(0, True), - "Tx.ptx.cp_async.bulk.wait_group(0, Tx.bool(True))", + "T.ptx.cp_async.bulk.wait_group(0, T.bool(True))", ) - _assert_print(tir.op.ptx_cp_async_mbarrier_arrive(0), "Tx.ptx.cp_async.mbarrier.arrive(0)") - _assert_print(tir.op.ptx_fence("acq_rel", "gpu"), 'Tx.ptx.fence("acq_rel", "gpu")') - _assert_print(tir.op.ptx_fence("sc", "cta"), 'Tx.ptx.fence("sc", "cta")') + _assert_print(tir.op.ptx_cp_async_mbarrier_arrive(0), "T.ptx.cp_async.mbarrier.arrive(0)") + _assert_print(tir.op.ptx_fence("acq_rel", "gpu"), 'T.ptx.fence("acq_rel", "gpu")') + _assert_print(tir.op.ptx_fence("sc", "cta"), 'T.ptx.fence("sc", "cta")') _assert_print( - tir.op.ptx_fence_proxy_async("shared::cta"), 'Tx.ptx.fence.proxy_async("shared::cta")' + tir.op.ptx_fence_proxy_async("shared::cta"), 'T.ptx.fence.proxy_async("shared::cta")' ) - _assert_print(tir.op.ptx_fence_proxy_async("global"), 'Tx.ptx.fence.proxy_async("global")') - _assert_print(tir.op.ptx_fence_mbarrier_init(), "Tx.ptx.fence.mbarrier_init()") - _assert_print(tir.op.ptx_elect_sync(), "Tx.ptx.elect_sync()") + _assert_print(tir.op.ptx_fence_proxy_async("global"), 'T.ptx.fence.proxy_async("global")') + _assert_print(tir.op.ptx_fence_mbarrier_init(), "T.ptx.fence.mbarrier_init()") + _assert_print(tir.op.ptx_elect_sync(), "T.ptx.elect_sync()") lane = tir.Var("lane", "int32") _assert_print( tir.op.selector(lane, tir.op.ptx_elect_sync()), - "lane = Tx.int32()\nTx.selector(lane, Tx.ptx.elect_sync())", + "lane = T.int32()\nT.selector(lane, T.ptx.elect_sync())", ) _assert_print( tir.op.ptx_ld_global_acquire(r, s), - "r = Tx.handle()\ns = Tx.handle()\nTx.ptx.ld_global_acquire(r, s)", + "r = T.handle()\ns = T.handle()\nT.ptx.ld_global_acquire(r, s)", ) _assert_print( - tir.op.ptx_map_shared_rank(r, 2), 'r = Tx.handle()\nTx.ptx.mapa(r, 2, "", "u64", "uint64")' + tir.op.ptx_map_shared_rank(r, 2), 'r = T.handle()\nT.ptx.mapa(r, 2, "", "u64", "uint64")' ) - _assert_print(tir.op.ptx_bar_arrive(0, 128), "Tx.ptx.bar.arrive(0, 128)") - _assert_print(tir.op.ptx_bar_sync(0, 128), "Tx.ptx.bar.sync(0, 128)") + _assert_print(tir.op.ptx_bar_arrive(0, 128), "T.ptx.bar.arrive(0, 128)") + _assert_print(tir.op.ptx_bar_sync(0, 128), "T.ptx.bar.sync(0, 128)") _assert_print( - tir.op.ptx_tcgen05_alloc(s, 64, 1), "s = Tx.handle()\nTx.ptx.tcgen05.alloc(s, 64, 1)" + tir.op.ptx_tcgen05_alloc(s, 64, 1), "s = T.handle()\nT.ptx.tcgen05.alloc(s, 64, 1)" ) _assert_print( - tir.op.ptx_tcgen05_dealloc(s, 64, 1), "s = Tx.handle()\nTx.ptx.tcgen05.dealloc(s, 64, 1)" + tir.op.ptx_tcgen05_dealloc(s, 64, 1), "s = T.handle()\nT.ptx.tcgen05.dealloc(s, 64, 1)" ) d = tir.Var("d", "handle") a = tir.Var("a", "handle") b = tir.Var("b", "handle") _assert_print( tir.op.ptx_tcgen05_encode_matrix_descriptor(d, a, 1, 2, 0), - "d = Tx.handle()\na = Tx.handle()\nTx.ptx.tcgen05.encode_matrix_descriptor(d, a, 1, 2, 0)", + "d = T.handle()\na = T.handle()\nT.ptx.tcgen05.encode_matrix_descriptor(d, a, 1, 2, 0)", ) _assert_print( tir.op.ptx_tcgen05_encode_instr_descriptor( @@ -131,7 +133,7 @@ def test_printer_ptx_more(): sat_d=False, is_sparse=False, ), - 'd = Tx.handle()\nTx.ptx.tcgen05.encode_instr_descriptor(d, "f16", "f16", "f16", 16, 16, 16, Tx.bool(True), Tx.bool(False), 1, Tx.bool(False), Tx.bool(False), Tx.bool(False), Tx.bool(False))', # noqa: E501 + 'd = T.handle()\nT.ptx.tcgen05.encode_instr_descriptor(d, "f16", "f16", "f16", 16, 16, 16, T.bool(True), T.bool(False), 1, T.bool(False), T.bool(False), T.bool(False), T.bool(False))', # noqa: E501 ) _assert_print( tir.op.ptx_tcgen05_encode_instr_descriptor_block_scaled( @@ -153,118 +155,154 @@ def test_printer_ptx_more(): neg_a=False, neg_b=False, ), - "d = Tx.handle()\n" - "a = Tx.handle()\n" - "b = Tx.handle()\n" - 'Tx.ptx.tcgen05.encode_instr_descriptor_block_scaled(d, "f16", "f16", "f16", "f16", "f16", a, b, 16, 16, 16, Tx.bool(True), Tx.bool(False), 1, Tx.bool(False), Tx.bool(False), Tx.bool(True))', # noqa: E501 + "d = T.handle()\n" + "a = T.handle()\n" + "b = T.handle()\n" + 'T.ptx.tcgen05.encode_instr_descriptor_block_scaled(d, "f16", "f16", "f16", "f16", "f16", a, b, 16, 16, 16, T.bool(True), T.bool(False), 1, T.bool(False), T.bool(False), T.bool(True))', # noqa: E501 ) _assert_print( tir.op.ptx_tcgen05_cp(a, d, shape="64x128b", cta_group=1, multicast="warpx2::02_13"), - "a = Tx.handle()\n" - "d = Tx.handle()\n" - 'Tx.ptx.tcgen05.cp(a, d, "64x128b", 1, "warpx2::02_13", "", 0, 0)', + "a = T.handle()\n" + "d = T.handle()\n" + 'T.ptx.tcgen05.cp(a, d, "64x128b", 1, "warpx2::02_13", "", 0, 0)', ) - _assert_print(tir.op.ptx_tcgen05_shift(a, 1), "a = Tx.handle()\nTx.ptx.tcgen05.shift(a, 1)") + _assert_print(tir.op.ptx_tcgen05_shift(a, 1), "a = T.handle()\nT.ptx.tcgen05.shift(a, 1)") _assert_print( tir.op.ptx_tcgen05_ld(a, 0, shape="16x64b", num=1, row=0, col=0, pack=False), - 'a = Tx.handle()\nTx.ptx.tcgen05.ld(a, 0, 0, "16x64b", 1, Tx.bool(False), 0)', + 'a = T.handle()\nT.ptx.tcgen05.ld(a, 0, 0, "16x64b", 1, T.bool(False), 0)', ) _assert_print( tir.op.ptx_tcgen05_st(a, 0, shape="16x64b", num=1, row=0, col=0, unpack=False), - 'a = Tx.handle()\nTx.ptx.tcgen05.st(a, 0, 0, "16x64b", 1, Tx.bool(False), 0)', + 'a = T.handle()\nT.ptx.tcgen05.st(a, 0, 0, "16x64b", 1, T.bool(False), 0)', ) - _assert_print(tir.op.ptx_tcgen05_wait_ld(), "Tx.ptx.tcgen05.wait.ld()") - _assert_print(tir.op.ptx_tcgen05_wait_st(), "Tx.ptx.tcgen05.wait.st()") + _assert_print(tir.op.ptx_tcgen05_wait_ld(), "T.ptx.tcgen05.wait.ld()") + _assert_print(tir.op.ptx_tcgen05_wait_st(), "T.ptx.tcgen05.wait.st()") _assert_print( - tir.op.ptx_tcgen05_commit(a, 1, 0), "a = Tx.handle()\nTx.ptx.tcgen05.commit(a, 1, 0)" + tir.op.ptx_tcgen05_commit(a, 1, 0), "a = T.handle()\nT.ptx.tcgen05.commit(a, 1, 0)" ) _assert_print( - tir.op.ptx_tcgen05_relinquish_alloc_permit(1), "Tx.ptx.tcgen05.relinquish_alloc_permit(1)" + tir.op.ptx_tcgen05_relinquish_alloc_permit(1), "T.ptx.tcgen05.relinquish_alloc_permit(1)" ) def test_printer_ptx_mbarrier(): bar = tir.Var("bar", "handle") _assert_print( - tir.op.ptx_mbarrier_init(bar, 32), "bar = Tx.handle()\nTx.ptx.mbarrier.init(bar, 32)" + tir.op.ptx_mbarrier_init(bar, 32), "bar = T.handle()\nT.ptx.mbarrier.init(bar, 32)" ) - _assert_print(tir.op.ptx_mbarrier_arrive(bar), "bar = Tx.handle()\nTx.ptx.mbarrier.arrive(bar)") + _assert_print(tir.op.ptx_mbarrier_arrive(bar), "bar = T.handle()\nT.ptx.mbarrier.arrive(bar)") _assert_print( tir.op.ptx_mbarrier_arrive_expect_tx(bar, 128), - "bar = Tx.handle()\nTx.ptx.mbarrier.arrive.expect_tx(bar, 128)", + "bar = T.handle()\nT.ptx.mbarrier.arrive.expect_tx(bar, 128)", ) _assert_print( - tir.op.ptx_mbarrier_try_wait(bar, 1), "bar = Tx.handle()\nTx.ptx.mbarrier.try_wait(bar, 1)" + tir.op.ptx_mbarrier_try_wait(bar, 1), "bar = T.handle()\nT.ptx.mbarrier.try_wait(bar, 1)" ) - _assert_print(tir.op.cuda_cluster_sync(), "Tx.cuda.cluster_sync()") + _assert_print(tir.op.cuda_cluster_sync(), "T.cuda.cluster_sync()") def test_printer_cuda_more(): p = tir.Var("p", "handle") - _assert_print(tir.op.cuda_thread_fence(), "Tx.cuda.thread_fence()") - _assert_print(tir.op.cuda_warp_sync(), "Tx.cuda.warp_sync()") - _assert_print(tir.op.cuda_cta_sync(), "Tx.cuda.cta_sync()") - _assert_print(tir.op.cuda_grid_sync(), "Tx.cuda.grid_sync()") - _assert_print(tir.op.cuda_cluster_sync(), "Tx.cuda.cluster_sync()") - _assert_print(tir.op.cuda_syncthreads_and(1), "Tx.cuda.syncthreads_and(1)") - _assert_print(tir.op.cuda_syncthreads_or(1), "Tx.cuda.syncthreads_or(1)") - _assert_print(tir.op.cuda_nano_sleep(100), "Tx.cuda.nano_sleep(100)") + _assert_print(tir.op.cuda_thread_fence(), "T.cuda.thread_fence()") + _assert_print(tir.op.cuda_warp_sync(), "T.cuda.warp_sync()") + _assert_print(tir.op.cuda_cta_sync(), "T.cuda.cta_sync()") + _assert_print(tir.op.cuda_grid_sync(), "T.cuda.grid_sync()") + _assert_print(tir.op.cuda_cluster_sync(), "T.cuda.cluster_sync()") + _assert_print(tir.op.cuda_syncthreads_and(1), "T.cuda.syncthreads_and(1)") + _assert_print(tir.op.cuda_syncthreads_or(1), "T.cuda.syncthreads_or(1)") + _assert_print(tir.op.cuda_nano_sleep(100), "T.cuda.nano_sleep(100)") _assert_print( tir.op.cuda_atomic_add(p, tir.IntImm("int32", 1)), - "p = Tx.handle()\nTx.cuda.atomic_add(p, 1)", + "p = T.handle()\nT.cuda.atomic_add(p, 1)", ) - _assert_print(tir.op.cuda_atomic_cas(p, 1, 2), "p = Tx.handle()\nTx.cuda.atomic_cas(p, 1, 2)") - _assert_print(tir.op.cuda_ldg(p, "float32"), 'p = Tx.handle()\nTx.cuda.ldg(p, "float32")') + _assert_print(tir.op.cuda_atomic_cas(p, 1, 2), "p = T.handle()\nT.cuda.atomic_cas(p, 1, 2)") + _assert_print(tir.op.cuda_ldg(p, "float32"), 'p = T.handle()\nT.cuda.ldg(p, "float32")') _assert_print( - tir.op.cuda_func_call("f", 1, source_code=""), 'Tx.cuda.func_call("f", 1, source_code="")' + tir.op.cuda_func_call("f", 1, source_code=""), 'T.cuda.func_call("f", 1, source_code="")' ) +def test_printer_cuda_low_level_warp_intrinsics_roundtrip(): + @T.prim_func(check_well_formed=False) + def kernel(): + x = T.int32() + mask = T.cuda.__activemask() + T.evaluate(T.cuda.__shfl_sync(mask, x, 0, 32)) + T.evaluate(T.cuda.__shfl_up_sync(mask, x, 1, 32)) + T.evaluate(T.cuda.__shfl_down_sync(mask, x, 1, 32)) + T.evaluate(T.cuda.__shfl_xor_sync(mask, x, 1, 32)) + + code = kernel.script() + assert "T.cuda.__activemask()" in code + assert "T.cuda.__shfl_sync(" in code + assert "T.cuda.__shfl_up_sync(" in code + assert "T.cuda.__shfl_down_sync(" in code + assert "T.cuda.__shfl_xor_sync(" in code + assert "T.tirx." not in code + assert tvm.script.from_source(code).script() == code + + +def test_printer_webgpu_namespace_roundtrip(): + @T.prim_func(check_well_formed=False) + def kernel(): + x = T.int32() + T.evaluate(T.webgpu.subgroup_shuffle(x, 0)) + T.evaluate(T.webgpu.subgroup_shuffle_up(x, 1)) + T.evaluate(T.webgpu.subgroup_shuffle_down(x, 1)) + + code = kernel.script() + assert "T.webgpu.subgroup_shuffle(" in code + assert "T.webgpu.subgroup_shuffle_up(" in code + assert "T.webgpu.subgroup_shuffle_down(" in code + assert "T.tirx." not in code + assert tvm.script.from_source(code).script() == code + + def test_printer_nvshmem_more(): p = tir.Var("p", "handle") - _assert_print(tir.op.nvshmem_my_pe(), "Tx.nvshmem.my_pe()") - _assert_print(tir.op.nvshmem_n_pes(), "Tx.nvshmem.n_pes()") + _assert_print(tir.op.nvshmem_my_pe(), "T.nvshmem.my_pe()") + _assert_print(tir.op.nvshmem_n_pes(), "T.nvshmem.n_pes()") _assert_print( tir.op.nvshmem_signal_op(p, 1, "set", 0), - 'p = Tx.handle()\nTx.nvshmem.signal_op(p, 1, "set", 0)', + 'p = T.handle()\nT.nvshmem.signal_op(p, 1, "set", 0)', ) _assert_print( tir.op.nvshmem_wait_until(p, "eq", 0), - 'p = Tx.handle()\nTx.nvshmem.wait_until(p, "eq", 0, "uint64_t")', + 'p = T.handle()\nT.nvshmem.wait_until(p, "eq", 0, "uint64_t")', ) - _assert_print(tir.op.nvshmem_quiet(), "Tx.nvshmem.quiet()") - _assert_print(tir.op.nvshmem_barrier_all(), "Tx.nvshmem.barrier_all()") + _assert_print(tir.op.nvshmem_quiet(), "T.nvshmem.quiet()") + _assert_print(tir.op.nvshmem_barrier_all(), "T.nvshmem.barrier_all()") _assert_print( tir.op.nvshmem_getmem_nbi(p, p, 16, 0), - "p = Tx.handle()\nTx.nvshmem.getmem_nbi(p, p, 16, 0)", + "p = T.handle()\nT.nvshmem.getmem_nbi(p, p, 16, 0)", ) _assert_print( tir.op.nvshmem_getmem_nbi_warp(p, p, 16, 0), - "p = Tx.handle()\nTx.nvshmem.getmem_nbi.warp(p, p, 16, 0)", + "p = T.handle()\nT.nvshmem.getmem_nbi.warp(p, p, 16, 0)", ) _assert_print( tir.op.nvshmem_putmem_nbi_block(p, p, 16, 0), - "p = Tx.handle()\nTx.nvshmem.putmem_nbi.block(p, p, 16, 0)", + "p = T.handle()\nT.nvshmem.putmem_nbi.block(p, p, 16, 0)", ) _assert_print( tir.op.nvshmem_putmem_nbi(p, p, 16, 0), - "p = Tx.handle()\nTx.nvshmem.putmem_nbi(p, p, 16, 0)", + "p = T.handle()\nT.nvshmem.putmem_nbi(p, p, 16, 0)", ) _assert_print( tir.op.nvshmem_putmem_nbi_warp(p, p, 16, 0), - "p = Tx.handle()\nTx.nvshmem.putmem_nbi.warp(p, p, 16, 0)", + "p = T.handle()\nT.nvshmem.putmem_nbi.warp(p, p, 16, 0)", ) _assert_print( tir.op.nvshmem_putmem_signal_nbi(p, p, 16, p, 1, "set", 0), - 'p = Tx.handle()\nTx.nvshmem.putmem_signal_nbi(p, p, 16, p, 1, "set", 0)', + 'p = T.handle()\nT.nvshmem.putmem_signal_nbi(p, p, 16, p, 1, "set", 0)', ) _assert_print( tir.op.nvshmem_putmem_signal_nbi_warp(p, p, 16, p, 1, "set", 0), - 'p = Tx.handle()\nTx.nvshmem.putmem_signal_nbi.warp(p, p, 16, p, 1, "set", 0)', + 'p = T.handle()\nT.nvshmem.putmem_signal_nbi.warp(p, p, 16, p, 1, "set", 0)', ) _assert_print( tir.op.nvshmem_putmem_signal_nbi_block(p, p, 16, p, 1, "set", 0), - 'p = Tx.handle()\nTx.nvshmem.putmem_signal_nbi.block(p, p, 16, p, 1, "set", 0)', + 'p = T.handle()\nT.nvshmem.putmem_signal_nbi.block(p, p, 16, p, 1, "set", 0)', ) @@ -275,81 +313,81 @@ def test_printer_nki_namespace(): b0 = B[0] _assert_print( tir.op.nki_load(a0, b0), - 'A = Tx.Buffer((1,), "float16")\nB = Tx.Buffer((1,), "float16")\nTx.nki.load(A, B)', + 'A = T.Buffer((1,), "float16")\nB = T.Buffer((1,), "float16")\nT.nki.load(A, B)', ) _assert_print( tir.op.nki_store(a0, b0), - 'A = Tx.Buffer((1,), "float16")\nB = Tx.Buffer((1,), "float16")\nTx.nki.store(A, B)', + 'A = T.Buffer((1,), "float16")\nB = T.Buffer((1,), "float16")\nT.nki.store(A, B)', ) _assert_print( tir.op.nki_tensor_copy(a0, b0), - 'A = Tx.Buffer((1,), "float16")\nB = Tx.Buffer((1,), "float16")\nTx.nki.tensor_copy(A, B)', + 'A = T.Buffer((1,), "float16")\nB = T.Buffer((1,), "float16")\nT.nki.tensor_copy(A, B)', ) _assert_print( tir.op.nki_matmul(a0, a0, b0), - 'A = Tx.Buffer((1,), "float16")\n' - 'B = Tx.Buffer((1,), "float16")\n' - "Tx.nki.matmul(A, A, B, Tx.bool(True))", + 'A = T.Buffer((1,), "float16")\n' + 'B = T.Buffer((1,), "float16")\n' + "T.nki.matmul(A, A, B, T.bool(True))", ) _assert_print( tir.op.nki_activation(a0, b0, "relu", 0.0, 1.0), - 'A = Tx.Buffer((1,), "float16")\n' - 'B = Tx.Buffer((1,), "float16")\n' - 'Tx.nki.activation(A, B, "relu", Tx.float32(0.0), Tx.float32(1.0))', + 'A = T.Buffer((1,), "float16")\n' + 'B = T.Buffer((1,), "float16")\n' + 'T.nki.activation(A, B, "relu", T.float32(0.0), T.float32(1.0))', ) _assert_print( tir.op.nki_memset(a0, 0), - 'A = Tx.Buffer((1,), "float16")\nTx.nki.memset(A, 0)', + 'A = T.Buffer((1,), "float16")\nT.nki.memset(A, 0)', ) _assert_print( tir.op.nki_identity(a0, 1), - 'A = Tx.Buffer((1,), "float16")\nTx.nki.identity(A, 1)', + 'A = T.Buffer((1,), "float16")\nT.nki.identity(A, 1)', ) _assert_print( tir.op.nki_reciprocal(a0, b0), - 'A = Tx.Buffer((1,), "float16")\nB = Tx.Buffer((1,), "float16")\nTx.nki.reciprocal(A, B)', + 'A = T.Buffer((1,), "float16")\nB = T.Buffer((1,), "float16")\nT.nki.reciprocal(A, B)', ) _assert_print( tir.op.nki_tensorreduce(a0, b0, "sum", False, 0), - 'A = Tx.Buffer((1,), "float16")\n' - 'B = Tx.Buffer((1,), "float16")\n' - 'Tx.nki.tensorreduce(A, B, "sum", Tx.bool(False), 0)', + 'A = T.Buffer((1,), "float16")\n' + 'B = T.Buffer((1,), "float16")\n' + 'T.nki.tensorreduce(A, B, "sum", T.bool(False), 0)', ) _assert_print( tir.op.nki_tensortensor(a0, a0, b0, "add"), - 'A = Tx.Buffer((1,), "float16")\n' - 'B = Tx.Buffer((1,), "float16")\n' - 'Tx.nki.tensortensor(A, A, B, "add")', + 'A = T.Buffer((1,), "float16")\n' + 'B = T.Buffer((1,), "float16")\n' + 'T.nki.tensortensor(A, A, B, "add")', ) _assert_print( tir.op.nki_tensorscalar(a0, a0, 1.0, "mul", False), - 'A = Tx.Buffer((1,), "float16")\n' - 'Tx.nki.tensorscalar(A, A, Tx.float32(1.0), "mul", Tx.bool(False))', + 'A = T.Buffer((1,), "float16")\n' + 'T.nki.tensorscalar(A, A, T.float32(1.0), "mul", T.bool(False))', ) _assert_print( tir.op.nki_tensorscalar_reduce(a0, a0, 1.0, "mul", "sum", False), - 'A = Tx.Buffer((1,), "float16")\n' - 'Tx.nki.tensorscalar_reduce(A, A, Tx.float32(1.0), "mul", "sum", Tx.bool(False), Tx.bool(False))', # noqa: E501 + 'A = T.Buffer((1,), "float16")\n' + 'T.nki.tensorscalar_reduce(A, A, T.float32(1.0), "mul", "sum", T.bool(False), T.bool(False))', # noqa: E501 ) _assert_print( tir.op.nki_scalar_tensor_tensor(a0, a0, 1.0, a0, "add", "add"), - 'A = Tx.Buffer((1,), "float16")\n' - 'Tx.nki.scalar_tensor_tensor(A, A, Tx.float32(1.0), A, "add", "add", Tx.bool(False), Tx.bool(False))', # noqa: E501 + 'A = T.Buffer((1,), "float16")\n' + 'T.nki.scalar_tensor_tensor(A, A, T.float32(1.0), A, "add", "add", T.bool(False), T.bool(False))', # noqa: E501 ) _assert_print( tir.op.nki_scalar_tensor_scalar(a0, a0, 1.0, 1.0, "add", "add"), - 'A = Tx.Buffer((1,), "float16")\n' - 'Tx.nki.scalar_tensor_scalar(A, A, Tx.float32(1.0), Tx.float32(1.0), "add", "add", Tx.bool(False), Tx.bool(False))', # noqa: E501 + 'A = T.Buffer((1,), "float16")\n' + 'T.nki.scalar_tensor_scalar(A, A, T.float32(1.0), T.float32(1.0), "add", "add", T.bool(False), T.bool(False))', # noqa: E501 ) _assert_print( tir.op.nki_activation_reduce(a0, a0, b0, "relu", "sum", 0.0, 1.0), - 'A = Tx.Buffer((1,), "float16")\n' - 'B = Tx.Buffer((1,), "float16")\n' - 'Tx.nki.activation_reduce(A, A, B, "relu", "sum", Tx.float32(0.0), Tx.float32(1.0))', + 'A = T.Buffer((1,), "float16")\n' + 'B = T.Buffer((1,), "float16")\n' + 'T.nki.activation_reduce(A, A, B, "relu", "sum", T.float32(0.0), T.float32(1.0))', ) _assert_print( tir.op.nki_affine_select(a0, a0, a0, 1.0), - 'A = Tx.Buffer((1,), "float16")\nTx.nki.affine_select(A, A, A, Tx.float32(1.0))', + 'A = T.Buffer((1,), "float16")\nT.nki.affine_select(A, A, A, T.float32(1.0))', ) @@ -360,13 +398,13 @@ def test_printer_ptx_mma_and_wgmma(): tir.Var("b", "handle") _assert_print( tir.op.ptx_mma("m8n8k4", "row", "row", "fp16", "fp16", "fp16", "fp16", [r], [r], [r]), - 'r = Tx.handle()\nTx.ptx.mma("void", "m8n8k4", "row", "row", "fp16", "fp16", "fp16", "fp16", 1, 1, 1, 0, Tx.bool(True), r, r, r, Tx.bool(False))', # noqa: E501 + 'r = T.handle()\nT.ptx.mma("m8n8k4", "row", "row", "fp16", "fp16", "fp16", "fp16", 1, 1, 1, 0, T.bool(True), r, r, r, T.bool(False))', # noqa: E501 ) _assert_print( tir.op.ptx_wgmma_encode_matrix_descriptor(d, a, 1, 1, 0), - "d = Tx.handle()\na = Tx.handle()\nTx.ptx.wgmma.encode_matrix_descriptor(d, a, 1, 1, 0)", + "d = T.handle()\na = T.handle()\nT.ptx.wgmma.encode_matrix_descriptor(d, a, 1, 1, 0)", ) - _assert_print(tir.op.ptx_wgmma_noop_barrier(0), "Tx.ptx.wgmma.noop_barrier(0)") + _assert_print(tir.op.ptx_wgmma_noop_barrier(0), "T.ptx.wgmma.noop_barrier(0)") _assert_print( tir.op.ptx_wgmma_mma_async_ss( d, @@ -384,7 +422,7 @@ def test_printer_ptx_mma_and_wgmma(): scaleB=1.0, scaleD=True, ), - 'd = Tx.handle()\nTx.ptx.wgmma.mma_async.ss(16, 16, 16, "f16", "f16", Tx.bool(True), Tx.bool(False), Tx.float32(1.0), Tx.float32(1.0), Tx.bool(True), d, d, 0, 0)', # noqa: E501 + 'd = T.handle()\nT.ptx.wgmma.mma_async.ss(16, 16, 16, "f16", "f16", T.bool(True), T.bool(False), T.float32(1.0), T.float32(1.0), T.bool(True), d, d, 0, 0)', # noqa: E501 ) _assert_print( tir.op.ptx_wgmma_mma_async_rs( @@ -402,7 +440,7 @@ def test_printer_ptx_mma_and_wgmma(): scaleB=1.0, scaleD=True, ), - 'd = Tx.handle()\nTx.ptx.wgmma.mma_async.rs(16, 16, 16, "f16", "f16", Tx.bool(True), Tx.bool(False), Tx.float32(1.0), Tx.float32(1.0), Tx.bool(True), d, 0, 0)', # noqa: E501 + 'd = T.handle()\nT.ptx.wgmma.mma_async.rs(16, 16, 16, "f16", "f16", T.bool(True), T.bool(False), T.float32(1.0), T.float32(1.0), T.bool(True), d, 0, 0)', # noqa: E501 ) @@ -410,31 +448,30 @@ def test_printer_ptx_cp_async_tensor(): tmap = tir.Var("tm", "handle") _assert_print( tir.op.ptx_cp_async_bulk_tensor_global_to_cluster(2, tmap, 0, tmap, 0, 1, "", 0, 1, ""), - "tm = Tx.handle()\n" - 'Tx.ptx.cp_async.bulk.tensor.g2c(2, tm, 0, tm, 0, 1, Tx.uint64(0), 0, 0, 1, "")', + "tm = T.handle()\n" + 'T.ptx.cp_async.bulk.tensor.g2c(2, tm, 0, tm, 0, 1, T.uint64(0), 0, 0, 1, "")', ) _assert_print( tir.op.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster( 2, tmap, 0, tmap, 0, 1, "", 0, 1, "" ), - "tm = Tx.handle()\n" - "Tx.ptx.cp_async.bulk.tensor.g2c_tile_gather4" - '(2, tm, 0, tm, 0, 1, Tx.uint64(0), 0, 0, 1, "")', + "tm = T.handle()\n" + "T.ptx.cp_async.bulk.tensor.g2c_tile_gather4" + '(2, tm, 0, tm, 0, 1, T.uint64(0), 0, 0, 1, "")', ) _assert_print( tir.op.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch(2, tmap, "", 0, 0, ""), - "tm = Tx.handle()\n" - 'Tx.ptx.cp_async.bulk.tensor.g2c_prefetch(2, tm, Tx.uint64(0), 0, 0, 0, "")', + 'tm = T.handle()\nT.ptx.cp_async.bulk.tensor.g2c_prefetch(2, tm, T.uint64(0), 0, 0, 0, "")', ) _assert_print( tir.op.ptx_cp_async_bulk_tensor_shared_to_global(2, 0, tmap, "", 0, 0, ""), - 'tm = Tx.handle()\nTx.ptx.cp_async.bulk.tensor.s2g(2, 0, tm, Tx.uint64(0), 0, 0, 0, "")', + 'tm = T.handle()\nT.ptx.cp_async.bulk.tensor.s2g(2, 0, tm, T.uint64(0), 0, 0, 0, "")', ) _assert_print( tir.op.ptx_cp_async_bulk_tensor_shared_to_global_reduce(2, 0, tmap, "", "add", 0, 0, ""), - "tm = Tx.handle()\n" - "Tx.ptx.cp_async.bulk.tensor.s2g_reduce" - '(2, 0, tm, Tx.uint64(0), 0, "add", 0, 0, "")', + "tm = T.handle()\n" + "T.ptx.cp_async.bulk.tensor.s2g_reduce" + '(2, 0, tm, T.uint64(0), 0, "add", 0, 0, "")', ) @@ -445,6 +482,5 @@ def test_printer_ptx_cp_async_call(): tir.op.ptx_cp_async( sh, gl, 16, cache_hint="", prefetch_size=-1, predicate=-1, fill_mode="" ), - "sh = Tx.handle()\ngl = Tx.handle()\n" - 'Tx.ptx.cp_async("void", sh, gl, 16, Tx.uint64(0), 0, -1, -1, "")', + 'sh = T.handle()\ngl = T.handle()\nT.ptx.cp_async(sh, gl, 16, T.uint64(0), 0, -1, -1, "")', ) diff --git a/tests/python/tirx/test_roundtrip_namespaces.py b/tests/python/tirx/test_roundtrip_namespaces.py index 4a3cdce86ebf..69e0629cee31 100644 --- a/tests/python/tirx/test_roundtrip_namespaces.py +++ b/tests/python/tirx/test_roundtrip_namespaces.py @@ -17,7 +17,7 @@ import tvm from tvm.ir import assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T def from_source(code): @@ -26,16 +26,16 @@ def from_source(code): def test_roundtrip_tir_namespaces_minimal(): # Exercise a selection of namespace ops and ensure round-trip consistency - @Tx.prim_func - def func(a_ptr: Tx.handle) -> None: - A = Tx.match_buffer(a_ptr, (2, 2), "float16") - Tx.ptx.wgmma.commit_group() - Tx.cuda.cluster_sync() - Tx.ptx.cp_async.wait_group(0) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.cuda.printf("ok") - Tx.nvshmem.quiet() - Tx.nki.identity(A[0, 0], 1) + @T.prim_func + def func(a_ptr: T.handle) -> None: + A = T.match_buffer(a_ptr, (2, 2), "float16") + T.ptx.wgmma.commit_group() + T.cuda.cluster_sync() + T.ptx.cp_async.wait_group(0) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.printf("ok") + T.nvshmem.quiet() + T.nki.identity(A[0, 0], 1) code = func.script() roundtripped = from_source(code) diff --git a/tests/python/tirx/test_verifier.py b/tests/python/tirx/test_verifier.py index b0a06ba96893..5ed20e7162fe 100644 --- a/tests/python/tirx/test_verifier.py +++ b/tests/python/tirx/test_verifier.py @@ -16,37 +16,30 @@ # under the License. import pytest -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.analysis import verify_tirx_well_formed as verify def test_root_scope(): # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test1() -> None: - Tx.device_entry() + T.device_entry() pass - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test2() -> None: - with Tx.warp(): - with Tx.thread(): - pass + pass - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test3() -> None: - with Tx.cta(): - with Tx.warp(): - with Tx.thread(): - pass + pass - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test4() -> None: - Tx.device_entry() - with Tx.cta(): - with Tx.warp(): - with Tx.thread(): - pass + T.device_entry() + pass # fmt: on @@ -58,44 +51,26 @@ def test4() -> None: def test_nested_scope(): # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test1() -> None: - Tx.device_entry() - with Tx.cta(): - with Tx.warp(): - with Tx.thread(): - pass - with Tx.thread(): - pass - - @Tx.prim_func(check_well_formed=False) + T.device_entry() + pass + pass + + @T.prim_func(check_well_formed=False) def test2() -> None: - Tx.device_entry() - with Tx.thread(): - with Tx.cta(): - with Tx.thread(): - pass + T.device_entry() + pass - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test3() -> None: - Tx.device_entry() - with Tx.warp(): - with Tx.thread(): - with Tx.cta(): - with Tx.thread(): - pass - @Tx.prim_func(check_well_formed=False) + T.device_entry() + pass + @T.prim_func(check_well_formed=False) def test4() -> None: - Tx.device_entry() - with Tx.thread(): - with Tx.warpgroup(): - with Tx.warp(): - with Tx.thread(): - pass - with Tx.warpgroup(): - with Tx.warp(): - with Tx.thread(): - pass + T.device_entry() + pass + pass # fmt: on @@ -107,89 +82,71 @@ def test4() -> None: def test_scope_id_consistency(): # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test1(): - Tx.device_entry() - Tx.cta_id([32]) - Tx.warp_id([4]) - Tx.lane_id([32]) - - with Tx.thread(): - pass + T.device_entry() + T.cta_id([32]) + T.warp_id([4]) + T.lane_id([32]) + pass - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test2(): - Tx.device_entry() - Tx.cta_id([32]) - Tx.warp_id([4]) - Tx.lane_id([32]) - Tx.thread_id([128]) - - with Tx.thread(): - pass + T.device_entry() + T.cta_id([32]) + T.warp_id([4]) + T.lane_id([32]) + T.thread_id([128]) + pass - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test3(): - Tx.device_entry() - Tx.cta_id([32]) - Tx.warp_id([2]) - Tx.lane_id([32]) - Tx.thread_id([128]) - - with Tx.thread(): - pass + T.device_entry() + T.cta_id([32]) + T.warp_id([2]) + T.lane_id([32]) + T.thread_id([128]) + pass - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test4(): - Tx.device_entry() - bx, by, bz = Tx.cta_id([8, 10, 12]) - cbx, cby, cbz = Tx.cta_id_in_cluster([2, 2, 1]) - clx, cly, clz = Tx.cluster_id([4, 5, 12]) - with Tx.cta(): - with Tx.warp(): - with Tx.thread(): - Tx.evaluate(bx + by + bz) - Tx.evaluate(cbx + cby + cbz) - Tx.evaluate(clx + cly + clz) - - @Tx.prim_func(check_well_formed=False) + T.device_entry() + bx, by, bz = T.cta_id([8, 10, 12]) + cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 1]) + clx, cly, clz = T.cluster_id([4, 5, 12]) + T.evaluate(bx + by + bz) + T.evaluate(cbx + cby + cbz) + T.evaluate(clx + cly + clz) + + @T.prim_func(check_well_formed=False) def test5(): - Tx.device_entry() - bx, by, bz = Tx.cta_id([8, 10, 12]) - cbx, cby, cbz = Tx.cta_id_in_cluster([2, 2, 1]) - clx, cly, clz = Tx.cluster_id([3, 5, 12]) - with Tx.cta(): - with Tx.warp(): - with Tx.thread(): - Tx.evaluate(bx + by + bz) - Tx.evaluate(cbx + cby + cbz) - Tx.evaluate(clx + cly + clz) - - @Tx.prim_func(check_well_formed=False) + T.device_entry() + bx, by, bz = T.cta_id([8, 10, 12]) + cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 1]) + clx, cly, clz = T.cluster_id([3, 5, 12]) + T.evaluate(bx + by + bz) + T.evaluate(cbx + cby + cbz) + T.evaluate(clx + cly + clz) + + @T.prim_func(check_well_formed=False) def test6(): - Tx.device_entry() - clx, cly, clz = Tx.cluster_id([4, 5, 12]) - bx, by, bz = Tx.cta_id([8, 10, 12]) - with Tx.cluster(): - cbx, cby, cbz = Tx.cta_id_in_cluster([2, 2, 1]) - with Tx.warp(): - with Tx.thread(): - Tx.evaluate(bx + by + bz) - Tx.evaluate(cbx + cby + cbz) - Tx.evaluate(clx + cly + clz) - - @Tx.prim_func(check_well_formed=False) + T.device_entry() + clx, cly, clz = T.cluster_id([4, 5, 12]) + bx, by, bz = T.cta_id([8, 10, 12]) + cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 1]) + T.evaluate(bx + by + bz) + T.evaluate(cbx + cby + cbz) + T.evaluate(clx + cly + clz) + + @T.prim_func(check_well_formed=False) def test7(): - Tx.device_entry() - clx, cly, clz = Tx.cluster_id([3, 5, 12]) - bx, by, bz = Tx.cta_id([8, 10, 12]) - with Tx.cluster(): - cbx, cby, cbz = Tx.cta_id_in_cluster([2, 2, 1]) - with Tx.warp(): - with Tx.thread(): - Tx.evaluate(bx + by + bz) - Tx.evaluate(cbx + cby + cbz) - Tx.evaluate(clx + cly + clz) + T.device_entry() + clx, cly, clz = T.cluster_id([3, 5, 12]) + bx, by, bz = T.cta_id([8, 10, 12]) + cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 1]) + T.evaluate(bx + by + bz) + T.evaluate(cbx + cby + cbz) + T.evaluate(clx + cly + clz) # fmt: on @@ -208,116 +165,105 @@ def test7(): def test_layout(): ### TileLayout # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test1(): - Tx.device_entry() - Tx.cta_id([32]) - Tx.warp_id([4]) - Tx.lane_id([32]) + T.device_entry() + T.cta_id([32]) + T.warp_id([4]) + T.lane_id([32]) + A = T.alloc_buffer((2,), layout=T.TileLayout(T.S[2, 1])) - with Tx.thread(): - A = Tx.alloc_buffer((2,), layout=Tx.TileLayout(Tx.S[2, 1])) - - A[0] = 0 + A[0] = 0 # fmt: on verify(test1) ### SwizzleLayout # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test2(): - Tx.device_entry() - Tx.cta_id([32]) - Tx.warp_id([4]) - Tx.lane_id([32]) - - with Tx.thread(): - A = Tx.alloc_buffer((512,), scope="shared", layout=Tx.SwizzleLayout(3, 3, 3)) + T.device_entry() + T.cta_id([32]) + T.warp_id([4]) + T.lane_id([32]) + A = T.alloc_buffer((512,), scope="shared", layout=T.SwizzleLayout(3, 3, 3)) - A[0] = 0 + A[0] = 0 # fmt: on verify(test2) def test_host(): # fmt: off - @Tx.prim_func(check_well_formed=False) - def test1(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) - - A_map: Tx.let[Tx.handle("tensormap")] = Tx.tvm_stack_alloca("tensormap", 1) - Tx.call_packed("runtime.cuTensorMapEncodeTiled", A_map, "float32", 2, A.data, 16, 16, 64, 16, 16, 1, 1, 0, 0, 0, 0) # noqa: E501 - - Tx.device_entry() - for blockIdx in Tx.thread_binding(1, thread="blockIdx.x"): - for threadIdx in Tx.thread_binding(128, thread="threadIdx.x"): - with Tx.thread(): - bar = Tx.alloc_buffer((1,), "uint64", scope="shared", align=8) - phase = Tx.alloc_buffer((1,), "int32", scope="local") - A_smem = Tx.alloc_buffer((16, 16), "float32", scope="shared", align=128) - - phase[0] = 0 - if threadIdx == 0: - Tx.ptx.mbarrier.init(bar.data, 1) - Tx.ptx.fence.proxy_async("shared::cta") - Tx.ptx.cp_async.bulk.tensor.g2c(2, A_smem.data, bar.data, Tx.address_of(A_map), 0, 1, "", 0, 0) # noqa: E501 - Tx.ptx.mbarrier.arrive.expect_tx(bar.data, 16*16*4) - Tx.ptx.mbarrier.try_wait(bar.data, phase[0]) - phase[0] = phase[0] ^ 1 - Tx.print_buffer(A_smem.data, "float32", False, False, 2, 16*16) + @T.prim_func(check_well_formed=False) + def test1(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) + + A_map: T.let[T.handle("tensormap")] = T.tvm_stack_alloca("tensormap", 1) + T.call_packed("runtime.cuTensorMapEncodeTiled", A_map, "float32", 2, A.data, 16, 16, 64, 16, 16, 1, 1, 0, 0, 0, 0) # noqa: E501 + + T.device_entry() + for blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for threadIdx in T.thread_binding(128, thread="threadIdx.x"): + bar = T.alloc_buffer((1,), "uint64", scope="shared", align=8) + phase = T.alloc_buffer((1,), "int32", scope="local") + A_smem = T.alloc_buffer((16, 16), "float32", scope="shared", align=128) + + phase[0] = 0 + if threadIdx == 0: + T.ptx.mbarrier.init(bar.data, 1) + T.ptx.fence.proxy_async("shared::cta") + T.ptx.cp_async.bulk.tensor.g2c(2, A_smem.data, bar.data, T.address_of(A_map), 0, 1, "", 0, 0) # noqa: E501 + T.ptx.mbarrier.arrive.expect_tx(bar.data, 16*16*4) + T.ptx.mbarrier.try_wait(bar.data, phase[0]) + phase[0] = phase[0] ^ 1 + T.print_buffer(A_smem.data, "float32", False, False, 2, 16*16) # fmt: on verify(test1) def test_device_func(): + # Per-call exec-scope migration: scope is now attached per op via the + # ``T.op[scope](...)`` subscription surface instead of a ``with T.cta():`` + # region. ``test1`` exercises a per-call-scoped op; ``test2`` the plain + # (unscoped) op. The old multi-root-scope negative case asserted the removed + # "only one root scope" verifier rule and no longer has an equivalent, so it + # is dropped. # fmt: off - @Tx.prim_func(check_well_formed=False) - def test1(A: Tx.Buffer((128,), "float32")): - with Tx.cta(): - Tx.thread_id([128]) - Tx.fill(A, 0.) - - @Tx.prim_func(check_well_formed=False) - def test2(A: Tx.Buffer((128,), "float32")): - Tx.device_entry() - Tx.cta_id([128]) - Tx.thread_id([128]) + @T.prim_func(check_well_formed=False) + def test1(A: T.Buffer((128,), "float32")): + T.device_entry() + T.cta_id([1]) + T.thread_id([128]) + Tx.cta.fill(A, 0.) + + @T.prim_func(check_well_formed=False) + def test2(A: T.Buffer((128,), "float32")): + T.device_entry() + T.cta_id([128]) + T.thread_id([128]) Tx.fill(A, 0.) - - @Tx.prim_func(check_well_formed=False) - def test3(A: Tx.Buffer((128,), "float32")): - with Tx.cta(): - Tx.thread_id([128]) - Tx.fill(A, 0.) - with Tx.cta(): - Tx.thread_id([128]) - Tx.fill(A, 0.) # fmt: on verify(test1, device_func=True) verify(test2, device_func=True) - with pytest.raises(Exception, match="Only one root scope is allowed in device function"): - verify(test3, device_func=True) def test_preferred_cluster_validation(): # fmt: off # Valid: cluster→cta with preferred_extents matching size - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test1() -> None: - Tx.device_entry() - cbx, cby = Tx.cta_id_in_cluster([2, 1], preferred=[2, 2]) - tx = Tx.thread_id([128]) - with Tx.thread(): - Tx.evaluate(cbx + cby + tx) + T.device_entry() + cbx, cby = T.cta_id_in_cluster([2, 1], preferred=[2, 2]) + tx = T.thread_id([128]) + T.evaluate(cbx + cby + tx) # Invalid: preferred size doesn't match extents size (caught at verify time) - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test2() -> None: - Tx.device_entry() - cbx, cby = Tx.cta_id_in_cluster([2, 1], preferred=[2]) - tx = Tx.thread_id([128]) - with Tx.thread(): - Tx.evaluate(cbx + cby + tx) + T.device_entry() + cbx, cby = T.cta_id_in_cluster([2, 1], preferred=[2]) + tx = T.thread_id([128]) + T.evaluate(cbx + cby + tx) # fmt: on verify(test1) @@ -327,13 +273,12 @@ def test2() -> None: # Invalid: preferred on a non-cluster→cta scope (caught at IR build time) with pytest.raises(Exception): # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def test3() -> None: - Tx.device_entry() - bx = Tx.cta_id([128], preferred=[256]) - tx = Tx.thread_id([128]) - with Tx.thread(): - Tx.evaluate(bx + tx) + T.device_entry() + bx = T.cta_id([128], preferred=[256]) + tx = T.thread_id([128]) + T.evaluate(bx + tx) # fmt: on @@ -343,33 +288,30 @@ def test_scope_id_deferred_relaxed_at_construction(): deferred to LowerTIRx.""" # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def partial_only_cta(): - Tx.device_entry() - bx = Tx.cta_id() # deferred kernel→cta, no closure source - tx = Tx.thread_id([128]) # explicit - with Tx.thread(): - Tx.evaluate(bx + tx) + T.device_entry() + bx = T.cta_id() # deferred kernel→cta, no closure source + tx = T.thread_id([128]) # explicit + T.evaluate(bx + tx) - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def all_deferred(): - Tx.device_entry() - bx = Tx.cta_id() - wg = Tx.warpgroup_id() - warp = Tx.warp_id_in_wg() - lane = Tx.lane_id() - with Tx.thread(): - Tx.evaluate(bx + wg + warp + lane) - - @Tx.prim_func(check_well_formed=False) + T.device_entry() + bx = T.cta_id() + wg = T.warpgroup_id() + warp = T.warp_id_in_wg() + lane = T.lane_id() + T.evaluate(bx + wg + warp + lane) + + @T.prim_func(check_well_formed=False) def mixed(): - Tx.device_entry() + T.device_entry() # kCtaWarp=4, kWarpThread=32 → kCtaThread=128 derivable. - Tx.warp_id([4]) - Tx.lane_id([32]) - Tx.thread_id() # deferred kCtaThread, resolvable via closure - with Tx.thread(): - pass + T.warp_id([4]) + T.lane_id([32]) + T.thread_id() # deferred kCtaThread, resolvable via closure + pass # fmt: on # All three accepted by well-formed: deferred extents are tolerated. @@ -383,17 +325,16 @@ def test_scope_id_deferred_consistency_still_enforced(): must still be enforced by the closure check.""" # fmt: off - @Tx.prim_func(check_well_formed=False) + @T.prim_func(check_well_formed=False) def inconsistent(): # 4 warps * 32 lanes = 128 threads, but explicit thread_id says 64 -> error. - Tx.device_entry() - Tx.cta_id([32]) - Tx.warp_id([4]) - Tx.lane_id([32]) - Tx.thread_id() # deferred (shouldn't shadow the conflict) - Tx.thread_id([64]) # conflicts with derived kCtaThread=128 - with Tx.thread(): - pass + T.device_entry() + T.cta_id([32]) + T.warp_id([4]) + T.lane_id([32]) + T.thread_id() # deferred (shouldn't shadow the conflict) + T.thread_id([64]) # conflicts with derived kCtaThread=128 + pass # fmt: on with pytest.raises(Exception, match="Inconsistent extents for scope"): diff --git a/tests/python/tirx/transform/test_stmt_functor.py b/tests/python/tirx/transform/test_stmt_functor.py index cce208d706c2..af8605d841bf 100644 --- a/tests/python/tirx/transform/test_stmt_functor.py +++ b/tests/python/tirx/transform/test_stmt_functor.py @@ -22,7 +22,8 @@ import tvm.testing from tvm import tirx as tir from tvm.ir import Range -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.expr import EQ, GT, LT, Add, IntImm, Mul, Sub, Var from tvm.tirx.stmt_functor import StmtExprMutator, StmtExprVisitor, StmtMutator, StmtVisitor @@ -657,8 +658,8 @@ def create_test_statements(): if_then_else = tir.IfThenElse(tir.LT(x, int_imm), evaluate_stmt, evaluate_stmt) # Break and continue statements inside a for loop - @Tx.prim_func - def func(A: Tx.Buffer((10,), "int32")): + @T.prim_func + def func(A: T.Buffer((10,), "int32")): for x in range(10): A[x] = x + 1 if x == 5: @@ -666,15 +667,15 @@ def func(A: Tx.Buffer((10,), "int32")): continue # DeclBuffer - buffer_decl = tir.DeclBuffer(Tx.buffer((10,), "int32"), evaluate_stmt) + buffer_decl = tir.DeclBuffer(T.buffer((10,), "int32"), evaluate_stmt) # TilePrimitiveCall — extract the TilePrimitiveCall from the kernel body, then wrap in an SBlock - @Tx.prim_func - def op_call(A: Tx.Buffer((10,), "int32"), B: Tx.Buffer((10,), "int32")): - Tx.device_entry() + @T.prim_func + def op_call(A: T.Buffer((10,), "int32"), B: T.Buffer((10,), "int32")): + T.device_entry() Tx.add(A, B, 1.0) - # op_call.body is ExecScopeStmt, op_call.body.body is TilePrimitiveCall + # op_call.body is the tirx.device_entry AttrStmt, op_call.body.body is TilePrimitiveCall op_call_stmt = op_call.body.body op_call_block = tir.SBlock([], [], [], "op_call_block", op_call_stmt) @@ -1009,7 +1010,7 @@ def visit_int_imm_(self, op): def test_mutator_transformation(): - """Test that mutator actually transforms the ASTx.""" + """Test that mutator actually transforms the AST.""" evaluate_stmt = create_test_statements()["evaluate"] mutator = NegateIntImmMutator() result = mutator.visit_stmt(evaluate_stmt) @@ -1092,9 +1093,9 @@ def __init__(self): def visit_var_(self, op): self.vars.add(op.name) - @Tx.prim_func - def op_call_with_config(A: Tx.Buffer((10,), "int32"), B: Tx.Buffer((10,), "int32")): - Tx.device_entry() + @T.prim_func + def op_call_with_config(A: T.Buffer((10,), "int32"), B: T.Buffer((10,), "int32")): + T.device_entry() Tx.add(A, B, 1.0) op_call_stmt = op_call_with_config.body.body @@ -1124,9 +1125,9 @@ def test_op_call_config_mutated(): """ from tvm.tirx.stmt_functor import substitute - @Tx.prim_func - def op_call_with_config(A: Tx.Buffer((10,), "int32"), B: Tx.Buffer((10,), "int32")): - Tx.device_entry() + @T.prim_func + def op_call_with_config(A: T.Buffer((10,), "int32"), B: T.Buffer((10,), "int32")): + T.device_entry() Tx.add(A, B, 1.0) op_call_stmt = op_call_with_config.body.body diff --git a/tests/python/tirx/transform/test_transform_lower_tirx.py b/tests/python/tirx/transform/test_transform_lower_tirx.py index 33e0d028e83a..037e415fe9f6 100644 --- a/tests/python/tirx/transform/test_transform_lower_tirx.py +++ b/tests/python/tirx/transform/test_transform_lower_tirx.py @@ -19,27 +19,13 @@ import tvm import tvm.testing -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.function import PrimFunc from tvm.tirx.layout import laneid, warpid, wg_local_layout -from tvm.tirx.stmt import ExecScopeStmt -from tvm.tirx.stmt_functor import post_order_visit from tvm.tirx.transform import LowerTIRx, StmtSimplify -def _contains_exec_scope(mod): - found = [False] - - def _visit(node): - if isinstance(node, ExecScopeStmt): - found[0] = True - - for _gv, base_func in mod.functions.items(): - if isinstance(base_func, PrimFunc): - post_order_visit(base_func.body, _visit) - return found[0] - - def compare(before, after, transform): """Compare lowered output against expected ``after`` IR.""" if isinstance(before, PrimFunc): @@ -51,7 +37,6 @@ def compare(before, after, transform): with tvm.target.Target("cuda"): lowered = transform()(before) lowered.show() - assert not _contains_exec_scope(lowered) tvm.ir.assert_structural_equal(lowered, after, map_free_vars=False) @@ -63,200 +48,182 @@ def _int_triple(side, axis): return tuple(int(x) for x in side[axis]) -L_LANE = Tx.TileLayout(Tx.S[32 : 1 @ laneid]) +L_LANE = T.TileLayout(T.S[32 : 1 @ laneid]) def test_lower_view_get(): - @Tx.prim_func(private=True) - def before1(in_buf: Tx.Buffer(64, "float32"), out: Tx.Buffer(64, "float32")) -> None: - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - A = Tx.alloc_buffer([2], dtype="float16", scope="local", layout=Tx.TileLayout(Tx.S[2:1])) + @T.prim_func(private=True) + def before1(in_buf: T.Buffer(64, "float32"), out: T.Buffer(64, "float32")) -> None: + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + T.warp_id([1]) + lane_id = T.lane_id([32]) + A = T.alloc_buffer([2], dtype="float16", scope="local", layout=T.TileLayout(T.S[2:1])) B_layout = A.layout.tile(L_LANE, (32,), (2,)) - with Tx.warp(): - B = A.view(64, layout=B_layout) - with Tx.thread(): - A_local = B.local(2) - for i in Tx.vectorized(2): - A_local[i] = Tx.float32(in_buf[lane_id * 2 + i]) - with Tx.warp(): - B = A.view(64, layout=B_layout) - with Tx.thread(): - A_local = B.local(2) - for i in Tx.vectorized(2): - out[lane_id * 2 + i] = Tx.float32(A_local[i]) - - @Tx.prim_func(private=True) - def after1(in_buf_handle: Tx.handle, out_handle: Tx.handle): - in_buf = Tx.match_buffer(in_buf_handle, (64,), layout=None) - out = Tx.match_buffer(out_handle, (64,), layout=None) - out_1 = Tx.decl_buffer((64,), data=out.data, layout=None) - in_buf_1 = Tx.decl_buffer((64,), data=in_buf.data, layout=None) - blockIdx_x = Tx.launch_thread("blockIdx.x", 1) - threadIdx_x = Tx.launch_thread("threadIdx.x", 32) - blockIdx_y = Tx.launch_thread("blockIdx.y", 1) - blockIdx_z = Tx.launch_thread("blockIdx.z", 1) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + B = A.view(64, layout=B_layout) + A_local = B.local(2) + for i in T.vectorized(2): + A_local[i] = T.float32(in_buf[lane_id * 2 + i]) + B_1 = A.view(64, layout=B_layout) + A_local_1 = B_1.local(2) + for i in T.vectorized(2): + out[lane_id * 2 + i] = T.float32(A_local_1[i]) + + @T.prim_func(private=True) + def after1(in_buf_handle: T.handle, out_handle: T.handle): + in_buf = T.match_buffer(in_buf_handle, (64,), layout=None) + out = T.match_buffer(out_handle, (64,), layout=None) + out_1 = T.decl_buffer((64,), data=out.data, layout=None) + in_buf_1 = T.decl_buffer((64,), data=in_buf.data, layout=None) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 32) + blockIdx_y = T.launch_thread("blockIdx.y", 1) + blockIdx_z = T.launch_thread("blockIdx.z", 1) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - v: Tx.let[Tx.int32] = warp_id_in_cta - lane_id: Tx.let[Tx.int32] = threadIdx_x % 32 - Tx.evaluate(v) - A = Tx.alloc_local((2,), "float16", layout=None) - B = Tx.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) - A_local = Tx.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) - for i in Tx.vectorized(2): - A_local[i] = Tx.Cast("float16", in_buf_1[threadIdx_x * 2 + i]) - B_1 = Tx.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) - A_local_1 = Tx.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) - for i in Tx.vectorized(2): - out_1[threadIdx_x * 2 + i] = Tx.Cast("float32", A_local_1[i]) + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + v: T.let[T.int32] = warp_id_in_cta + lane_id: T.let[T.int32] = threadIdx_x % 32 + T.evaluate(v) + A = T.alloc_local((2,), "float16", layout=None) + B = T.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) + A_local = T.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) + for i in T.vectorized(2): + A_local[i] = T.Cast("float16", in_buf_1[threadIdx_x * 2 + i]) + B_1 = T.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) + A_local_1 = T.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) + for i in T.vectorized(2): + out_1[threadIdx_x * 2 + i] = T.Cast("float32", A_local_1[i]) compare(before1, after1, LowerTIRx) - @Tx.prim_func(private=True) - def before2( - in_buf: Tx.Buffer((16, 16), "float32"), out: Tx.Buffer((16, 16), "float32") - ) -> None: - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.thread(): - atom = Tx.TileLayout(Tx.S[(1, 2) : (2, 1)]) - tile = Tx.TileLayout(Tx.S[(2, 2) : (2, 1)]) - warp_atom = atom.tile(L_LANE, (8, 4), (1, 2)) - A = Tx.alloc_buffer( - [4, 2], dtype="float32", scope="local", layout=atom.tile(tile, (2, 2), (1, 2)) - ) - B_layout = warp_atom.tile(tile, (2, 2), (8, 8)) - with Tx.warp(): - B = A.view(16, 16, layout=B_layout) - with Tx.thread(): - A_local = B.local(2, 2, 2) - for i in Tx.unroll(4): - for j in Tx.vectorized(2): - A_local[i // 2, i % 2, j] = in_buf[ - i // 2 * 8 + lane_id // 4, i % 2 * 8 + lane_id % 4 + j - ] - with Tx.warp(): - B = A.view(16, 16, layout=B_layout) - with Tx.thread(): - A_local = B.local(8) - for i in Tx.vectorized(2): - out[ - lane_id // 4 * 8 + i // 2 * 8 + lane_id % 4, lane_id % 4 * 2 + i % 2 - ] = A_local[i] - - @Tx.prim_func(private=True) - def after2(in_buf_handle: Tx.handle, out_handle: Tx.handle): - in_buf = Tx.match_buffer(in_buf_handle, (16, 16), layout=None) - out = Tx.match_buffer(out_handle, (16, 16), layout=None) - out_1 = Tx.decl_buffer((256,), data=out.data, layout=None) - in_buf_1 = Tx.decl_buffer((256,), data=in_buf.data, layout=None) - blockIdx_x = Tx.launch_thread("blockIdx.x", 1) - threadIdx_x = Tx.launch_thread("threadIdx.x", 32) - blockIdx_y = Tx.launch_thread("blockIdx.y", 1) - blockIdx_z = Tx.launch_thread("blockIdx.z", 1) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + @T.prim_func(private=True) + def before2(in_buf: T.Buffer((16, 16), "float32"), out: T.Buffer((16, 16), "float32")) -> None: + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + T.warp_id([1]) + lane_id = T.lane_id([32]) + atom = T.TileLayout(T.S[(1, 2) : (2, 1)]) + tile = T.TileLayout(T.S[(2, 2) : (2, 1)]) + warp_atom = atom.tile(L_LANE, (8, 4), (1, 2)) + A = T.alloc_buffer( + [4, 2], dtype="float32", scope="local", layout=atom.tile(tile, (2, 2), (1, 2)) ) - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - v: Tx.let[Tx.int32] = warp_id_in_cta - lane_id: Tx.let[Tx.int32] = threadIdx_x % 32 - Tx.evaluate(v) - A = Tx.alloc_local((8,), layout=None) - B = Tx.decl_buffer((256,), data=A.data, scope="local", layout=None) - A_local = Tx.decl_buffer((8,), data=A.data, scope="local", layout=None) - for i in Tx.unroll(4): - for j in Tx.vectorized(2): + B_layout = warp_atom.tile(tile, (2, 2), (8, 8)) + B = A.view(16, 16, layout=B_layout) + A_local = B.local(2, 2, 2) + for i in T.unroll(4): + for j in T.vectorized(2): + A_local[i // 2, i % 2, j] = in_buf[ + i // 2 * 8 + lane_id // 4, i % 2 * 8 + lane_id % 4 + j + ] + B_1 = A.view(16, 16, layout=B_layout) + A_local_1 = B_1.local(8) + for i in T.vectorized(2): + out[lane_id // 4 * 8 + i // 2 * 8 + lane_id % 4, lane_id % 4 * 2 + i % 2] = A_local_1[i] + + @T.prim_func(private=True) + def after2(in_buf_handle: T.handle, out_handle: T.handle): + in_buf = T.match_buffer(in_buf_handle, (16, 16), layout=None) + out = T.match_buffer(out_handle, (16, 16), layout=None) + out_1 = T.decl_buffer((256,), data=out.data, layout=None) + in_buf_1 = T.decl_buffer((256,), data=in_buf.data, layout=None) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 32) + blockIdx_y = T.launch_thread("blockIdx.y", 1) + blockIdx_z = T.launch_thread("blockIdx.z", 1) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + ) + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + v: T.let[T.int32] = warp_id_in_cta + lane_id: T.let[T.int32] = threadIdx_x % 32 + T.evaluate(v) + A = T.alloc_local((8,), layout=None) + B = T.decl_buffer((256,), data=A.data, scope="local", layout=None) + A_local = T.decl_buffer((8,), data=A.data, scope="local", layout=None) + for i in T.unroll(4): + for j in T.vectorized(2): A_local[i * 2 + j] = in_buf_1[ i // 2 * 128 + threadIdx_x // 4 * 16 + i % 2 * 8 + j + threadIdx_x % 4 ] - B_1 = Tx.decl_buffer((256,), data=A.data, scope="local", layout=None) - A_local_1 = Tx.decl_buffer((8,), data=A.data, scope="local", layout=None) - for i in Tx.vectorized(2): + B_1 = T.decl_buffer((256,), data=A.data, scope="local", layout=None) + A_local_1 = T.decl_buffer((8,), data=A.data, scope="local", layout=None) + for i in T.vectorized(2): out_1[threadIdx_x // 4 * 128 + threadIdx_x % 4 * 18 + i] = A_local_1[i] compare(before2, after2, LowerTIRx) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before3_wgmma_layout( - in_buf: Tx.Buffer((128, 128), "float32"), out: Tx.Buffer((128, 128), "float32") + in_buf: T.Buffer((128, 128), "float32"), out: T.Buffer((128, 128), "float32") ) -> None: - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - wg_id = Tx.warpgroup_id([2]) - warp_id_in_wg = Tx.warp_id_in_wg([4]) - lane_id = Tx.lane_id([32]) - with Tx.thread(): - atom = Tx.TileLayout(Tx.S[1, 2]) - warp_atom = atom.tile(L_LANE, (8, 4), (1, 2)) - tile = Tx.TileLayout(Tx.S[(2, 128 // 8) : (1, 2)]) - warp_layout = warp_atom.tile(tile, (2, 128 // 8), (8, 8)) - L_warp = Tx.TileLayout(Tx.S[8 : 1 @ warpid]) - layout = warp_layout.tile(L_warp, (8, 1), (16, 128)) - acc = Tx.alloc_buffer( - [64], - dtype="float32", - scope="local", - layout=atom.tile(tile, (2, 128 // 8), (1, 2)), - ) - with Tx.cta(): - A = acc.view(128, 128, layout=layout) - with Tx.thread(): - acc_local = A.local(16, 2, 2, layout=atom.tile(tile, (2, 128 // 8), (1, 2))) - for i in Tx.serial(128 // 8): - for j in Tx.unroll(2): - for vec in Tx.vectorized(2): - acc_local[i, j, vec] = in_buf[ - wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, - i * 8 + lane_id % 4 * 2 + vec, - ] - with Tx.cta(): - A = acc.view(128, 128, layout=layout) - with Tx.thread(): - acc_local = A.local(64, layout=atom.tile(tile, (2, 128 // 8), (1, 2))) - for i in Tx.serial(128 // 8): - for j in Tx.unroll(2): - for vec in Tx.vectorized(2): - out[ - wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, - i * 8 + lane_id % 4 * 2 + vec, - ] = acc_local[i * 4 + j * 2 + vec] - - @Tx.prim_func(private=True) - def after3_wgmma_layout(in_buf_handle: Tx.handle, out_handle: Tx.handle): - in_buf = Tx.match_buffer(in_buf_handle, (128, 128), layout=None) - out = Tx.match_buffer(out_handle, (128, 128), layout=None) - out_1 = Tx.decl_buffer((16384,), data=out.data, layout=None) - in_buf_1 = Tx.decl_buffer((16384,), data=in_buf.data, layout=None) - blockIdx_x = Tx.launch_thread("blockIdx.x", 1) - threadIdx_x = Tx.launch_thread("threadIdx.x", 256) - blockIdx_y = Tx.launch_thread("blockIdx.y", 1) - blockIdx_z = Tx.launch_thread("blockIdx.z", 1) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + wg_id = T.warpgroup_id([2]) + warp_id_in_wg = T.warp_id_in_wg([4]) + lane_id = T.lane_id([32]) + atom = T.TileLayout(T.S[1, 2]) + warp_atom = atom.tile(L_LANE, (8, 4), (1, 2)) + tile = T.TileLayout(T.S[(2, 128 // 8) : (1, 2)]) + warp_layout = warp_atom.tile(tile, (2, 128 // 8), (8, 8)) + L_warp = T.TileLayout(T.S[8 : 1 @ warpid]) + layout = warp_layout.tile(L_warp, (8, 1), (16, 128)) + acc = T.alloc_buffer( + [64], + dtype="float32", + scope="local", + layout=atom.tile(tile, (2, 128 // 8), (1, 2)), + ) + A = acc.view(128, 128, layout=layout) + acc_local = A.local(16, 2, 2, layout=atom.tile(tile, (2, 128 // 8), (1, 2))) + for i in T.serial(128 // 8): + for j in T.unroll(2): + for vec in T.vectorized(2): + acc_local[i, j, vec] = in_buf[ + wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, + i * 8 + lane_id % 4 * 2 + vec, + ] + A_1 = acc.view(128, 128, layout=layout) + acc_local_1 = A_1.local(64, layout=atom.tile(tile, (2, 128 // 8), (1, 2))) + for i in T.serial(128 // 8): + for j in T.unroll(2): + for vec in T.vectorized(2): + out[ + wg_id * 64 + warp_id_in_wg * 16 + j * 8 + lane_id // 4, + i * 8 + lane_id % 4 * 2 + vec, + ] = acc_local_1[i * 4 + j * 2 + vec] + + @T.prim_func(private=True) + def after3_wgmma_layout(in_buf_handle: T.handle, out_handle: T.handle): + in_buf = T.match_buffer(in_buf_handle, (128, 128), layout=None) + out = T.match_buffer(out_handle, (128, 128), layout=None) + out_1 = T.decl_buffer((16384,), data=out.data, layout=None) + in_buf_1 = T.decl_buffer((16384,), data=in_buf.data, layout=None) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 256) + blockIdx_y = T.launch_thread("blockIdx.y", 1) + blockIdx_z = T.launch_thread("blockIdx.z", 1) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - wg_id: Tx.let[Tx.int32] = warp_id_in_cta // 4 - warp_id_in_wg: Tx.let[Tx.int32] = warp_id_in_cta % 4 - lane_id: Tx.let[Tx.int32] = threadIdx_x % 32 - acc = Tx.alloc_local((64,), layout=None) - B = Tx.decl_buffer((16384,), data=acc.data, scope="local", layout=None) - acc_local = Tx.decl_buffer((64,), data=acc.data, scope="local", layout=None) + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + wg_id: T.let[T.int32] = warp_id_in_cta // 4 + warp_id_in_wg: T.let[T.int32] = warp_id_in_cta % 4 + lane_id: T.let[T.int32] = threadIdx_x % 32 + acc = T.alloc_local((64,), layout=None) + B = T.decl_buffer((16384,), data=acc.data, scope="local", layout=None) + acc_local = T.decl_buffer((64,), data=acc.data, scope="local", layout=None) for i in range(16): - for j in Tx.unroll(2): - for vec in Tx.vectorized(2): + for j in T.unroll(2): + for vec in T.vectorized(2): acc_local[i % 8 * 8 + j * 4 + i // 8 * 2 + vec] = in_buf_1[ warp_id_in_cta * 2048 + j * 1024 @@ -265,11 +232,11 @@ def after3_wgmma_layout(in_buf_handle: Tx.handle, out_handle: Tx.handle): + threadIdx_x % 4 * 2 + vec ] - B_1 = Tx.decl_buffer((16384,), data=acc.data, scope="local", layout=None) - acc_local_1 = Tx.decl_buffer((64,), data=acc.data, scope="local", layout=None) + B_1 = T.decl_buffer((16384,), data=acc.data, scope="local", layout=None) + acc_local_1 = T.decl_buffer((64,), data=acc.data, scope="local", layout=None) for i in range(16): - for j in Tx.unroll(2): - for vec in Tx.vectorized(2): + for j in T.unroll(2): + for vec in T.vectorized(2): out_1[ warp_id_in_cta * 2048 + j * 1024 @@ -281,214 +248,202 @@ def after3_wgmma_layout(in_buf_handle: Tx.handle, out_handle: Tx.handle): compare(before3_wgmma_layout, after3_wgmma_layout, LowerTIRx) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before4_multi_view_get( - in_buf: Tx.Buffer(64, "float32"), out: Tx.Buffer(64, "float32") + in_buf: T.Buffer(64, "float32"), out: T.Buffer(64, "float32") ) -> None: - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.thread(): - A = Tx.alloc_buffer( - [2], dtype="float16", scope="local", layout=Tx.TileLayout(Tx.S[2:1]) - ) - B_layout = A.layout.tile(L_LANE, (32,), (2,)) - with Tx.warp(): - B = A.view(64, layout=B_layout) - B_1 = A.view(64, layout=B_layout) - with Tx.thread(): - A_local = B.local(2) - A_local[0] = Tx.float32(in_buf[lane_id * 2]) - A_local_1 = B_1.local(2) - A_local_1[1] = Tx.float32(in_buf[lane_id * 2 + 1]) - "\n write A into out\n " - with Tx.warp(): - B = A.view(64, layout=B_layout) - B_1 = A.view(64, layout=B_layout) - with Tx.thread(): - A_local = B.local(2) - out[lane_id * 2] = Tx.float32(A_local[0]) - A_local_1 = B_1.local(2) - out[lane_id * 2 + 1] = Tx.float32(A_local_1[1]) - - @Tx.prim_func(private=True) - def after4_multi_view_get(in_buf_handle: Tx.handle, out_handle: Tx.handle): - in_buf = Tx.match_buffer(in_buf_handle, (64,), layout=None) - out = Tx.match_buffer(out_handle, (64,), layout=None) - out_1 = Tx.decl_buffer((64,), data=out.data, layout=None) - in_buf_1 = Tx.decl_buffer((64,), data=in_buf.data, layout=None) - blockIdx_x = Tx.launch_thread("blockIdx.x", 1) - threadIdx_x = Tx.launch_thread("threadIdx.x", 32) - blockIdx_y = Tx.launch_thread("blockIdx.y", 1) - blockIdx_z = Tx.launch_thread("blockIdx.z", 1) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + T.warp_id([1]) + lane_id = T.lane_id([32]) + A = T.alloc_buffer([2], dtype="float16", scope="local", layout=T.TileLayout(T.S[2:1])) + B_layout = A.layout.tile(L_LANE, (32,), (2,)) + B = A.view(64, layout=B_layout) + B_1 = A.view(64, layout=B_layout) + A_local = B.local(2) + A_local[0] = T.float32(in_buf[lane_id * 2]) + A_local_1 = B_1.local(2) + A_local_1[1] = T.float32(in_buf[lane_id * 2 + 1]) + "\n write A into out\n " + B_2 = A.view(64, layout=B_layout) + B_3 = A.view(64, layout=B_layout) + A_local_2 = B_2.local(2) + out[lane_id * 2] = T.float32(A_local_2[0]) + A_local_3 = B_3.local(2) + out[lane_id * 2 + 1] = T.float32(A_local_3[1]) + + @T.prim_func(private=True) + def after4_multi_view_get(in_buf_handle: T.handle, out_handle: T.handle): + in_buf = T.match_buffer(in_buf_handle, (64,), layout=None) + out = T.match_buffer(out_handle, (64,), layout=None) + out_1 = T.decl_buffer((64,), data=out.data, layout=None) + in_buf_1 = T.decl_buffer((64,), data=in_buf.data, layout=None) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 32) + blockIdx_y = T.launch_thread("blockIdx.y", 1) + blockIdx_z = T.launch_thread("blockIdx.z", 1) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - v: Tx.let[Tx.int32] = warp_id_in_cta - lane_id: Tx.let[Tx.int32] = threadIdx_x % 32 - Tx.evaluate(v) - A = Tx.alloc_local((2,), "float16", layout=None) - B = Tx.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) - B_1 = Tx.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) - A_local = Tx.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) - A_local[0] = Tx.Cast("float16", in_buf_1[threadIdx_x * 2]) - A_local_1 = Tx.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) - A_local_1[1] = Tx.Cast("float16", in_buf_1[threadIdx_x * 2 + 1]) - B_2 = Tx.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) - B_3 = Tx.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) - A_local_2 = Tx.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) - out_1[threadIdx_x * 2] = Tx.Cast("float32", A_local_2[0]) - A_local_3 = Tx.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) - out_1[threadIdx_x * 2 + 1] = Tx.Cast("float32", A_local_3[1]) + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + v: T.let[T.int32] = warp_id_in_cta + lane_id: T.let[T.int32] = threadIdx_x % 32 + T.evaluate(v) + A = T.alloc_local((2,), "float16", layout=None) + B = T.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) + B_1 = T.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) + A_local = T.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) + A_local[0] = T.Cast("float16", in_buf_1[threadIdx_x * 2]) + A_local_1 = T.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) + A_local_1[1] = T.Cast("float16", in_buf_1[threadIdx_x * 2 + 1]) + B_2 = T.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) + B_3 = T.decl_buffer((64,), "float16", data=A.data, scope="local", layout=None) + A_local_2 = T.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) + out_1[threadIdx_x * 2] = T.Cast("float32", A_local_2[0]) + A_local_3 = T.decl_buffer((2,), "float16", data=A.data, scope="local", layout=None) + out_1[threadIdx_x * 2 + 1] = T.Cast("float32", A_local_3[1]) compare(before4_multi_view_get, after4_multi_view_get, LowerTIRx) def test_lower_scope_id(): - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before1() -> None: - Tx.device_entry() - bx, by, bz = Tx.cta_id([3, 4, 5]) - tx = Tx.thread_id([32]) - Tx.evaluate(bx + by + bz + tx) + T.device_entry() + bx, by, bz = T.cta_id([3, 4, 5]) + tx = T.thread_id([32]) + T.evaluate(bx + by + bz + tx) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def after1() -> None: - blockIdx_x = Tx.launch_thread("blockIdx.x", 3) - threadIdx_x = Tx.launch_thread("threadIdx.x", 32) - blockIdx_y = Tx.launch_thread("blockIdx.y", 4) - blockIdx_z = Tx.launch_thread("blockIdx.z", 5) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + blockIdx_x = T.launch_thread("blockIdx.x", 3) + threadIdx_x = T.launch_thread("threadIdx.x", 32) + blockIdx_y = T.launch_thread("blockIdx.y", 4) + blockIdx_z = T.launch_thread("blockIdx.z", 5) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - tx: Tx.let[Tx.int32] = threadIdx_x - Tx.evaluate(bx + by + bz + tx) + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + tx: T.let[T.int32] = threadIdx_x + T.evaluate(bx + by + bz + tx) compare(before1, after1, LowerTIRx) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before2() -> None: - Tx.device_entry() - cbx, cby, cbz = Tx.cta_id_in_cluster([2, 2, 2]) - bx, by, bz = Tx.cta_id([8, 8, 8]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - Tx.evaluate(bx + by + bz + warp_id + lane_id + cbx + cby + cbz) - - @Tx.prim_func(private=True) + T.device_entry() + cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 2]) + bx, by, bz = T.cta_id([8, 8, 8]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + T.evaluate(bx + by + bz + warp_id + lane_id + cbx + cby + cbz) + + @T.prim_func(private=True) def after2() -> None: - clusterCtaIdx_x = Tx.launch_thread("clusterCtaIdx.x", 2) - blockIdx_z = Tx.launch_thread("blockIdx.z", 8) - clusterCtaIdx_y = Tx.launch_thread("clusterCtaIdx.y", 2) - clusterCtaIdx_z = Tx.launch_thread("clusterCtaIdx.z", 2) - blockIdx_x = Tx.launch_thread("blockIdx.x", 8) - threadIdx_x = Tx.launch_thread("threadIdx.x", 128) - blockIdx_y = Tx.launch_thread("blockIdx.y", 8) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + clusterCtaIdx_x = T.launch_thread("clusterCtaIdx.x", 2) + blockIdx_z = T.launch_thread("blockIdx.z", 8) + clusterCtaIdx_y = T.launch_thread("clusterCtaIdx.y", 2) + clusterCtaIdx_z = T.launch_thread("clusterCtaIdx.z", 2) + blockIdx_x = T.launch_thread("blockIdx.x", 8) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + blockIdx_y = T.launch_thread("blockIdx.y", 8) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - cbx: Tx.let[Tx.int32] = clusterCtaIdx_x - cby: Tx.let[Tx.int32] = clusterCtaIdx_y - cbz: Tx.let[Tx.int32] = clusterCtaIdx_z - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - warp_id: Tx.let[Tx.int32] = warp_id_in_cta - lane_id: Tx.let[Tx.int32] = threadIdx_x % 32 - Tx.evaluate(bx + by + bz + warp_id + lane_id + cbx + cby + cbz) + cbx: T.let[T.int32] = clusterCtaIdx_x + cby: T.let[T.int32] = clusterCtaIdx_y + cbz: T.let[T.int32] = clusterCtaIdx_z + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + warp_id: T.let[T.int32] = warp_id_in_cta + lane_id: T.let[T.int32] = threadIdx_x % 32 + T.evaluate(bx + by + bz + warp_id + lane_id + cbx + cby + cbz) compare(before2, after2, LowerTIRx) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before3() -> None: - Tx.device_entry() - bx, by, bz = Tx.cta_id([8, 10, 12]) - cbx, cby, cbz = Tx.cta_id_in_cluster([2, 2, 1]) - clx, cly, clz = Tx.cluster_id([4, 5, 12]) - wg_id = Tx.warpgroup_id([3]) - warp_id_in_wg = Tx.warp_id_in_wg([4]) - lane_id = Tx.lane_id([32]) - tid_in_wg = Tx.thread_id_in_wg([128]) - with Tx.cta(): - with Tx.warpgroup(): - with Tx.thread(): - Tx.evaluate(bx + by + bz) - Tx.evaluate(cbx + cby + cbz) - Tx.evaluate(clx + cly + clz) - Tx.evaluate(wg_id + warp_id_in_wg + lane_id + tid_in_wg) - - @Tx.prim_func(private=True) + T.device_entry() + bx, by, bz = T.cta_id([8, 10, 12]) + cbx, cby, cbz = T.cta_id_in_cluster([2, 2, 1]) + clx, cly, clz = T.cluster_id([4, 5, 12]) + wg_id = T.warpgroup_id([3]) + warp_id_in_wg = T.warp_id_in_wg([4]) + lane_id = T.lane_id([32]) + tid_in_wg = T.thread_id_in_wg([128]) + T.evaluate(bx + by + bz) + T.evaluate(cbx + cby + cbz) + T.evaluate(clx + cly + clz) + T.evaluate(wg_id + warp_id_in_wg + lane_id + tid_in_wg) + + @T.prim_func(private=True) def after3() -> None: - clusterCtaIdx_x = Tx.launch_thread("clusterCtaIdx.x", 2) - blockIdx_z = Tx.launch_thread("blockIdx.z", 12) - clusterCtaIdx_y = Tx.launch_thread("clusterCtaIdx.y", 2) - clusterCtaIdx_z = Tx.launch_thread("clusterCtaIdx.z", 1) - blockIdx_x = Tx.launch_thread("blockIdx.x", 8) - threadIdx_x = Tx.launch_thread("threadIdx.x", 384) - blockIdx_y = Tx.launch_thread("blockIdx.y", 10) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + clusterCtaIdx_x = T.launch_thread("clusterCtaIdx.x", 2) + blockIdx_z = T.launch_thread("blockIdx.z", 12) + clusterCtaIdx_y = T.launch_thread("clusterCtaIdx.y", 2) + clusterCtaIdx_z = T.launch_thread("clusterCtaIdx.z", 1) + blockIdx_x = T.launch_thread("blockIdx.x", 8) + threadIdx_x = T.launch_thread("threadIdx.x", 384) + blockIdx_y = T.launch_thread("blockIdx.y", 10) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - cbx: Tx.let[Tx.int32] = clusterCtaIdx_x - cby: Tx.let[Tx.int32] = clusterCtaIdx_y - cbz: Tx.let[Tx.int32] = clusterCtaIdx_z - clx: Tx.let[Tx.int32] = Tx.ptx.fetch_register(32, "clusterid.x") - cly: Tx.let[Tx.int32] = Tx.ptx.fetch_register(32, "clusterid.y") - clz: Tx.let[Tx.int32] = Tx.ptx.fetch_register(32, "clusterid.z") - wg_id: Tx.let[Tx.int32] = warp_id_in_cta // 4 - warp_id: Tx.let[Tx.int32] = warp_id_in_cta % 4 - lane_id: Tx.let[Tx.int32] = threadIdx_x % 32 - tid_in_wg: Tx.let[Tx.int32] = threadIdx_x % 128 - Tx.evaluate(bx + by + bz) - Tx.evaluate(cbx + cby + cbz) - Tx.evaluate(clx + cly + clz) - Tx.evaluate(wg_id + warp_id + lane_id + tid_in_wg) + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + cbx: T.let[T.int32] = clusterCtaIdx_x + cby: T.let[T.int32] = clusterCtaIdx_y + cbz: T.let[T.int32] = clusterCtaIdx_z + clx: T.let[T.int32] = T.ptx.fetch_register(32, "clusterid.x") + cly: T.let[T.int32] = T.ptx.fetch_register(32, "clusterid.y") + clz: T.let[T.int32] = T.ptx.fetch_register(32, "clusterid.z") + wg_id: T.let[T.int32] = warp_id_in_cta // 4 + warp_id_in_wg: T.let[T.int32] = warp_id_in_cta % 4 + lane_id: T.let[T.int32] = threadIdx_x % 32 + tid_in_wg: T.let[T.int32] = threadIdx_x % 128 + T.evaluate(bx + by + bz) + T.evaluate(cbx + cby + cbz) + T.evaluate(clx + cly + clz) + T.evaluate(wg_id + warp_id_in_wg + lane_id + tid_in_wg) compare(before3, after3, LowerTIRx) def test_lower_scope_id2(): - @Tx.inline + @T.inline def func(warp_id, tx): - with Tx.cta(): - wg_id = Tx.warpgroup_id([2]) - with Tx.thread(): - Tx.evaluate(wg_id + warp_id + tx) + wg_id = T.warpgroup_id([2]) + T.evaluate(wg_id + warp_id + tx) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before(): - Tx.device_entry() - bx, by, bz = Tx.cta_id([3, 4, 5]) - warp_id = Tx.warp_id([8]) - tx = Tx.thread_id([256]) + T.device_entry() + bx, by, bz = T.cta_id([3, 4, 5]) + warp_id = T.warp_id([8]) + tx = T.thread_id([256]) func(warp_id, tx) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def after(): - blockIdx_x = Tx.launch_thread("blockIdx.x", 3) - threadIdx_x = Tx.launch_thread("threadIdx.x", 256) - blockIdx_y = Tx.launch_thread("blockIdx.y", 4) - blockIdx_z = Tx.launch_thread("blockIdx.z", 5) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + blockIdx_x = T.launch_thread("blockIdx.x", 3) + threadIdx_x = T.launch_thread("threadIdx.x", 256) + blockIdx_y = T.launch_thread("blockIdx.y", 4) + blockIdx_z = T.launch_thread("blockIdx.z", 5) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - wg_id: Tx.let[Tx.int32] = warp_id_in_cta // 4 - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - warp_id: Tx.let[Tx.int32] = warp_id_in_cta - tx: Tx.let[Tx.int32] = threadIdx_x - Tx.evaluate(wg_id + warp_id + tx) + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + warp_id: T.let[T.int32] = warp_id_in_cta + tx: T.let[T.int32] = threadIdx_x + wg_id: T.let[T.int32] = warp_id_in_cta // 4 + T.evaluate(wg_id + warp_id + tx) compare(before, after, LowerTIRx) @@ -496,108 +451,102 @@ def after(): @pytest.mark.skip( reason=( "Tested multi-kernel-per-PrimFunc behavior where a second sibling " - "`with Tx.thread():` would redefine scope-ids and produce a second " - "launch. The Tx.device_entry() refactor allows only one device-region " + "`with T.thread():` would redefine scope-ids and produce a second " + "launch. The T.device_entry() refactor allows only one device-region " "marker per PrimFunc; this case is out of scope." ) ) def test_lower_scope_id3(): - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before(): - Tx.device_entry() - bx, by, bz = Tx.cta_id([3, 4, 5]) - warp_id = Tx.warp_id([4]) - tx = Tx.thread_id([128]) - with Tx.cta(): - with Tx.thread(): - Tx.evaluate(bx + by + bz + warp_id + tx) - bx, by, bz = Tx.cta_id([6, 7, 8]) - warp_id = Tx.warp_id([8]) - tx = Tx.thread_id([256]) - with Tx.cta(): - with Tx.thread(): - Tx.evaluate(bx + by + bz + warp_id + tx) - - @Tx.prim_func(private=True) + T.device_entry() + bx, by, bz = T.cta_id([3, 4, 5]) + warp_id = T.warp_id([4]) + tx = T.thread_id([128]) + T.evaluate(bx + by + bz + warp_id + tx) + bx, by, bz = T.cta_id([6, 7, 8]) + warp_id = T.warp_id([8]) + tx = T.thread_id([256]) + T.evaluate(bx + by + bz + warp_id + tx) + + @T.prim_func(private=True) def after(): - with Tx.launch_thread("blockIdx.x", 3) as blockIdx_x: - threadIdx_x = Tx.launch_thread("threadIdx.x", 128) - blockIdx_y = Tx.launch_thread("blockIdx.y", 4) - blockIdx_z = Tx.launch_thread("blockIdx.z", 5) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + with T.launch_thread("blockIdx.x", 3) as blockIdx_x: + threadIdx_x = T.launch_thread("threadIdx.x", 128) + blockIdx_y = T.launch_thread("blockIdx.y", 4) + blockIdx_z = T.launch_thread("blockIdx.z", 5) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - warp_id: Tx.let[Tx.int32] = warp_id_in_cta - tx: Tx.let[Tx.int32] = threadIdx_x - Tx.evaluate(bx + by + bz + warp_id + tx) - blockIdx_x = Tx.launch_thread("blockIdx.x", 6) - threadIdx_x = Tx.launch_thread("threadIdx.x", 256) - blockIdx_y = Tx.launch_thread("blockIdx.y", 7) - blockIdx_z = Tx.launch_thread("blockIdx.z", 8) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + warp_id: T.let[T.int32] = warp_id_in_cta + tx: T.let[T.int32] = threadIdx_x + T.evaluate(bx + by + bz + warp_id + tx) + blockIdx_x = T.launch_thread("blockIdx.x", 6) + threadIdx_x = T.launch_thread("threadIdx.x", 256) + blockIdx_y = T.launch_thread("blockIdx.y", 7) + blockIdx_z = T.launch_thread("blockIdx.z", 8) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - warp_id: Tx.let[Tx.int32] = warp_id_in_cta - tx: Tx.let[Tx.int32] = threadIdx_x - Tx.evaluate(bx + by + bz + warp_id + tx) + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + warp_id: T.let[T.int32] = warp_id_in_cta + tx: T.let[T.int32] = threadIdx_x + T.evaluate(bx + by + bz + warp_id + tx) compare(before, after, LowerTIRx) def test_lower_layout(): - @Tx.prim_func(private=True) - def before(A: Tx.Buffer((128, 32), "float16")) -> None: - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - Tx.warp_id([4]) - Tx.lane_id([32]) - tid = Tx.thread_id([128]) - with Tx.cta(): - A_smem = Tx.alloc_buffer( - [128, 32], dtype="float16", scope="shared", layout=Tx.SwizzleLayout(3, 3, 3) - ) - with Tx.thread(): - thread_col = Tx.meta_var(4) - thread_row = Tx.meta_var(32) - for tile in Tx.serial(128 // thread_row): - row = Tx.meta_var(tile * thread_row + tid // thread_col) - col = Tx.meta_var(tid % thread_col * 8) - for vec in Tx.vectorized(8): - A_smem[row, col + vec] = A[bx * 128 + row, col + vec] - - @Tx.prim_func(private=True) - def after(A_handle: Tx.handle) -> None: - A = Tx.match_buffer(A_handle, (128, 32), "float16", layout=None) - A_1 = Tx.decl_buffer((4096,), "float16", data=A.data, layout=None) - blockIdx_x = Tx.launch_thread("blockIdx.x", 1) - threadIdx_x = Tx.launch_thread("threadIdx.x", 128) - blockIdx_y = Tx.launch_thread("blockIdx.y", 1) - blockIdx_z = Tx.launch_thread("blockIdx.z", 1) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + @T.prim_func(private=True) + def before(A: T.Buffer((128, 32), "float16")) -> None: + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + T.warp_id([4]) + T.lane_id([32]) + tid = T.thread_id([128]) + A_smem = T.alloc_buffer( + [128, 32], dtype="float16", scope="shared", layout=T.SwizzleLayout(3, 3, 3) + ) + thread_col = T.meta_var(4) + thread_row = T.meta_var(32) + for tile in T.serial(128 // thread_row): + row = T.meta_var(tile * thread_row + tid // thread_col) + col = T.meta_var(tid % thread_col * 8) + for vec in T.vectorized(8): + A_smem[row, col + vec] = A[bx * 128 + row, col + vec] + + @T.prim_func(private=True) + def after(A_handle: T.handle) -> None: + A = T.match_buffer(A_handle, (128, 32), "float16", layout=None) + A_1 = T.decl_buffer((4096,), "float16", data=A.data, layout=None) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + blockIdx_y = T.launch_thread("blockIdx.y", 1) + blockIdx_z = T.launch_thread("blockIdx.z", 1) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - v: Tx.let[Tx.int32] = warp_id_in_cta - v_1: Tx.let[Tx.int32] = threadIdx_x % 32 - tid: Tx.let[Tx.int32] = threadIdx_x - Tx.evaluate(v) - Tx.evaluate(v_1) - A_smem = Tx.alloc_shared((4096,), "float16", layout=None) + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + v: T.let[T.int32] = warp_id_in_cta + v_1: T.let[T.int32] = threadIdx_x % 32 + tid: T.let[T.int32] = threadIdx_x + T.evaluate(v) + T.evaluate(v_1) + A_smem = T.alloc_shared((4096,), "float16", layout=None) for tile in range(4): - for vec in Tx.vectorized(8): + for vec in T.vectorized(8): A_smem[ - Tx.shift_left( - Tx.bitwise_xor( + T.shift_left( + T.bitwise_xor( tile * 128 + threadIdx_x, - Tx.shift_right(Tx.bitwise_and(tile * 128 + threadIdx_x, 56), 3), + T.shift_right(T.bitwise_and(tile * 128 + threadIdx_x, 56), 3), ), 3, ) @@ -608,82 +557,75 @@ def after(A_handle: Tx.handle) -> None: def test_lower_opcall_fail(): - @Tx.prim_func - def test(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, (64,), "float32", scope="global") - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - Tx.warp_id([1]) - Tx.lane_id([32]) - with Tx.cta(): - A_smem = Tx.alloc_buffer([64], dtype="float32", scope="shared") - Tx.copy(A[0:64], A_smem[0:64]) - for i in range(10): - Tx.fill(A_smem[0:64], Tx.float32(0)) - Tx.gemm(A_smem, A_smem, A_smem, A_smem) - Tx.copy(A_smem[0:64], A[0:64]) + @T.prim_func + def test(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (64,), "float32", scope="global") + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + T.warp_id([1]) + T.lane_id([32]) + A_smem = T.alloc_buffer([64], dtype="float32", scope="shared") + Tx.cta.copy(A[0:64], A_smem[0:64]) + for i in range(10): + Tx.cta.fill(A_smem[0:64], T.float32(0)) + Tx.cta.gemm(A_smem, A_smem, A_smem, A_smem) + Tx.cta.copy(A_smem[0:64], A[0:64]) with pytest.raises(Exception): LowerTIRx()(tvm.IRModule({"main": test})) def test_lower_decl_buffer_access_ptr(): - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before(): - Tx.device_entry() - Tx.cta_id([1]) - Tx.thread_id([128]) - with Tx.cta(): - buf = Tx.alloc_buffer([1024], "uint8", scope="shared.dyn") - A = Tx.decl_buffer([128], "float16", buf.data, elem_offset=32) - with Tx.thread(): - Tx.evaluate(A.access_ptr("rw", offset=A.elem_offset_of([64]))) - - @Tx.prim_func(private=True) + T.device_entry() + T.cta_id([1]) + T.thread_id([128]) + buf = T.alloc_buffer([1024], "uint8", scope="shared.dyn") + A = T.decl_buffer([128], "float16", buf.data, elem_offset=32) + T.evaluate(A.access_ptr("rw", offset=A.elem_offset_of([64]))) + + @T.prim_func(private=True) def after(): - blockIdx_x = Tx.launch_thread("blockIdx.x", 1) - threadIdx_x = Tx.launch_thread("threadIdx.x", 128) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - v: Tx.let[Tx.int32] = blockIdx_x - v_1: Tx.let[Tx.int32] = threadIdx_x - Tx.evaluate(v) - Tx.evaluate(v_1) - buf = Tx.alloc_buffer((1024,), "uint8", scope="shared.dyn", layout=None) - A = Tx.decl_buffer( + v: T.let[T.int32] = blockIdx_x + v_1: T.let[T.int32] = threadIdx_x + T.evaluate(v) + T.evaluate(v_1) + buf = T.alloc_buffer((1024,), "uint8", scope="shared.dyn", layout=None) + A = T.decl_buffer( (128,), "float16", data=buf.data, elem_offset=32, scope="shared.dyn", layout=None ) - Tx.tvm_access_ptr( - Tx.type_annotation("float16"), buf.data, Tx.Add(32, 64), Tx.Sub(128, 64), 3 - ) + T.tvm_access_ptr(T.type_annotation("float16"), buf.data, T.Add(32, 64), T.Sub(128, 64), 3) compare(before, after, LowerTIRx) def test_lower_separate_scope_id_def(): - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before(): - Tx.device_entry() - Tx.cta_id([1]) - with Tx.cta(): - tx = Tx.thread_id([128]) - if tx == 0: - with Tx.thread(): - Tx.evaluate(tx) - - @Tx.prim_func(private=True) + T.device_entry() + T.cta_id([1]) + tx = T.thread_id([128]) + if tx == 0: + T.evaluate(tx) + + @T.prim_func(private=True) def after(): - blockIdx_x = Tx.launch_thread("blockIdx.x", 1) - threadIdx_x = Tx.launch_thread("threadIdx.x", 128) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - tx: Tx.let[Tx.int32] = threadIdx_x - v: Tx.let[Tx.int32] = blockIdx_x - Tx.evaluate(v) + v: T.let[T.int32] = blockIdx_x + tx: T.let[T.int32] = threadIdx_x + T.evaluate(v) if tx == 0: - Tx.evaluate(tx) + T.evaluate(tx) compare(before, after, LowerTIRx) @@ -699,24 +641,22 @@ def test_lower_exec_context_infers_plain_predicate_for_dispatch(): def _probe(op_call, sctx): seen.append({"scope_kind": sctx.scope_kind, "inter": sctx.inter, "intra": sctx.intra}) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - if (warp_id == 0) & (lane_id == 0): - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + if (warp_id == 0) & (lane_id == 0): + Tx.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -740,30 +680,27 @@ def test_lower_exec_context_infers_warpgroup_range_predicate_for_dispatch(): def _probe(op_call, sctx): seen.append({"scope_kind": sctx.scope_kind, "inter": sctx.inter, "intra": sctx.intra}) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([2]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - with Tx.cta(): - if wg_id == 0: - with Tx.warpgroup(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) - if (0 <= wg_id) & (wg_id < 1): - with Tx.warpgroup(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) - with Tx.warpgroup((0 <= wg_id) & (wg_id < 1)): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + wg_id = T.warpgroup_id([2]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + if wg_id == 0: + Tx.wg.copy(B[0:1], A[0:1], dispatch=variant) + if (0 <= wg_id) & (wg_id < 1): + Tx.wg.copy(B[0:1], A[0:1], dispatch=variant) + if (0 <= wg_id) & (wg_id < 1): + Tx.wg.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -788,23 +725,21 @@ def test_lower_exec_context_tracks_cta_thread_range_predicate_for_dispatch(): def _probe(op_call, sctx): seen.append({"scope_kind": sctx.scope_kind, "inter": sctx.inter, "intra": sctx.intra}) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - tid = Tx.thread_id([256]) - with Tx.cta(): - if (0 <= tid) & (tid < 128): - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + tid = T.thread_id([256]) + if (0 <= tid) & (tid < 128): + Tx.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -828,22 +763,21 @@ def test_lower_exec_context_tracks_cta_thread_single_warp_range_predicate(): def _probe(op_call, sctx): seen.append({"scope_kind": sctx.scope_kind, "inter": sctx.inter, "intra": sctx.intra}) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - tid = Tx.thread_id([256]) - with Tx.cta(): - with Tx.thread((34 <= tid) & (tid < 40)): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + tid = T.thread_id([256]) + if (34 <= tid) & (tid < 40): + Tx.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -867,25 +801,23 @@ def test_lower_exec_context_tracks_warpgroup_thread_range_predicate(): def _probe(op_call, sctx): seen.append({"scope_kind": sctx.scope_kind, "inter": sctx.inter, "intra": sctx.intra}) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([2]) - tid_in_wg = Tx.thread_id_in_wg([128]) - with Tx.cta(): - if wg_id == 1: - with Tx.warpgroup(): - if (32 <= tid_in_wg) & (tid_in_wg < 64): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + wg_id = T.warpgroup_id([2]) + tid_in_wg = T.thread_id_in_wg([128]) + if wg_id == 1: + if (32 <= tid_in_wg) & (tid_in_wg < 64): + Tx.wg.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -909,24 +841,22 @@ def test_lower_exec_context_tracks_dependent_conjunctive_predicate(): def _probe(op_call, sctx): seen.append({"scope_kind": sctx.scope_kind, "inter": sctx.inter, "intra": sctx.intra}) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([2]) - tid_in_wg = Tx.thread_id_in_wg([128]) - with Tx.cta(): - if ((32 <= tid_in_wg) & (tid_in_wg < 64)) & (wg_id == 1): - with Tx.warpgroup(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + wg_id = T.warpgroup_id([2]) + tid_in_wg = T.thread_id_in_wg([128]) + if ((32 <= tid_in_wg) & (tid_in_wg < 64)) & (wg_id == 1): + Tx.wg.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -940,74 +870,67 @@ def before(A_ptr: Tx.handle, B_ptr: Tx.handle): def test_lower_exec_context_keeps_plain_predicate_condition(): - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([2]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - with Tx.cta(): - if wg_id == 0: - Tx.evaluate(A[0]) + @T.prim_func(private=True) + def before(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + wg_id = T.warpgroup_id([2]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + if wg_id == 0: + T.evaluate(A[0]) with tvm.target.Target("cuda"): lowered = LowerTIRx()(tvm.IRModule({"main": before})) - script = lowered.script(extra_config={"tirx.prefix": "Tx"}) + script = lowered.script(extra_config={"tirx.prefix": "T"}) assert "if wg_id == 0:" in script assert "0 <= wg_id" not in script assert "wg_id < 1" not in script def test_lower_exec_context_keeps_plain_scope_predicate_condition(): - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([2]) - Tx.warp_id_in_wg([4]) - Tx.lane_id([32]) - with Tx.cta(): - if wg_id == 0: - with Tx.warpgroup(): - with Tx.thread(): - A[0] = Tx.float32(1) + @T.prim_func(private=True) + def before(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + wg_id = T.warpgroup_id([2]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + if wg_id == 0: + A[0] = T.float32(1) with tvm.target.Target("cuda"): lowered = LowerTIRx()(tvm.IRModule({"main": before})) - script = lowered.script(extra_config={"tirx.prefix": "Tx"}) + script = lowered.script(extra_config={"tirx.prefix": "T"}) assert "if wg_id == 0:" in script assert "0 <= wg_id" not in script assert "wg_id < 1" not in script def test_simplify_uses_floor_div_scope_predicate_as_context_fact(): - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (16,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - wg_id = Tx.warpgroup_id([2]) - warp_id = Tx.warp_id_in_wg([4]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - if wg_id == 0: - with Tx.warpgroup(): - with Tx.thread(): - A[warp_id] = Tx.float32(lane_id) + @T.prim_func(private=True) + def before(A_ptr: T.handle): + A = T.match_buffer(A_ptr, (16,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + wg_id = T.warpgroup_id([2]) + warp_id = T.warp_id_in_wg([4]) + lane_id = T.lane_id([32]) + if wg_id == 0: + A[warp_id] = T.float32(lane_id) with tvm.target.Target("cuda"): lowered = LowerTIRx()(tvm.IRModule({"main": before})) simplified = StmtSimplify()(lowered) - script = simplified.script(extra_config={"tirx.prefix": "Tx"}) + script = simplified.script(extra_config={"tirx.prefix": "T"}) assert "if warp_id_in_cta // 4 == 0:" in script assert "if 0 <= warp_id_in_cta" not in script - assert "A_1[warp_id_in_cta] = Tx.Cast" in script + assert "A_1[warp_id_in_cta] = T.Cast" in script assert "A_1[warp_id_in_cta % 4]" not in script @@ -1020,38 +943,35 @@ def test_lower_exec_context_selector_filter_for_elect_sync(): @register_dispatch("copy", "cuda", variant=variant, priority=10_000) def _probe(op_call, sctx): - seen.append(sctx.inter["laneid"][1].script(extra_config={"tirx.prefix": "Tx"})) + seen.append(sctx.inter["laneid"][1].script(extra_config={"tirx.prefix": "T"})) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - Tx.warp_id([1]) - lane_id = Tx.lane_id([32]) - with Tx.warp(): - if Tx.ptx.elect_sync(): - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) - if Tx.ptx.elect_sync() != 0: - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) - with Tx.thread(Tx.ptx.elect_sync()): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + T.warp_id([1]) + lane_id = T.lane_id([32]) + if T.ptx.elect_sync(): + Tx.copy(B[0:1], A[0:1], dispatch=variant) + if T.ptx.elect_sync() != 0: + Tx.copy(B[0:1], A[0:1], dispatch=variant) + if T.ptx.elect_sync(): + Tx.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) assert len(seen) == 3 - assert any("Tx.selector(lane_id, Tx.ptx.elect_sync())" in item for item in seen) - assert any("Tx.selector(lane_id, Tx.ptx.elect_sync() != Tx.uint32(0))" in item for item in seen) + assert any("T.selector(lane_id, T.ptx.elect_sync())" in item for item in seen) + assert any("T.selector(lane_id, T.ptx.elect_sync() != T.uint32(0))" in item for item in seen) def test_lower_exec_context_scope_guard_mixes_structural_and_selector(): @@ -1065,23 +985,22 @@ def test_lower_exec_context_scope_guard_mixes_structural_and_selector(): def _probe(op_call, sctx): seen.append({"inter": sctx.inter, "intra": sctx.intra}) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id([1]) - warp_id = Tx.warp_id([4]) - lane_id = Tx.lane_id([32]) - with Tx.cta(): - with Tx.thread((warp_id == 0) & Tx.ptx.elect_sync()): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id([1]) + warp_id = T.warp_id([4]) + lane_id = T.lane_id([32]) + if (warp_id == 0) & T.ptx.elect_sync(): + Tx.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -1090,8 +1009,8 @@ def before(A_ptr: Tx.handle, B_ptr: Tx.handle): assert _int_pair(seen[0]["inter"], "warpid") == (1, 0) assert int(seen[0]["inter"]["laneid"][0]) == 1 assert ( - seen[0]["inter"]["laneid"][1].script(extra_config={"tirx.prefix": "Tx"}) - == "Tx.selector(lane_id, Tx.ptx.elect_sync())" + seen[0]["inter"]["laneid"][1].script(extra_config={"tirx.prefix": "T"}) + == "T.selector(lane_id, T.ptx.elect_sync())" ) assert len(seen[0]["intra"]) == 0 @@ -1107,23 +1026,21 @@ def test_lower_exec_context_tracks_factorized_cta_predicate(): def _probe(op_call, sctx): seen.append(sctx.inter) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - cbx, cby = Tx.cta_id_in_cluster([2, 3]) - Tx.thread_id([32]) - with Tx.cta(): - if cbx == 0: - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + cbx, cby = T.cta_id_in_cluster([2, 3]) + T.thread_id([32]) + if cbx == 0: + Tx.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -1145,9 +1062,9 @@ def test_lower_exec_context_keeps_kernel_cta_predicate_out_of_cluster_active_set def _probe_kernel(op_call, sctx): seen["kernel"] = sctx.inter - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl @@ -1155,27 +1072,24 @@ def impl(): def _probe_cluster(op_call, sctx): seen["cluster"] = sctx.inter - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - bx = Tx.cta_id([8]) - cbx = Tx.cta_id_in_cluster([2]) - Tx.thread_id([32]) - with Tx.cta(): - if bx == 0: - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=kernel_variant) - if cbx == 0: - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=cluster_variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + bx = T.cta_id([8]) + cbx = T.cta_id_in_cluster([2]) + T.thread_id([32]) + if bx == 0: + Tx.copy(B[0:1], A[0:1], dispatch=kernel_variant) + if cbx == 0: + Tx.copy(B[0:1], A[0:1], dispatch=cluster_variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -1196,23 +1110,21 @@ def test_lower_exec_context_tracks_cta_axis_modulo_predicate(): def _probe(op_call, sctx): seen.append(sctx.inter) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - cbx, cby = Tx.cta_id_in_cluster([4, 2]) - Tx.thread_id([32]) - with Tx.cta(): - if cbx % 2 == 0: - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + cbx, cby = T.cta_id_in_cluster([4, 2]) + T.thread_id([32]) + if cbx % 2 == 0: + Tx.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -1233,24 +1145,22 @@ def test_lower_exec_context_tracks_cta_id_in_pair_predicate(): def _probe(op_call, sctx): seen.append(sctx.inter) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - cbx, cby = Tx.cta_id_in_cluster([4, 2]) - cta_id_in_pair = Tx.cta_id_in_pair() - Tx.thread_id([32]) - with Tx.cta(): - if cta_id_in_pair == 0: - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + cbx, cby = T.cta_id_in_cluster([4, 2]) + cta_id_in_pair = T.cta_id_in_pair() + T.thread_id([32]) + if cta_id_in_pair == 0: + Tx.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): lowered = LowerTIRx()(tvm.IRModule({"main": before})) @@ -1272,9 +1182,9 @@ def test_lower_exec_context_tracks_two_cta_pair_predicates(): def _probe_zero(op_call, sctx): seen["zero"] = sctx.inter - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl @@ -1282,27 +1192,24 @@ def impl(): def _probe_one(op_call, sctx): seen["one"] = sctx.inter - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - Tx.cta_id_in_cluster([2]) - cta_id_in_pair = Tx.cta_id_in_pair() - Tx.thread_id([32]) - with Tx.cta(): - if cta_id_in_pair == 0: - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=zero_variant) - if cta_id_in_pair == 1: - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=one_variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + T.cta_id_in_cluster([2]) + cta_id_in_pair = T.cta_id_in_pair() + T.thread_id([32]) + if cta_id_in_pair == 0: + Tx.copy(B[0:1], A[0:1], dispatch=zero_variant) + if cta_id_in_pair == 1: + Tx.copy(B[0:1], A[0:1], dispatch=one_variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -1323,25 +1230,23 @@ def test_lower_exec_context_tracks_cta_id_in_pair_after_axis_predicate(): def _probe(op_call, sctx): seen.append(sctx.inter) - @Tx.prim_func(private=True) + @T.prim_func(private=True) def impl(): - Tx.evaluate(0) + T.evaluate(0) return impl - @Tx.prim_func(private=True) - def before(A_ptr: Tx.handle, B_ptr: Tx.handle): - A = Tx.match_buffer(A_ptr, (1,), "float32", scope="global") - B = Tx.match_buffer(B_ptr, (1,), "float32", scope="global") - Tx.device_entry() - cbx, cby = Tx.cta_id_in_cluster([3, 2]) - cta_id_in_pair = Tx.cta_id_in_pair() - Tx.thread_id([32]) - with Tx.cta(): - if cbx == 0: - if cta_id_in_pair == 1: - with Tx.thread(): - Tx.copy(B[0:1], A[0:1], dispatch=variant) + @T.prim_func(private=True) + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (1,), "float32", scope="global") + B = T.match_buffer(B_ptr, (1,), "float32", scope="global") + T.device_entry() + cbx, cby = T.cta_id_in_cluster([3, 2]) + cta_id_in_pair = T.cta_id_in_pair() + T.thread_id([32]) + if cbx == 0: + if cta_id_in_pair == 1: + Tx.copy(B[0:1], A[0:1], dispatch=variant) with tvm.target.Target("cuda"): LowerTIRx()(tvm.IRModule({"main": before})) @@ -1352,66 +1257,63 @@ def before(A_ptr: Tx.handle, B_ptr: Tx.handle): def test_lower_buffer_offset(): - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before(): - Tx.device_entry() - Tx.cta_id([1]) - with Tx.cta(): - Tx.thread_id([128]) - with Tx.thread(): - A = Tx.alloc_buffer([64, 64], "float16", scope="local") - A0 = Tx.decl_buffer([64], "float16", A.data, elem_offset=A.elem_offset_of([32, 32])) - with Tx.thread(): - Tx.evaluate(Tx.address_of(A0[32])) - - @Tx.prim_func(private=True) + T.device_entry() + T.cta_id([1]) + T.thread_id([128]) + A = T.alloc_buffer([64, 64], "float16", scope="local") + A0 = T.decl_buffer([64], "float16", A.data, elem_offset=A.elem_offset_of([32, 32])) + T.evaluate(T.address_of(A0[32])) + + @T.prim_func(private=True) def after(): - blockIdx_x = Tx.launch_thread("blockIdx.x", 1) - threadIdx_x = Tx.launch_thread("threadIdx.x", 128) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - v: Tx.let[Tx.int32] = threadIdx_x - v_1: Tx.let[Tx.int32] = blockIdx_x - Tx.evaluate(v_1) - Tx.evaluate(v) - A = Tx.alloc_local((4096,), "float16", layout=None) - A0 = Tx.decl_buffer( + v: T.let[T.int32] = blockIdx_x + v_1: T.let[T.int32] = threadIdx_x + T.evaluate(v) + T.evaluate(v_1) + A = T.alloc_local((4096,), "float16", layout=None) + A0 = T.decl_buffer( (64,), "float16", data=A.data, elem_offset=2080, scope="local", layout=None ) - Tx.address_of(A0[32]) + T.address_of(A0[32]) compare(before, after, LowerTIRx) def test_lower_alloc_decl_buffer_outside_of_parser(): - @Tx.meta_class + @T.meta_class class State: def __init__(self, smem): - self.A = Tx.alloc_local([1], "float16") - self.B = Tx.alloc_local([1], "float16") - self.C = Tx.decl_buffer([1], "float16", smem, elem_offset=0, scope="shared.dyn") + self.A = T.alloc_local([1], "float16") + self.B = T.alloc_local([1], "float16") + self.C = T.decl_buffer([1], "float16", smem, elem_offset=0, scope="shared.dyn") def int_var1(val): - buf = Tx.local_scalar("int32") + buf = T.local_scalar("int32") if val is not None: - Tx.buffer_store(buf.buffer, val, 0) + T.buffer_store(buf.buffer, val, 0) return buf def int_var2(val): - buf = Tx.alloc_local([1], "int32") + buf = T.alloc_local([1], "int32") if val is not None: - Tx.buffer_store(buf, val, 0) + T.buffer_store(buf, val, 0) return buf - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before(): - Tx.device_entry() - smem = Tx.alloc_buffer([100], "uint8", scope="shared.dyn") + T.device_entry() + smem = T.alloc_buffer([100], "uint8", scope="shared.dyn") state = State(smem.data) - state.A[0] = Tx.float16(1) - state.B[0] = Tx.float16(2) - state.C[0] = Tx.float16(3) + state.A[0] = T.float16(1) + state.B[0] = T.float16(2) + state.C[0] = T.float16(3) D = int_var1(1) D = D + 1 E = int_var1(2) @@ -1421,27 +1323,27 @@ def before(): G = int_var2(4) G[0] = G[0] + 4 - @Tx.prim_func(private=True) + @T.prim_func(private=True) def after(): - smem = Tx.alloc_buffer([100], "uint8", scope="shared.dyn", layout=None) - A = Tx.alloc_local((1,), "float16", layout=None) - B = Tx.alloc_local((1,), "float16", layout=None) - C = Tx.decl_buffer( + smem = T.alloc_buffer([100], "uint8", scope="shared.dyn", layout=None) + A = T.alloc_local((1,), "float16", layout=None) + B = T.alloc_local((1,), "float16", layout=None) + C = T.decl_buffer( (1,), "float16", data=smem.data, elem_offset=0, scope="shared.dyn", layout=None ) - A[0] = Tx.float16(1) - B[0] = Tx.float16(2) - C[0] = Tx.float16(3) - D = Tx.alloc_local((1,), "int32", layout=None) + A[0] = T.float16(1) + B[0] = T.float16(2) + C[0] = T.float16(3) + D = T.alloc_local((1,), "int32", layout=None) D = 1 D = D[0] + 1 - E = Tx.alloc_local((1,), "int32", layout=None) + E = T.alloc_local((1,), "int32", layout=None) E = 2 E = E[0] + 2 - F = Tx.alloc_local((1,), "int32", layout=None) + F = T.alloc_local((1,), "int32", layout=None) F = 3 F = F[0] + 3 - G = Tx.alloc_local((1,), "int32", layout=None) + G = T.alloc_local((1,), "int32", layout=None) G = 4 G = G[0] + 4 @@ -1451,42 +1353,38 @@ def after(): def test_alloc_buffer_with_thread_axis_layout(): """alloc_buffer with thread-axis layout should lower to 1D physical buffer with memory-axis span.""" # noqa: E501 - @Tx.prim_func(private=True) - def before(out: Tx.Buffer((128, 4), "float32")) -> None: - Tx.device_entry() - bx, by, bz = Tx.cta_id([1, 1, 1]) - Tx.warpgroup_id([1]) - warp_id = Tx.warp_id_in_wg([4]) - lane_id = Tx.lane_id([32]) - with Tx.warpgroup(): - with Tx.thread(): - reg_wg = Tx.alloc_buffer( - (128, 4), "float32", scope="local", layout=wg_local_layout(4) - ) - reg = reg_wg.local(4) - for i in Tx.serial(4): - reg[i] = out[lane_id + warp_id * 32, i] - - @Tx.prim_func(private=True) - def after(out_handle: Tx.handle): - out = Tx.match_buffer(out_handle, (128, 4), layout=None) - out_1 = Tx.decl_buffer((512,), data=out.data, layout=None) - blockIdx_x = Tx.launch_thread("blockIdx.x", 1) - threadIdx_x = Tx.launch_thread("threadIdx.x", 128) - blockIdx_y = Tx.launch_thread("blockIdx.y", 1) - blockIdx_z = Tx.launch_thread("blockIdx.z", 1) - warp_id_in_cta: Tx.let[Tx.int32] = Tx.tvm_warp_shuffle( - Tx.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 + @T.prim_func(private=True) + def before(out: T.Buffer((128, 4), "float32")) -> None: + T.device_entry() + bx, by, bz = T.cta_id([1, 1, 1]) + T.warpgroup_id([1]) + warp_id = T.warp_id_in_wg([4]) + lane_id = T.lane_id([32]) + reg_wg = T.alloc_buffer((128, 4), "float32", scope="local", layout=wg_local_layout(4)) + reg = reg_wg.local(4) + for i in T.serial(4): + reg[i] = out[lane_id + warp_id * 32, i] + + @T.prim_func(private=True) + def after(out_handle: T.handle): + out = T.match_buffer(out_handle, (128, 4), layout=None) + out_1 = T.decl_buffer((512,), data=out.data, layout=None) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + blockIdx_y = T.launch_thread("blockIdx.y", 1) + blockIdx_z = T.launch_thread("blockIdx.z", 1) + warp_id_in_cta: T.let[T.int32] = T.tvm_warp_shuffle( + T.uint32(4294967295), threadIdx_x // 32, 0, 32, 32 ) - bx: Tx.let[Tx.int32] = blockIdx_x - by: Tx.let[Tx.int32] = blockIdx_y - bz: Tx.let[Tx.int32] = blockIdx_z - v: Tx.let[Tx.int32] = warp_id_in_cta // 4 - warp_id: Tx.let[Tx.int32] = warp_id_in_cta % 4 - lane_id: Tx.let[Tx.int32] = threadIdx_x % 32 - Tx.evaluate(v) - reg_wg = Tx.alloc_local((4,), layout=None) - reg = Tx.decl_buffer((4,), data=reg_wg.data, scope="local", layout=None) + bx: T.let[T.int32] = blockIdx_x + by: T.let[T.int32] = blockIdx_y + bz: T.let[T.int32] = blockIdx_z + v: T.let[T.int32] = warp_id_in_cta // 4 + warp_id: T.let[T.int32] = warp_id_in_cta % 4 + lane_id: T.let[T.int32] = threadIdx_x % 32 + T.evaluate(v) + reg_wg = T.alloc_local((4,), layout=None) + reg = T.decl_buffer((4,), data=reg_wg.data, scope="local", layout=None) for i in range(4): reg[i] = out_1[warp_id_in_cta % 4 * 128 + threadIdx_x % 32 * 4 + i] @@ -1502,13 +1400,13 @@ def test_scope_id_compliment_no_div_by_zero(): """ with pytest.raises(Exception): - @Tx.prim_func - def func(A: Tx.Buffer((1,))): - Tx.device_entry() - cb_m, cb_n = Tx.cta_id_in_cluster([2, 2]) - bx = Tx.cta_id([1]) - tx = Tx.thread_id([128]) - Tx.evaluate(bx + cb_m + cb_n + tx) + @T.prim_func + def func(A: T.Buffer((1,))): + T.device_entry() + cb_m, cb_n = T.cta_id_in_cluster([2, 2]) + bx = T.cta_id([1]) + tx = T.thread_id([128]) + T.evaluate(bx + cb_m + cb_n + tx) def test_scope_id_compliment_non_divisible(): @@ -1519,13 +1417,13 @@ def test_scope_id_compliment_non_divisible(): """ with pytest.raises(Exception): - @Tx.prim_func + @T.prim_func def func(): - Tx.device_entry() - bx = Tx.cta_id([1]) - wid = Tx.warp_id([3]) - tx = Tx.thread_id([100]) - Tx.evaluate(bx + wid + tx) + T.device_entry() + bx = T.cta_id([1]) + wid = T.warp_id([3]) + tx = T.thread_id([100]) + T.evaluate(bx + wid + tx) def test_empty_kernel_no_thread_id(): @@ -1534,13 +1432,11 @@ def test_empty_kernel_no_thread_id(): Before the fix, this would crash late in codegen with poor diagnostics. """ - @Tx.prim_func + @T.prim_func def func(): - Tx.device_entry() - bx = Tx.cta_id([32]) - with Tx.cta(): - with Tx.thread(): - Tx.evaluate(bx) + T.device_entry() + bx = T.cta_id([32]) + T.evaluate(bx) with pytest.raises(Exception, match="kernel has no thread launch parameters"): with tvm.target.Target("cuda"): @@ -1548,17 +1444,16 @@ def func(): def test_lower_preferred_cluster(): - @Tx.prim_func(private=True) + @T.prim_func(private=True) def before() -> None: - Tx.device_entry() - bx = Tx.cta_id([8]) - cbx, cby = Tx.cta_id_in_cluster([2, 1], preferred=[2, 2]) - tx = Tx.thread_id([128]) - Tx.evaluate(bx + cbx + cby + tx) + T.device_entry() + bx = T.cta_id([8]) + cbx, cby = T.cta_id_in_cluster([2, 1], preferred=[2, 2]) + tx = T.thread_id([128]) + T.evaluate(bx + cbx + cby + tx) with tvm.target.Target("cuda"): after_mod = LowerTIRx()(tvm.IRModule({"main": before})) - assert not _contains_exec_scope(after_mod) after_str = str(after_mod["main"]) assert 'launch_thread("clusterCtaIdx.x", 2)' in after_str assert 'launch_thread("clusterCtaIdx.y", 1)' in after_str diff --git a/tests/python/tirx/transform/test_transform_naive_allocator.py b/tests/python/tirx/transform/test_transform_naive_allocator.py index 16e48a86b774..7d77c6114c00 100644 --- a/tests/python/tirx/transform/test_transform_naive_allocator.py +++ b/tests/python/tirx/transform/test_transform_naive_allocator.py @@ -18,7 +18,8 @@ import tvm import tvm.testing from tvm.ir import assert_structural_equal -from tvm.script import tirx as Tx +from tvm.script import tirx as T +from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout from tvm.tirx.transform.trn import TrnNaiveAllocator @@ -30,19 +31,19 @@ def test_one_alloc(): dst_layout = TileLayout(S[(128, 512) : (1 @ P, 1 @ F)]) # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) + @T.prim_func + def copy(A_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout) Tx.copy(A_sbuf, A) - @Tx.prim_func - def expected(A_ptr: Tx.handle) -> None: - Tx.func_attr({"global_symbol": "copy"}) - A = Tx.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout, allocated_addr=[0]) # noqa: E501 + @T.prim_func + def expected(A_ptr: T.handle) -> None: + T.func_attr({"global_symbol": "copy"}) + A = T.match_buffer(A_ptr, src_shape, "float32", layout=src_layout) + T.device_entry() + A_sbuf = T.alloc_buffer(dst_shape, "float32", scope="trn.sbuf", layout=dst_layout, allocated_addr=[0]) # noqa: E501 Tx.copy(A_sbuf, A) # fmt: on @@ -53,19 +54,19 @@ def expected(A_ptr: Tx.handle) -> None: def test_two_alloc(): # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF") - B_sbuf = Tx.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF") + @T.prim_func + def copy(A_ptr: T.handle) -> None: + T.device_entry() + A_sbuf = T.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF") + B_sbuf = T.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF") Tx.copy(B_sbuf[0:256, :], A_sbuf) - @Tx.prim_func - def expected(A_ptr: Tx.handle) -> None: - Tx.func_attr({"global_symbol": "copy"}) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[0]) # noqa: E501 - B_sbuf = Tx.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[2*512*4]) # noqa: E501 + @T.prim_func + def expected(A_ptr: T.handle) -> None: + T.func_attr({"global_symbol": "copy"}) + T.device_entry() + A_sbuf = T.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[0]) # noqa: E501 + B_sbuf = T.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[2*512*4]) # noqa: E501 Tx.copy(B_sbuf[0:256, :], A_sbuf) # fmt: on @@ -76,19 +77,19 @@ def expected(A_ptr: Tx.handle) -> None: def test_existing_alloc(): # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF") - B_sbuf = Tx.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[1]) # noqa: E501 + @T.prim_func + def copy(A_ptr: T.handle) -> None: + T.device_entry() + A_sbuf = T.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF") + B_sbuf = T.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[1]) # noqa: E501 Tx.copy(B_sbuf[0:256, :], A_sbuf) - @Tx.prim_func - def expected(A_ptr: Tx.handle) -> None: - Tx.func_attr({"global_symbol": "copy"}) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[4*512*4+1]) # noqa: E501 - B_sbuf = Tx.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[1]) # noqa: E501 + @T.prim_func + def expected(A_ptr: T.handle) -> None: + T.func_attr({"global_symbol": "copy"}) + T.device_entry() + A_sbuf = T.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[4*512*4+1]) # noqa: E501 + B_sbuf = T.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[1]) # noqa: E501 Tx.copy(B_sbuf[0:256, :], A_sbuf) # fmt: on @@ -99,21 +100,21 @@ def expected(A_ptr: Tx.handle) -> None: def test_workspace(): # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF") - B_sbuf = Tx.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF") - C_sbuf = Tx.alloc_buffer([128, 1024], "float32", scope="trn.sbuf") + @T.prim_func + def copy(A_ptr: T.handle) -> None: + T.device_entry() + A_sbuf = T.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF") + B_sbuf = T.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF") + C_sbuf = T.alloc_buffer([128, 1024], "float32", scope="trn.sbuf") Tx.copy(B_sbuf[0:256, :], A_sbuf, workspace={"C": C_sbuf}) - @Tx.prim_func - def expected(A_ptr: Tx.handle) -> None: - Tx.func_attr({"global_symbol": "copy"}) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[0]) # noqa: E501 - B_sbuf = Tx.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[2*512*4]) # noqa: E501 - C_sbuf = Tx.alloc_buffer([128, 1024], "float32", scope="trn.sbuf", allocated_addr=[2*512*4+4*512*4]) # noqa: E501 + @T.prim_func + def expected(A_ptr: T.handle) -> None: + T.func_attr({"global_symbol": "copy"}) + T.device_entry() + A_sbuf = T.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[0]) # noqa: E501 + B_sbuf = T.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[2*512*4]) # noqa: E501 + C_sbuf = T.alloc_buffer([128, 1024], "float32", scope="trn.sbuf", allocated_addr=[2*512*4+4*512*4]) # noqa: E501 Tx.copy(B_sbuf[0:256, :], A_sbuf, workspace={"C": C_sbuf}) # fmt: on @@ -124,21 +125,21 @@ def expected(A_ptr: Tx.handle) -> None: def test_other_scope_alloc(): # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF") - B_sbuf = Tx.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF") - C_sbuf = Tx.alloc_buffer([8, 128, 512], "float32", scope="global") + @T.prim_func + def copy(A_ptr: T.handle) -> None: + T.device_entry() + A_sbuf = T.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF") + B_sbuf = T.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF") + C_sbuf = T.alloc_buffer([8, 128, 512], "float32", scope="global") Tx.copy(B_sbuf[0:256, :], A_sbuf, workspace={"C": C_sbuf}) - @Tx.prim_func - def expected(A_ptr: Tx.handle) -> None: - Tx.func_attr({"global_symbol": "copy"}) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[0]) # noqa: E501 - B_sbuf = Tx.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[2*512*4]) # noqa: E501 - C_sbuf = Tx.alloc_buffer([8, 128, 512], "float32", scope="global") + @T.prim_func + def expected(A_ptr: T.handle) -> None: + T.func_attr({"global_symbol": "copy"}) + T.device_entry() + A_sbuf = T.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[0]) # noqa: E501 + B_sbuf = T.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[2*512*4]) # noqa: E501 + C_sbuf = T.alloc_buffer([8, 128, 512], "float32", scope="global") Tx.copy(B_sbuf[0:256, :], A_sbuf, workspace={"C": C_sbuf}) # fmt: on @@ -149,20 +150,20 @@ def expected(A_ptr: Tx.handle) -> None: def test_buffer_views(): # fmt: off - @Tx.prim_func - def copy(A_ptr: Tx.handle) -> None: - Tx.device_entry() - A_sbuf = Tx.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF") - B_sbuf = Tx.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF") + @T.prim_func + def copy(A_ptr: T.handle) -> None: + T.device_entry() + A_sbuf = T.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF") + B_sbuf = T.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF") B_view = B_sbuf.view(2, 256, 512) Tx.copy(B_view[0], A_sbuf) - @Tx.prim_func - def expected(A_ptr: Tx.handle) -> None: - Tx.func_attr({"global_symbol": "copy"}) - Tx.device_entry() - A_sbuf = Tx.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[0]) # noqa: E501 - B_sbuf = Tx.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[2*512*4]) # noqa: E501 + @T.prim_func + def expected(A_ptr: T.handle) -> None: + T.func_attr({"global_symbol": "copy"}) + T.device_entry() + A_sbuf = T.alloc_buffer([256, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[0]) # noqa: E501 + B_sbuf = T.alloc_buffer([512, 512], "float32", scope="trn.sbuf", layout="PF", allocated_addr=[2*512*4]) # noqa: E501 B_view = B_sbuf.view(2, 256, 512) Tx.copy(B_view[0], A_sbuf) # fmt: on From 708a23ba6584c9043106bf8ce18905367619f5e5 Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Fri, 5 Jun 2026 13:26:44 -0700 Subject: [PATCH 2/4] refactor(op): remove tile primitive kind attrs (#661) * refactor(op): remove tile primitive kind attrs * refactor(op): move kernel replace point to builtin --- include/tvm/tirx/builtin.h | 5 + include/tvm/tirx/op_attr_types.h | 11 -- include/tvm/tirx/tirx_op.h | 7 -- python/tvm/tirx/op.py | 5 + .../tile_primitive/cuda/copy_async/tma.py | 5 +- .../tile_primitive/cuda/gemm_async/tcgen05.py | 4 +- .../tvm/tirx/operator/tile_primitive/ops.py | 60 ----------- .../tile_primitive/trn/copy/default.py | 3 +- .../tile_primitive/trn/private_alloc.py | 5 +- .../tile_primitive/trn/unary/utils.py | 3 +- python/tvm/tirx/script/builder/ir.py | 2 + python/tvm/tirx/script/builder/tirx.py | 6 -- python/tvm/tirx/script/tile.py | 2 - python/tvm/tirx/transform/common.py | 17 ++- src/tirx/analysis/verify_tirx_well_formed.cc | 7 +- src/tirx/ir/tirx_stmt.cc | 7 +- src/tirx/op/builtin.cc | 4 + src/tirx/op/tirx.cc | 100 +++++++----------- src/tirx/script/printer/stmt.cc | 27 ++--- src/tirx/transform/tile_primitive_dispatch.cc | 5 +- .../cuda/copy_async/test_tma.py | 4 +- .../python/tirx/test_op_namespace_cleanup.py | 29 +++-- 22 files changed, 113 insertions(+), 205 deletions(-) diff --git a/include/tvm/tirx/builtin.h b/include/tvm/tirx/builtin.h index 1a11598fa427..f25e48daf330 100644 --- a/include/tvm/tirx/builtin.h +++ b/include/tvm/tirx/builtin.h @@ -500,6 +500,11 @@ TVM_DLL const Op& tvm_call_trace_packed_lowered(); */ TVM_DLL const Op& tvm_storage_sync(); +/*! + * \brief Marker where a transform should replace generated kernel initialization. + */ +TVM_DLL const Op& tvm_kernel_replace_point(); + /*! * \brief See pseudo code * diff --git a/include/tvm/tirx/op_attr_types.h b/include/tvm/tirx/op_attr_types.h index 7ebd87ed6f3c..2b8a7428ca0b 100644 --- a/include/tvm/tirx/op_attr_types.h +++ b/include/tvm/tirx/op_attr_types.h @@ -92,17 +92,6 @@ using TScriptDtypePrintLocation = int64_t; */ using TIRxOpCategory = ffi::String; -/*! - * \brief Tile primitive subcategory. - * - * Expected values: - * - "dispatch" - * - "compose" - * - "async" - * - "marker" - */ -using TTilePrimitiveKind = ffi::String; - /*! * \brief Device intrinsic namespace. * diff --git a/include/tvm/tirx/tirx_op.h b/include/tvm/tirx/tirx_op.h index 772f6ce34f06..ad2ec8e80fee 100644 --- a/include/tvm/tirx/tirx_op.h +++ b/include/tvm/tirx/tirx_op.h @@ -231,13 +231,6 @@ TVM_DLL const Op& compose_op(); TVM_DLL const Op& permute_layout(); -/*! - * \brief See pesudo code below: - * - * tvm_kernel_replace_point() - */ -TVM_DLL const Op& tvm_kernel_replace_point(); - } // namespace tirx } // namespace tvm diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py index 1d2c7ec2d167..4fdbc283969a 100644 --- a/python/tvm/tirx/op.py +++ b/python/tvm/tirx/op.py @@ -729,6 +729,11 @@ def tvm_storage_sync(storage_scope, is_load=False, num_blocks=-1): return call_intrin("void", "tirx.tvm_storage_sync", storage_scope, is_load, num_blocks) +def tvm_kernel_replace_point(): + """Mark where a transform should replace generated kernel initialization.""" + return call_intrin("void", "tirx.tvm_kernel_replace_point") + + def tvm_global_barrier_kinit(): """Initialize the global barrier. diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tma.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tma.py index 7e19713e492b..7fd773103f1e 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tma.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tma.py @@ -49,7 +49,6 @@ import tvm from tvm.arith import Analyzer from tvm.script import tirx as T -from tvm.script.tirx import tile as Tx from tvm.tirx import Buffer, PrimFunc from tvm.tirx.layout import ComposeLayout, Layout, S, SwizzleLayout, TileLayout from tvm.tirx.operator.tile_primitive import ( @@ -1237,7 +1236,7 @@ def create_tensor_map(): 2, # CU_TENSOR_MAP_L2_PROMOTION_L2_128B oob_fill_kind, ) - Tx.tvm_kernel_replace_point() + T.tvm_kernel_replace_point() # fmt: on sctx.add_init_stmt(create_tensor_map.body, host=True) @@ -1255,7 +1254,7 @@ def create_tensor_map(): def prefetch_tensor_map(): if warp_id_in_cta == 0: T.ptx.prefetch_tensormap(T.address_of(tensor_map)) - Tx.tvm_kernel_replace_point() + T.tvm_kernel_replace_point() # fmt: on sctx.add_init_stmt(prefetch_tensor_map.body) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py b/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py index 5aac270467e3..c19bcda622e8 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py @@ -30,6 +30,7 @@ from tvm.runtime import DataType from tvm.script import tirx as T from tvm.tirx import PrimFunc +from tvm.tirx import op as tirx_op from tvm.tirx.layout import ( ComposeLayout, Iter, @@ -41,7 +42,6 @@ tmem_datapath_layout, ) from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, register_dispatch -from tvm.tirx.operator.tile_primitive.ops import KernelReplacePoint from tvm.tirx.stmt import AllocBuffer, Evaluate, SeqStmt, TilePrimitiveCall from ..common import get_st_extent, smem_desc_add_16B_offset @@ -717,7 +717,7 @@ def _try_atom(atom, atom_shape): # Descriptors with identical construction parameters are cached and reused # across dispatch calls via sctx.shared_state. B_base = [0] * len(B_buffer.shape) - krp = KernelReplacePoint(workspace={}, config={}) + krp = Evaluate(tirx_op.tvm_kernel_replace_point()) def _make_lo_uniform(desc): """Shuffle the lower 32 bits of the descriptor to ensure warp-uniformity.""" diff --git a/python/tvm/tirx/operator/tile_primitive/ops.py b/python/tvm/tirx/operator/tile_primitive/ops.py index 21e02793fd02..7455a1ae7456 100644 --- a/python/tvm/tirx/operator/tile_primitive/ops.py +++ b/python/tvm/tirx/operator/tile_primitive/ops.py @@ -21,50 +21,6 @@ from tvm.tirx import PrimExpr from tvm.tirx.stmt import TilePrimitiveCall -_DISPATCH_OPS = { - "zero", - "sqrt", - "exp", - "exp2", - "reciprocal", - "add", - "sub", - "mul", - "fdiv", - "maximum", - "minimum", - "copy", - "fill", - "gemm", - "sum", - "max", - "min", - "memset", - "reduce_negate", - "binary_reduce", - "unary_reduce", - "binary_chain", - "select", - "cast", - "fma", - "silu", -} -_COMPOSE_OPS = {"compose_op"} -_ASYNC_OPS = {"copy_async", "gemm_async"} -_MARKER_OPS = {"tvm_kernel_replace_point"} - - -def _tile_primitive_kind(op_name: str) -> str: - if op_name in _DISPATCH_OPS: - return "dispatch" - if op_name in _COMPOSE_OPS: - return "compose" - if op_name in _ASYNC_OPS: - return "async" - if op_name in _MARKER_OPS: - return "marker" - return "dispatch" - def get_tirx_op(op_name: str): assert isinstance(op_name, str) @@ -454,22 +410,6 @@ class Select(BinaryOp): predicate = ArgProperty(3) -class KernelReplacePoint(TilePrimitiveCall): - """A placeholder for kernel replacement points in TIR scheduling.""" - - op = get_tirx_op("tvm_kernel_replace_point") - - @property - def srcs(self) -> list[PrimExpr]: - """Get the source expressions (inputs) of the operator.""" - return [] - - @property - def dsts(self) -> list[PrimExpr]: - """Get the destination expressions (outputs) of the operator.""" - return [] - - ### Compose Ops ### class BinaryReduce(TilePrimitiveCall): """Combine a binary operation with a reduction operation. diff --git a/python/tvm/tirx/operator/tile_primitive/trn/copy/default.py b/python/tvm/tirx/operator/tile_primitive/trn/copy/default.py index 0005723ec193..b1a0b2078681 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/copy/default.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/copy/default.py @@ -18,7 +18,6 @@ """Implementation of copy operator dispatchs.""" from tvm.script import tirx as T -from tvm.script.tirx import tile as Tx from tvm.tirx import PrimFunc from tvm.tirx.operator.tile_primitive import ( DispatchContext, @@ -100,7 +99,7 @@ def identity_init(): for p_loop in T.serial(0, p_size, annotations={nki_dim: "P"}): for rhs_f_loop in T.serial(0, rhs_f_size, annotations={nki_dim: "F"}): T.evaluate(T.nki.identity(identity_tensor[p_loop, rhs_f_loop], p_size)) - Tx.tvm_kernel_replace_point() + T.tvm_kernel_replace_point() sctx.add_init_stmt(identity_init.body) else: diff --git a/python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py b/python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py index 90b97aeb62fd..fe3f0a54bba1 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py @@ -18,7 +18,6 @@ from typing import Any from tvm.script import tirx as T -from tvm.script.tirx import tile as Tx from tvm.tirx import Buffer, FloatImm, Stmt from tvm.tirx.operator.tile_primitive.dispatch_context import DispatchContext from tvm.tirx.operator.tile_primitive.ops import ( @@ -62,7 +61,7 @@ def const_bias_init(): for p_loop in T.serial(0, par_size, annotations={"nki_dim": "P"}): for f_loop in T.serial(0, max_inst_size, annotations={nki_dim: "F"}): T.evaluate(T.nki.memset(new_buffer[p_loop, f_loop], bias)) - Tx.tvm_kernel_replace_point() + T.tvm_kernel_replace_point() buffer_dict[("const_bias", bias.value)] = (new_buffer, const_bias_init.body) return {"const_bias": ("const_bias", bias.value)} @@ -112,7 +111,7 @@ def identity_init(): for p_loop in T.serial(0, par_size, annotations={nki_dim: "P"}): for rhs_f_loop in T.serial(0, par_size, annotations={nki_dim: "F"}): T.evaluate(T.nki.identity(new_buffer[p_loop, rhs_f_loop], par_size)) - Tx.tvm_kernel_replace_point() + T.tvm_kernel_replace_point() buffer_dict["identity"] = (new_buffer, identity_init.body) return {"identity": "identity"} diff --git a/python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py b/python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py index 24d1704923de..7a757609b09a 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py +++ b/python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py @@ -19,7 +19,6 @@ from tvm.arith.analyzer import Analyzer from tvm.script import tirx as T -from tvm.script.tirx import tile as Tx from tvm.tirx import BufferRegion, FloatImm from ...common import MapOpType @@ -110,7 +109,7 @@ def const_bias_init(): for p_loop in T.serial(0, shape[0], annotations={nki_dim: "P"}): for f_loop in T.serial(0, shape[1], annotations={nki_dim: "F"}): T.evaluate(T.nki.memset(bias_buffer[p_loop, f_loop], bias)) - Tx.tvm_kernel_replace_point() + T.tvm_kernel_replace_point() sctx.add_init_stmt(const_bias_init.body) else: diff --git a/python/tvm/tirx/script/builder/ir.py b/python/tvm/tirx/script/builder/ir.py index 7f527413e375..754f41fa3b23 100644 --- a/python/tvm/tirx/script/builder/ir.py +++ b/python/tvm/tirx/script/builder/ir.py @@ -3732,6 +3732,7 @@ def visit(ns_obj, dotted_prefix): tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) tvm_storage_sync = _tir_op.tvm_storage_sync +tvm_kernel_replace_point = _op_wrapper(_tir_op.tvm_kernel_replace_point) tvm_global_barrier_kinit = _tir_op.tvm_global_barrier_kinit tvm_warp_shuffle = _tir_op.tvm_warp_shuffle tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up @@ -4063,6 +4064,7 @@ def visit(ns_obj, dotted_prefix): "tvm_fill_fragment", "tvm_store_matrix_sync", "tvm_storage_sync", + "tvm_kernel_replace_point", "tvm_global_barrier_kinit", "tvm_warp_shuffle", "tvm_warp_shuffle_up", diff --git a/python/tvm/tirx/script/builder/tirx.py b/python/tvm/tirx/script/builder/tirx.py index 23f702ebc570..f2d211d6485d 100644 --- a/python/tvm/tirx/script/builder/tirx.py +++ b/python/tvm/tirx/script/builder/tirx.py @@ -1201,11 +1201,6 @@ def compose_op( return _ffi_api.ComposeOp(workspace, config, dispatch) # pylint: disable=no-member -def tvm_kernel_replace_point(): - """A placeholder for the kernel replace point, used in TIRx op scheduling.""" - return f_insert(tirx_op.KernelReplacePoint(workspace={}, config={})) - - @ScopedOp def binary_reduce( binary_output: BufferRegion | Buffer, @@ -1671,7 +1666,6 @@ def _to_region(b): "sub", "sum", "thread", - "tvm_kernel_replace_point", "unary_reduce", "warp", "warpgroup", diff --git a/python/tvm/tirx/script/tile.py b/python/tvm/tirx/script/tile.py index bbc2c131bad0..42fe3914bedc 100644 --- a/python/tvm/tirx/script/tile.py +++ b/python/tvm/tirx/script/tile.py @@ -106,7 +106,6 @@ def wrapper(*args, scope=None, **kwargs): thread = _builder.ScopeNamespace("thread", "thread") compose_op = _builder.compose_op -tvm_kernel_replace_point = _builder.tvm_kernel_replace_point __all__ = [ *_SCOPED_TILE_OP_NAMES, @@ -114,7 +113,6 @@ def wrapper(*args, scope=None, **kwargs): "compose_op", "cta", "thread", - "tvm_kernel_replace_point", "warp", "warpgroup", "wg", diff --git a/python/tvm/tirx/transform/common.py b/python/tvm/tirx/transform/common.py index 16995c1d6c5e..d90903daf967 100644 --- a/python/tvm/tirx/transform/common.py +++ b/python/tvm/tirx/transform/common.py @@ -16,12 +16,15 @@ # under the License. +from tvm.ir import Op from tvm.tirx import ( AllocBuffer, BufferLoad, BufferRegion, BufferStore, + Call, DeclBuffer, + Evaluate, PrimExpr, Stmt, TilePrimitiveCall, @@ -174,17 +177,11 @@ def __init__(self, body: Stmt): super().__init__() self.body = body - def visit_op_call_(self, op: TilePrimitiveCall): - # Deferred import: tile_primitive's class bodies call Op.get() (FFI), - # not runtime-safe. Only reached in compiler mode. - from tvm.tirx.operator.tile_primitive.ops import ( # pylint: disable=import-outside-toplevel - KernelReplacePoint, - ) - - op = TilePrimitiveCall.downcast(op) - if isinstance(op, KernelReplacePoint): + def visit_evaluate_(self, op: Evaluate): + value = op.value + if isinstance(value, Call) and value.op.same_as(Op.get("tirx.tvm_kernel_replace_point")): return self.body - return super().visit_op_call_(op) + return super().visit_evaluate_(op) def seek_kernel_replace_point(stmt: Stmt, body: Stmt) -> Stmt: diff --git a/src/tirx/analysis/verify_tirx_well_formed.cc b/src/tirx/analysis/verify_tirx_well_formed.cc index aabda41ade99..67c8424cba41 100644 --- a/src/tirx/analysis/verify_tirx_well_formed.cc +++ b/src/tirx/analysis/verify_tirx_well_formed.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -60,9 +61,9 @@ class ExecScopeVerifier : public Verifier { void VisitStmt_(const tirx::TilePrimitiveCallNode* op, ffi::reflection::AccessPath path) override { - static const tvm::OpAttrMap& tirx_op_map_ = Op::GetAttrMap("TIsTIRxOp"); - Verify(tirx_op_map_.count(op->op)) - << "TIRxError: TilePrimitiveCall at " << path << " has unknown TIRX op " << op->op; + static const auto& category_map = Op::GetAttrMap("TIRxOpCategory"); + Verify(category_map.get(op->op, ffi::String("")) == "tile_primitive") + << "TIRxError: TilePrimitiveCall at " << path << " has non-tile op " << op->op; } }; diff --git a/src/tirx/ir/tirx_stmt.cc b/src/tirx/ir/tirx_stmt.cc index 81f392048dc1..58c95d90b1e8 100644 --- a/src/tirx/ir/tirx_stmt.cc +++ b/src/tirx/ir/tirx_stmt.cc @@ -36,10 +36,9 @@ TilePrimitiveCall::TilePrimitiveCall(tvm::Op op, ffi::Array args, ffi::Map workspace, ffi::Map config, ffi::Optional dispatch, ExecScope scope) { - // Check if the op is a TIRX op. - static const auto& tirx_op_map = Op::GetAttrMap("TIsTIRxOp"); - TVM_FFI_ICHECK_EQ(tirx_op_map.count(op), 1) - << "Only TIRX ops can be used in tirx::TilePrimitiveCall"; + static const auto& category_map = Op::GetAttrMap("TIRxOpCategory"); + TVM_FFI_ICHECK(category_map.get(op, ffi::String("")) == "tile_primitive") + << "Only tile primitive ops can be used in tirx::TilePrimitiveCall"; // Construct the TilePrimitiveCall. ffi::ObjectPtr n = ffi::make_object(); n->op = std::move(op); diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc index c62de1a4ae15..c2ad5559d608 100644 --- a/src/tirx/op/builtin.cc +++ b/src/tirx/op/builtin.cc @@ -266,6 +266,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered) TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(tvm_kernel_replace_point) + .set_num_inputs(0) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); diff --git a/src/tirx/op/tirx.cc b/src/tirx/op/tirx.cc index 0410bb5f2157..5ff54c45b613 100644 --- a/src/tirx/op/tirx.cc +++ b/src/tirx/op/tirx.cc @@ -33,18 +33,16 @@ TVM_FFI_STATIC_INIT_BLOCK() { DispatchContextNode::RegisterReflection(); } /********************* Utils **********************/ -#define TIRX_DEFINE_TILE_FUNC(OpName) \ - const Op& OpName() { \ - static const Op& op = Op::Get("tirx.tile." #OpName); \ - return op; \ - } \ - TVM_REGISTER_OP("tirx.tile." #OpName) \ - .set_attr("TScriptPrinterName", ffi::String(#OpName), /*plevel=*/9) \ - .set_attr("TIRxOpCategory", ffi::String("tile_primitive"), /*plevel=*/9) \ - .set_attr("TIsTIRxOp", true) - -#define TIRX_DEFINE_TILE_OP(OpName, Kind) \ - TIRX_DEFINE_TILE_FUNC(OpName).set_attr("TTilePrimitiveKind", Kind) +#define TIRX_DEFINE_TILE_FUNC(OpName) \ + const Op& OpName() { \ + static const Op& op = Op::Get("tirx.tile." #OpName); \ + return op; \ + } \ + TVM_REGISTER_OP("tirx.tile." #OpName) \ + .set_attr("TScriptPrinterName", ffi::String(#OpName), /*plevel=*/9) \ + .set_attr("TIRxOpCategory", ffi::String("tile_primitive"), /*plevel=*/9) + +#define TIRX_DEFINE_TILE_OP(OpName) TIRX_DEFINE_TILE_FUNC(OpName) /********************* Context utils **********************/ template @@ -142,53 +140,37 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("tirx.DispatchContextSharedStateGet", &DispatchContextNode::SharedStateGet); } -/********************* Dispatch Ops **********************/ -#define TIRX_DEFINE_DISPATCH_OP(OpName) \ - TIRX_DEFINE_TILE_OP(OpName, ffi::String("dispatch")).set_attr("TIsDispatchOp", true) - -TIRX_DEFINE_DISPATCH_OP(zero); -TIRX_DEFINE_DISPATCH_OP(sqrt); -TIRX_DEFINE_DISPATCH_OP(exp); -TIRX_DEFINE_DISPATCH_OP(exp2); -TIRX_DEFINE_DISPATCH_OP(add); -TIRX_DEFINE_DISPATCH_OP(sub); -TIRX_DEFINE_DISPATCH_OP(mul); -TIRX_DEFINE_DISPATCH_OP(fdiv); -TIRX_DEFINE_DISPATCH_OP(minimum); -TIRX_DEFINE_DISPATCH_OP(maximum); -TIRX_DEFINE_DISPATCH_OP(copy); -TIRX_DEFINE_DISPATCH_OP(fill); -TIRX_DEFINE_DISPATCH_OP(gemm); -TIRX_DEFINE_DISPATCH_OP(reciprocal); -TIRX_DEFINE_DISPATCH_OP(sum); -TIRX_DEFINE_DISPATCH_OP(max); -TIRX_DEFINE_DISPATCH_OP(min); -TIRX_DEFINE_DISPATCH_OP(memset); -TIRX_DEFINE_DISPATCH_OP(reduce_negate); -TIRX_DEFINE_DISPATCH_OP(binary_reduce); -TIRX_DEFINE_DISPATCH_OP(unary_reduce); -TIRX_DEFINE_DISPATCH_OP(binary_chain); -TIRX_DEFINE_DISPATCH_OP(select); -TIRX_DEFINE_DISPATCH_OP(cast); -TIRX_DEFINE_DISPATCH_OP(fma); -TIRX_DEFINE_DISPATCH_OP(silu); -TIRX_DEFINE_DISPATCH_OP(permute_layout); - -/********************* Compose Ops **********************/ -#define TIRX_DEFINE_COMPOSE_OP(OpName) \ - TIRX_DEFINE_TILE_OP(OpName, ffi::String("compose")).set_attr("TIsComposeOp", true) - -TIRX_DEFINE_COMPOSE_OP(compose_op); - -/********************* Async Ops **********************/ -#define TIRX_DEFINE_ASYNC_OP(OpName) \ - TIRX_DEFINE_TILE_OP(OpName, ffi::String("async")).set_attr("TIsAsyncOp", true) - -TIRX_DEFINE_ASYNC_OP(copy_async); -TIRX_DEFINE_ASYNC_OP(gemm_async); - -/********************* Misc Ops **********************/ -TIRX_DEFINE_TILE_OP(tvm_kernel_replace_point, ffi::String("marker")); +/********************* Tile Ops **********************/ +TIRX_DEFINE_TILE_OP(zero); +TIRX_DEFINE_TILE_OP(sqrt); +TIRX_DEFINE_TILE_OP(exp); +TIRX_DEFINE_TILE_OP(exp2); +TIRX_DEFINE_TILE_OP(add); +TIRX_DEFINE_TILE_OP(sub); +TIRX_DEFINE_TILE_OP(mul); +TIRX_DEFINE_TILE_OP(fdiv); +TIRX_DEFINE_TILE_OP(minimum); +TIRX_DEFINE_TILE_OP(maximum); +TIRX_DEFINE_TILE_OP(copy); +TIRX_DEFINE_TILE_OP(fill); +TIRX_DEFINE_TILE_OP(gemm); +TIRX_DEFINE_TILE_OP(reciprocal); +TIRX_DEFINE_TILE_OP(sum); +TIRX_DEFINE_TILE_OP(max); +TIRX_DEFINE_TILE_OP(min); +TIRX_DEFINE_TILE_OP(memset); +TIRX_DEFINE_TILE_OP(reduce_negate); +TIRX_DEFINE_TILE_OP(binary_reduce); +TIRX_DEFINE_TILE_OP(unary_reduce); +TIRX_DEFINE_TILE_OP(binary_chain); +TIRX_DEFINE_TILE_OP(select); +TIRX_DEFINE_TILE_OP(cast); +TIRX_DEFINE_TILE_OP(fma); +TIRX_DEFINE_TILE_OP(silu); +TIRX_DEFINE_TILE_OP(permute_layout); +TIRX_DEFINE_TILE_OP(compose_op); +TIRX_DEFINE_TILE_OP(copy_async); +TIRX_DEFINE_TILE_OP(gemm_async); } // namespace tirx } // namespace tvm diff --git a/src/tirx/script/printer/stmt.cc b/src/tirx/script/printer/stmt.cc index 4d3c21c88d5d..21ae61135d9b 100644 --- a/src/tirx/script/printer/stmt.cc +++ b/src/tirx/script/printer/stmt.cc @@ -90,13 +90,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; } - static const auto& tirx_op_map = Op::GetAttrMap("TIsTIRxOp"); - static const auto& dispatch_op_map = Op::GetAttrMap("TIsDispatchOp"); - static const auto& compose_op_map = Op::GetAttrMap("TIsComposeOp"); - static const auto& async_op_map = Op::GetAttrMap("TIsAsyncOp"); static const auto& category_map = Op::GetAttrMap("TIRxOpCategory"); - TVM_FFI_ICHECK(tirx_op_map.get(op, false)) - << "Only TIRX ops can be used in tirx::TilePrimitiveCall"; + bool is_tile_primitive = category_map.get(op, ffi::String("")) == "tile_primitive"; + TVM_FFI_ICHECK(is_tile_primitive) + << "Only tile primitive ops can be used in tirx::TilePrimitiveCall"; ffi::String name = op_names.get(op, op->name); // Per-call execution scope is printed as a namespace prefix on the op, // e.g. ``T.warp.copy(...)``. ``warpgroup`` prints as ``wg``. The @@ -123,13 +120,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (ns.has_value()) { return TIRx(d, ns.value())->Attr(op_name); } - if (category_map.get(op, ffi::String("")) == "tile_primitive") { - return TIRx(d, "tile")->Attr(op_name); - } - return TIRx(d, op_name); + return TIRx(d, "tile")->Attr(op_name); }; - if (dispatch_op_map.get(op, false) || async_op_map.get(op, false)) { - // Dispatch ops + if (!op.same_as(tirx::compose_op())) { // Trim trailing None args (e.g. optional bias=None, scale=None) size_t n_args = op_call->args.size(); while (n_args > 0 && @@ -160,8 +153,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return OpCallDoc(scoped_callee(name), args, d->AsDoc(op_call->workspace, p->Attr("workspace")), d->AsDoc(op_call->config, p->Attr("config")), disp); - } else if (compose_op_map.get(op, false)) { - // Compose ops + } else { With f(d, op_call); ffi::Array stmts; for (size_t i = 0, n = op_call->args.size(); i < n; ++i) { @@ -191,13 +183,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } return ScopeDoc(std::nullopt, scoped_callee("compose_op")->Call({}, kw_keys, kw_values), (*f)->stmts); - } else { - // Misc ops - ffi::Array args; - for (size_t i = 0, n = op_call->args.size(); i < n; ++i) { - args.push_back(d->AsDoc(op_call->args[i], p->Attr("args")->ArrayItem(i))); - } - return OpCallDoc(scoped_callee(name), args, {}, {}, std::nullopt); } }); TVM_SCRIPT_REPR(tirx::TilePrimitiveCallNode, ReprPrintTIR); diff --git a/src/tirx/transform/tile_primitive_dispatch.cc b/src/tirx/transform/tile_primitive_dispatch.cc index cb431e474d66..727bceaa0ed3 100644 --- a/src/tirx/transform/tile_primitive_dispatch.cc +++ b/src/tirx/transform/tile_primitive_dispatch.cc @@ -294,8 +294,9 @@ class TilePrimitiveDispatcher : public StmtExprMutator { } private: - Stmt VisitStmt_(const tirx::TilePrimitiveCallNode* op) final { - if (op->op == tirx::tvm_kernel_replace_point()) { + Stmt VisitStmt_(const EvaluateNode* op) final { + const auto* call = op->value.as(); + if (call != nullptr && call->op.same_as(tirx::builtin::tvm_kernel_replace_point())) { return body_; } return StmtExprMutator::VisitStmt_(op); diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py index 1c4bf5221625..933b866bdb64 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py @@ -36,7 +36,7 @@ ) from tvm.tirx.operator.tile_primitive.dispatch_context import DispatchContext from tvm.tirx.operator.tile_primitive.ops import CopyAsync -from tvm.tirx.stmt import DeclBuffer, TilePrimitiveCall +from tvm.tirx.stmt import DeclBuffer from tvm.tirx.stmt_functor import StmtExprVisitor # =========================================================================== @@ -159,7 +159,7 @@ def _build_expected_host_init(dtype, encode_args): + [IntImm("int32", v) for v in encode_args[1:]] ) encode_call = tvm.tirx.Call("int32", tvm.ir.Op.get("tirx.tvm_call_packed"), call_args) - replace_point = TilePrimitiveCall(op=tvm.ir.Op.get("tirx.tile.tvm_kernel_replace_point")) + replace_point = tvm.tirx.Evaluate(tvm.tirx.op.tvm_kernel_replace_point()) return tvm.tirx.SeqStmt( [tvm.tirx.Bind(A_tensormap, stack_alloca), tvm.tirx.Evaluate(encode_call), replace_point] ) diff --git a/tests/python/tirx/test_op_namespace_cleanup.py b/tests/python/tirx/test_op_namespace_cleanup.py index 5c0aa9615207..0bbfcff3e86d 100644 --- a/tests/python/tirx/test_op_namespace_cleanup.py +++ b/tests/python/tirx/test_op_namespace_cleanup.py @@ -100,6 +100,29 @@ def test_builtin_expression_ops_are_not_tile_primitives(): assert fma.op.name == "tirx.fma" +def test_kernel_replace_point_is_builtin_marker_not_tile_primitive(): + assert _op_attr("tirx.tvm_kernel_replace_point", "TIRxOpCategory") == "builtin" + assert "tirx.tile.tvm_kernel_replace_point" not in Op.list_op_names() + assert hasattr(T, "tvm_kernel_replace_point") + assert not hasattr(Tx, "tvm_kernel_replace_point") + + @T.prim_func(check_well_formed=False) + def marker(): + T.tvm_kernel_replace_point() + + calls = _expr_calls(marker) + assert [call.op.name for call in calls] == ["tirx.tvm_kernel_replace_point"] + assert _tile_calls(marker) == [] + + code = marker.script() + assert "T.tvm_kernel_replace_point()" in code + assert "tvm_kernel_replace_point" in code + assert "T.tile.tvm_kernel_replace_point" not in code + assert "Tx.tvm_kernel_replace_point" not in code + reparsed = tvm.script.from_source(code) + assert_structural_equal(marker, reparsed) + + def test_tile_shorthand_and_scoped_aliases_use_tile_ops(): @T.prim_func(check_well_formed=False) def tile_aliases(a: T.handle, b: T.handle): @@ -141,7 +164,6 @@ def test_device_intrinsic_namespaces_are_canonical_and_classified(): for op_name, namespace in expected: assert _op_attr(op_name, "TIRxOpCategory") == "device_intrin" assert _op_attr(op_name, "TDeviceIntrinsicNamespace") == namespace - assert _op_attr(op_name, "TTilePrimitiveKind") is None def test_device_intrinsic_printer_roundtrips_canonical_namespaces(): @@ -183,7 +205,6 @@ def test_registered_tirx_ops_have_exactly_one_category(): pytest.skip("TIRx op categories require a rebuilt C++ runtime") categories = {"builtin", "tile_primitive", "device_intrin"} - tile_kinds = {"dispatch", "compose", "async", "marker"} device_namespaces = {"cuda", "ptx", "nvshmem", "nki", "metal", "webgpu"} flat_tile_only_names = { "tirx.add", @@ -214,7 +235,6 @@ def test_registered_tirx_ops_have_exactly_one_category(): lingering_flat_tile = [] for op_name in sorted(name for name in Op.list_op_names() if name.startswith("tirx.")): category = _op_attr(op_name, "TIRxOpCategory") - tile_kind = _op_attr(op_name, "TTilePrimitiveKind") device_namespace = _op_attr(op_name, "TDeviceIntrinsicNamespace") if category is None: @@ -229,10 +249,8 @@ def test_registered_tirx_ops_have_exactly_one_category(): if category == "tile_primitive": if not op_name.startswith("tirx.tile."): lingering_flat_tile.append(op_name) - assert tile_kind in tile_kinds, op_name assert device_namespace is None, op_name elif category == "device_intrin": - assert tile_kind is None, op_name assert device_namespace in device_namespaces, op_name printer_name = _op_attr(op_name, "TScriptPrinterName") assert printer_name is not None, op_name @@ -240,7 +258,6 @@ def test_registered_tirx_ops_have_exactly_one_category(): assert _has_path(T, printer_name), op_name else: assert category == "builtin" - assert tile_kind is None, op_name assert device_namespace is None, op_name assert not missing From 47fd0d66b404f2dea6182f8a9a4c36519b604584 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Fri, 5 Jun 2026 18:30:30 -0400 Subject: [PATCH 3/4] fix(lower-tirx): recognize canonical PTX ops in warp memory lowering --- src/tirx/transform/lower_warp_memory.cc | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/tirx/transform/lower_warp_memory.cc b/src/tirx/transform/lower_warp_memory.cc index 99c815bf6630..a30f27a859ca 100644 --- a/src/tirx/transform/lower_warp_memory.cc +++ b/src/tirx/transform/lower_warp_memory.cc @@ -49,6 +49,18 @@ namespace tvm { namespace tirx { +namespace { + +bool IsOp(const CallNode* call, const Op& compat_op, const char* canonical_name) { + if (call->op.same_as(compat_op)) { + return true; + } + const auto* op_node = call->op.as(); + return op_node != nullptr && op_node->name == canonical_name; +} + +} // namespace + // Rewrite Rule // // There is no special warp memory in most GPUs. @@ -117,13 +129,14 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { private: /// Visitor implementation void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as() == buffer_) { + if (IsOp(op, builtin::ptx_ldmatrix(), "tirx.ptx.ldmatrix") && + op->args[3].as() == buffer_) { UpdatePattern(op->args[4]); } else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as() == buffer_) { auto* local_size = op->args[0].as(); TVM_FFI_ICHECK(local_size) << "Integer expected for the first argument of mma_fill"; warp_coeff_ = local_size->value; - } else if (op->op.same_as(builtin::ptx_ldmatrix_legacy()) && + } else if (IsOp(op, builtin::ptx_ldmatrix_legacy(), "tirx.ptx.ldmatrix_legacy") && op->args[3].as() == buffer_) { // ldmatrix writes the warp buffer; its local_offset carries // ``... + lift(local_size) * tx`` from which the warp coefficient @@ -295,11 +308,11 @@ class WarpAccessRewriter : protected StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) override { - if (op->op.same_as(builtin::ptx_mma())) { + if (IsOp(op, builtin::ptx_mma(), "tirx.ptx.mma")) { return RewriteIndicesAt(op, {6, 8, 10}); } - if (op->op.same_as(builtin::ptx_ldmatrix())) { + if (IsOp(op, builtin::ptx_ldmatrix(), "tirx.ptx.ldmatrix")) { return RewriteIndicesAt(op, {3}); } @@ -312,10 +325,10 @@ class WarpAccessRewriter : protected StmtExprMutator { } // Legacy variants: (ptr_var, offset) pairs in apache positions. - if (op->op.same_as(builtin::ptx_mma_legacy())) { + if (IsOp(op, builtin::ptx_mma_legacy(), "tirx.ptx.mma_legacy")) { return RewriteIndicesAt(op, {6, 8, 10}); } - if (op->op.same_as(builtin::ptx_ldmatrix_legacy())) { + if (IsOp(op, builtin::ptx_ldmatrix_legacy(), "tirx.ptx.ldmatrix_legacy")) { // args: trans, num, type, local_ptr, local_offset, smem_ptr_call, smem_offset // Only local_ptr is a raw warp buffer Var; smem_ptr is an // access_ptr Call wrapping a shared-scope var. From fe71f443d47b65f760fe75ea4dbc64de917d9dc4 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Fri, 5 Jun 2026 19:39:49 -0400 Subject: [PATCH 4/4] docs(op): fix PTX ldmatrix docstring indentation --- python/tvm/tirx/op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py index 4fdbc283969a..4d80ac378e5f 100644 --- a/python/tvm/tirx/op.py +++ b/python/tvm/tirx/op.py @@ -4953,7 +4953,7 @@ def ptx_ldmatrix(trans, num, dtype, smem_ptr, *dst_handles): """TVM intrinsic for ldmatrix.sync.aligned.m8n8.x{num}{.trans}.shared.{dtype}. Mirrors the PTX ISA destination form: each output register is a separate - operand. Pass ``T.address_of(buf[idx])`` (or ``buf.ptr_to([idx])``) for + operand. Pass ``T.address_of(buf[idx])`` (or ``buf.ptr_to([idx])``) for each destination — the slots may be non-contiguous. Parameters