Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tilus/backends/contexts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tilus.backends.contexts.const_reg_ctx import ConstRegTensorEmitContext
from tilus.backends.contexts.contexts import EmitContexts
from tilus.backends.contexts.global_view_ctx import GlobalTensorView, GlobalTensorViewContext
from tilus.backends.contexts.gmem_alloc_ctx import GlobalMemoryAllocationContext
Expand Down
105 changes: 105 additions & 0 deletions python/tilus/backends/contexts/const_reg_ctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence

from hidet.ir.expr import Expr, Var

from tilus.backends.context import BaseEmitContext
from tilus.extensions.hidet.ir.tools import rewrite
from tilus.ir.tensor import RegisterTensor


@dataclass
class ConstRegTensorInfo:
"""Represents a CTA-invariant register tensor whose elements can be computed from their logical indices.

The value at logical indices ``(i0, i1, ...)`` is obtained by substituting
``axes[0] = i0, axes[1] = i1, ...`` into ``expr``.

Attributes
----------
axes: list[Var]
Variables representing the logical indices, one per tensor dimension.
expr: Expr
The expression computing the tensor element value, parameterized by ``axes``.
"""

axes: list[Var]
expr: Expr


class ConstRegTensorEmitContext(BaseEmitContext):
"""Tracks RegisterTensors whose elements are CTA-invariant expressions.

Some RegisterTensors (e.g., barrier addresses from AllocBarrierInst) have values that are
constant during the lifetime of a CTA and can be expressed as a closed-form function of the
logical indices. Instead of materializing these as arrays (which may be spilled to local memory
by nvcc), consumers can use this context to obtain the value via arithmetic.

The normal array materialization is still emitted as a fallback. This context enables emitters
for specific instructions (e.g., SliceRegisterInst) to bypass array indexing and use the
arithmetic expression directly.
"""

def __post_init__(self):
self._tracked: dict[RegisterTensor, ConstRegTensorInfo] = {}

def register(self, tensor: RegisterTensor, axes: list[Var], expr: Expr) -> None:
"""Register a CTA-invariant register tensor.

Parameters
----------
tensor: RegisterTensor
The tensor to track.
axes: list[Var]
Variables representing the logical indices used in ``expr``, one per tensor dimension.
expr: Expr
The expression computing the tensor element value, parameterized by ``axes``.
"""
self._tracked[tensor] = ConstRegTensorInfo(axes=axes, expr=expr)

def is_tracked(self, tensor: RegisterTensor) -> bool:
"""Check if a tensor is tracked as CTA-invariant."""
return tensor in self._tracked

def get_info(self, tensor: RegisterTensor) -> ConstRegTensorInfo:
"""Get the tracking info for a CTA-invariant tensor."""
return self._tracked[tensor]

def get_value(self, tensor: RegisterTensor, logical_indices: Sequence[Expr]) -> Expr:
"""Compute the value of a CTA-invariant tensor at the given logical indices.

Parameters
----------
tensor: RegisterTensor
The tracked tensor.
logical_indices: Sequence[Expr]
The logical index expressions (may be runtime variables), one per tensor dimension.

Returns
-------
ret: Expr
The value expression with axis variables substituted by the given indices.
"""
info = self._tracked[tensor]
if len(info.axes) != len(logical_indices):
raise ValueError(
f"Expected {len(info.axes)} indices, got {len(logical_indices)} for tensor with shape {tensor.shape}"
)
mapping = dict(zip(info.axes, logical_indices))
return rewrite(info.expr, mapping)
2 changes: 2 additions & 0 deletions python/tilus/backends/contexts/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from tilus.backends.context import BaseEmitContext
from tilus.backends.contexts.const_reg_ctx import ConstRegTensorEmitContext
from tilus.backends.contexts.global_view_ctx import GlobalTensorViewContext
from tilus.backends.contexts.gmem_alloc_ctx import GlobalMemoryAllocationContext
from tilus.backends.contexts.invariant_ctx import InvariantTrackingContext
Expand All @@ -36,6 +37,7 @@ def __init__(self, codegen):
self.tcgen05_ctx: Tcgen05EmitContext = Tcgen05EmitContext(codegen)
self.barrier_alloc_ctx: BarrierAllocContext = BarrierAllocContext(codegen)
self.sync_ctx: SyncContext = SyncContext(codegen)
self.const_reg_ctx: ConstRegTensorEmitContext = ConstRegTensorEmitContext(codegen)

def contexts(self) -> list[BaseEmitContext]:
"""Get all contexts as a list.
Expand Down
16 changes: 12 additions & 4 deletions python/tilus/backends/contexts/mbarrier_alloc_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,22 @@ def finalize(self):
sb.append(syncthreads())
self.kernel_prepend(sb.finish())

def allocate_barriers(self, counts: Sequence[Expr | int]) -> list[Var]:
def allocate_barriers(self, counts: Sequence[Expr | int]) -> tuple[Var, list[Var]]:
"""
Allocate a list of barriers with given counts.
Allocate a contiguous list of barriers with given counts.

Each barrier is a 64-bit data structure stored in shared memory.
This function returns the address of the first barrier in the shared space.

Returns
-------
ret: tuple[Var, list[Var]]
A tuple of (base_addr, barrier_vars) where base_addr is the address of the first barrier
and barrier_vars are the addresses of each individual barrier. The barriers are guaranteed
to be contiguously allocated, so barrier_vars[i] = base_addr + i * 8.
"""
base_offset = len(self.barriers)
barrier_vars = [Var("barrier_{}".format(c), type=uint32) for c in counts]
self.counts.extend([uint32(c) if isinstance(c, int) else c for c in counts])
self.barriers.extend(barrier_vars)
return barrier_vars
base_addr = self.barrier_addr + uint32(base_offset * uint64.nbytes)
return base_addr, barrier_vars
5 changes: 2 additions & 3 deletions python/tilus/backends/contexts/sync_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def sync(self) -> Optional[Expr]:
key = (thread_begin, thread_end)
if key not in self.thread_group_barrier:
# allocate a new barrier for this thread group
self.thread_group_barrier[key] = self.contexts.barrier_alloc_ctx.allocate_barriers(
counts=[thread_end - thread_begin]
)[0]
_, barrier_vars = self.contexts.barrier_alloc_ctx.allocate_barriers(counts=[thread_end - thread_begin])
self.thread_group_barrier[key] = barrier_vars[0]
barrier_addr = self.thread_group_barrier[key]
return mbarrier_sync_shared(barrier_addr)
14 changes: 11 additions & 3 deletions python/tilus/backends/emitters/cuda/fence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# limitations under the License.


from hidet.ir.dtypes import uint32, uint64
from hidet.ir.expr import Var

from tilus.backends.emitter import BaseInstEmitter, register_emitter
from tilus.ir.instructions.cuda.mbarrier import (
AllocBarrierInst,
Expand All @@ -28,7 +31,12 @@ def emit(self, inst: AllocBarrierInst) -> None:
out_var = self.get_or_allocate_var(out)

counts = [c if c is not None else self.current_num_threads for c in inst.counts]
barriers = self.contexts.barrier_alloc_ctx.allocate_barriers(counts=counts)
base_addr, barrier_vars = self.contexts.barrier_alloc_ctx.allocate_barriers(counts=counts)

for i in range(len(barrier_vars)):
self.buffer_store(out_var, indices=[i], value=barrier_vars[i])

for i in range(len(barriers)):
self.buffer_store(out_var, indices=[i], value=barriers[i])
# Register as CTA-invariant tensor: value(i) = base_addr + i * uint64.nbytes
axis = Var("i", type=uint32)
expr = base_addr + axis * uint32(uint64.nbytes)
self.contexts.const_reg_ctx.register(out, axes=[axis], expr=expr)
14 changes: 10 additions & 4 deletions python/tilus/backends/emitters/cuda/mbarrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from hidet.ir.dtypes import uint32
from hidet.ir.dtypes import uint32, uint64
from hidet.ir.expr import Var

from tilus.backends.emitter import BaseInstEmitter, register_emitter
from tilus.extensions.hidet.ir.primitives.cuda.fence import fence_view_async
Expand Down Expand Up @@ -42,10 +43,15 @@ def emit(self, inst: AllocBarrierInst) -> None:
out_var = self.get_or_allocate_var(out)

counts = [c if c is not None else self.current_num_threads for c in inst.counts]
barriers = self.contexts.barrier_alloc_ctx.allocate_barriers(counts=counts)
base_addr, barrier_vars = self.contexts.barrier_alloc_ctx.allocate_barriers(counts=counts)

for i in range(len(barriers)):
self.buffer_store(out_var, indices=[i], value=barriers[i])
for i in range(len(barrier_vars)):
self.buffer_store(out_var, indices=[i], value=barrier_vars[i])

# Register as CTA-invariant tensor: value(i) = base_addr + i * uint64.nbytes
axis = Var("i", type=uint32)
expr = base_addr + axis * uint32(uint64.nbytes)
self.contexts.const_reg_ctx.register(out, axes=[axis], expr=expr)


@register_emitter(ArriveBarrierInst, target=nvgpu_sm80)
Expand Down
40 changes: 40 additions & 0 deletions python/tilus/backends/emitters/regs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ def emit(self, inst: SliceRegisterInst) -> None: # type: ignore
src_tensor: RegisterTensor = inst.register_input
src_layout: RegisterLayout = src_tensor.layout

const_reg_ctx = self.contexts.const_reg_ctx

if const_reg_ctx.is_tracked(src_tensor):
self._emit_const_reg(inst, dst_tensor, dst_layout, src_tensor, src_layout, const_reg_ctx)
else:
self._emit_default(inst, dst_tensor, dst_layout, src_tensor, src_layout)

def _emit_default(self, inst, dst_tensor, dst_layout, src_tensor, src_layout):
"""Default emission: materialize as array indexing."""
dst_var = self.get_or_allocate_var(tensor=dst_tensor, name="slice_regs")
src_var = self.tensor2var[src_tensor]

Expand All @@ -39,6 +48,37 @@ def emit(self, inst: SliceRegisterInst) -> None: # type: ignore
src_local = src_layout.get_local(global_indices=src_indices)
self.buffer_store(buf=dst_var, indices=[dst_local], value=src_var[src_local])

def _emit_const_reg(self, inst, dst_tensor, dst_layout, src_tensor, src_layout, const_reg_ctx):
"""Optimized emission: compute values from logical index expressions, bypassing array indexing."""
from hidet.ir.dtypes import int32
from hidet.ir.expr import Var

dst_var = self.get_or_allocate_var(tensor=dst_tensor, name="slice_regs")

# For each destination element, compute the logical source indices from the offsets
with self.for_range(extent=dst_layout.local_size) as dst_local:
# Get the logical indices of the destination element in the source tensor's index space
dst_global = dst_layout.get_global(spatial_index=self.current_thread, local_index=dst_local)
src_logical_indices = list(inst.offsets)
dims = range(len(src_layout.shape)) if inst.dims is None else inst.dims
for dim in dims:
src_logical_indices[dim] = src_logical_indices[dim] + dst_global[dim]

value = const_reg_ctx.get_value(src_tensor, src_logical_indices)
self.buffer_store(buf=dst_var, indices=[dst_local], value=value)

# Propagate CTA-invariant tracking to the output tensor
dst_axes = [Var(f"i{d}", type=int32) for d in range(len(dst_tensor.shape))]
dst_global = dst_layout.get_global(
spatial_index=self.current_thread, local_index=dst_axes[0] if dst_axes else int32(0)
)
src_logical_indices = list(inst.offsets)
dims = range(len(src_layout.shape)) if inst.dims is None else inst.dims
for dim in dims:
src_logical_indices[dim] = src_logical_indices[dim] + dst_global[dim]
dst_expr = const_reg_ctx.get_value(src_tensor, src_logical_indices)
const_reg_ctx.register(dst_tensor, axes=dst_axes, expr=dst_expr)


@register_emitter(SliceAssignInst)
class SliceAssignInstEmitter(BaseInstEmitter):
Expand Down
2 changes: 2 additions & 0 deletions python/tilus/extensions/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def compile(
"--diag-suppress 179",
# supress warning no 39 like: "warning #39-D: division by zero"
"--diag-suppress 39",
# supress warning no 550 like: "warning #550-D: variable "xxx" was set but never used"
"--diag-suppress 550",
# generate shared library (lib.so).
"--shared" if out_lib_path.endswith(".so") else "--compile",
# the linking objects.
Expand Down
Loading