[TIRx] Update scoped ops and CUDA launch bounds#19677
Conversation
Fold the latest TIRx op-dispatch and lowering follow-ups into one upstream-facing commit. This updates the scoped Tx op surface, splits TIRx op namespaces, and adds explicit CUDA launch bounds support.
There was a problem hiding this comment.
Code Review
This pull request introduces several updates to the TVM/TIRX compiler stack, including the addition of GPU monitoring and benchmarking scripts, the introduction of s_tir semantics, the renaming of layout classes to SLayout and SBijectiveLayout, new CUDA and Trainium intrinsics, and a CompareBeforeAfter testing utility. The code review identified several critical issues that need to be addressed: a parse-time failure from evaluating a PrimExpr in a Python if statement, a missing lazy import of tvm.s_tir.pipeline causing a ValueError, incorrect registration of an exit callback on attr_frame instead of the parent frame, macro compilation errors in layout.h, FFI conversion failures when passing Python Enum objects directly to CUDA intrinsics, Python < 3.10 compatibility issues with | type unions in isinstance calls, a redundant declaration of ptx_ldg32(), and an unused variable in nvcc.py.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
I am having trouble creating individual review comments. Click here to see my feedback.
python/tvm/tirx/lang/pipeline.py (137-139)
Evaluating a PrimExpr (like self.leader) inside a Python if statement will raise a ValueError or RuntimeError at parse time because its truth value is ambiguous. Since self.leader defaults to Tx.cuda.thread_rank() == 0, it must be wrapped in a TVM Script conditional block using with Tx.If(self.leader): and with Tx.Then():.
with Tx.If(self.leader):
with Tx.Then():
for i in Tx.unroll(self.depth):
Tx.ptx.mbarrier.init(self.buf.ptr_to([i]), count)
python/tvm/tirx/compilation_pipeline.py (180-187)
In python/tvm/__init__.py, the import of s_tir was removed. Consequently, tvm.s_tir.pipeline is no longer imported by default, and "s_tir" is never registered into PIPELINE_MAP. When get_tir_pipeline("default") is called, it maps "default" to "s_tir", which then raises a ValueError because "s_tir" is missing from PIPELINE_MAP. Please lazily import tvm.s_tir.pipeline inside get_tir_pipeline if "s_tir" is requested but not yet registered.
if name == "default":
# for now, default to s_tir pipeline
name = "s_tir"
if name == "s_tir" and "s_tir" not in PIPELINE_MAP:
import tvm.s_tir.pipeline
if name not in PIPELINE_MAP:
raise ValueError(
f"Unknown pre-built pipeline {name},candidates are {list(PIPELINE_MAP.keys())}"
)
return PIPELINE_MAP[name](**kwargs)
python/tvm/tirx/lang/alloc_pool.py (506-510)
attr_frame.add_callback(partial(attr_frame.__exit__, None, None, None)) adds the exit callback to attr_frame itself. However, since attr_frame is never exited via a with block, its callbacks are never executed, and the frame is never popped from the IRBuilder stack. To automatically pop attr_frame when the parent scope exits, the callback must be registered on the parent frame (the currently active frame on the IRBuilder stack before entering attr_frame).
if isinstance(attr_frame, frame_mod.AttrFrame):
from functools import partial
parent_frame = tvm.script.ir_builder.IRBuilder.current().frames[-1]
parent_frame.add_callback(partial(attr_frame.__exit__, None, None, None))
attr_frame.__enter__()
include/tvm/tirx/layout.h (280-284)
The macro TVM_AXIS_REGISTER_VAR_DEF is defined as [[maybe_unused]] static ::tvm::tirx::AxisRegEntry& __make_##Axis. Since Axis is not a macro parameter, using ## with it is a preprocessing error. Furthermore, TVM_REGISTER_AXIS attempts to concatenate TVM_AXIS_REGISTER_VAR_DEF (which is a sequence of tokens) with __COUNTER__ using ## inside TVM_STR_CONCAT_, which is invalid and will fail to compile. Please simplify the macro by directly defining the variable inside TVM_REGISTER_AXIS using TVM_STR_CONCAT to generate a unique name.
#define TVM_REGISTER_AXIS(AxisName) \
[[maybe_unused]] static ::tvm::tirx::AxisRegEntry& TVM_STR_CONCAT( \
__make_axis_, __COUNTER__) = \
::tvm::tirx::AxisRegEntry::RegisterOrGet(AxisName)python/tvm/tirx/bench.py (640-655)
event_type is type-hinted as Enum (e.g., EventType.kBegin). Passing a Python Enum instance directly to Tx.timer_start_cuda or Tx.timer_end_cuda will fail because the TVM FFI does not automatically convert Python Enum objects to PrimExpr / IntImm. You must pass event_type.value instead.
def start(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None):
if self.profiler_enabled:
Tx.timer_start_cuda(
event_type.value,
self.buffer.data,
self.profiler_tag.data,
self.profiler_write_offset.data,
self.write_stride,
self._leader(leader),
)
@Tx.inline
def end(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None):
if self.profiler_enabled:
Tx.timer_end_cuda(
event_type.value,python/tvm/script/parser/core/evaluator.py (242-244)
Using | for type unions inside isinstance (e.g., isinstance(node, doc.ListComp | doc.SetComp | doc.DictComp)) is only supported in Python 3.10+. If TVM is run on Python 3.8 or 3.9, this will raise a TypeError at runtime. Use a tuple of types instead for backward compatibility.
if isinstance(node, (doc.ListComp, doc.SetComp, doc.DictComp)):
value = self._eval_expr(node)
return self._add_intermediate_result(value)
python/tvm/tirx/buffer.py (318)
Using | for type unions inside isinstance (e.g., isinstance(args[0], str | tvm.DataType)) is only supported in Python 3.10+. If TVM is run on Python 3.8 or 3.9, this will raise a TypeError at runtime. Use a tuple of types instead for backward compatibility.
if len(args) == 1 and isinstance(args[0], (str, tvm.DataType)) and not kwargs:
include/tvm/tirx/target_builtin/cuda.h (103-109)
ptx_ldg32() is declared twice consecutively. The second declaration is redundant and should be removed.
/*!
* \brief tvm intrinsic for ptx predicate load with 32-bit data type.
*
*/
TVM_DLL const Op& ptx_ldg32();python/tvm/contrib/nvcc.py (198)
The variable major is assigned but never used in the function _compile_cuda_nvcc. Since get_target_compute_version has no side effects, we can safely remove this line.
* refactor(op): remove tile primitive kind attrs * refactor(op): move kernel replace point to builtin
Summary
Tx.<scope>.<op>namespaces and migrate call sitesValidation
git diff --check apache/main..HEADpre-commit run --from-ref apache/main --to-ref HEAD