From 3e8e2b93215a014d16d790cb7dc0daf011062a91 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Fri, 5 Sep 2025 23:13:09 +0000 Subject: [PATCH 1/7] enhance tensor slicing/indexing Signed-off-by: Yaoyao Ding --- python/tilus/ir/tensor.py | 2 +- python/tilus/lang/transpiler.py | 68 +++++++++++++++++++-------------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/python/tilus/ir/tensor.py b/python/tilus/ir/tensor.py index e1ec97c6..bb2c3a2b 100644 --- a/python/tilus/ir/tensor.py +++ b/python/tilus/ir/tensor.py @@ -371,7 +371,7 @@ class SharedTensor(Tensor): shape: tuple[int, ...] optional_layout: Optional[SharedLayout] - def __getitem__(self, index: int | Expr) -> SharedTensor: + def __getitem__(self, indices: tuple[Expr | int, ...] | Expr | int) -> SharedTensor | Expr: raise RuntimeError("shared_tensor[...] could only be used in Tilus Script.") @staticmethod diff --git a/python/tilus/lang/transpiler.py b/python/tilus/lang/transpiler.py index e9ac86fd..cc2973e5 100644 --- a/python/tilus/lang/transpiler.py +++ b/python/tilus/lang/transpiler.py @@ -877,41 +877,53 @@ def visit_Subscript(self, expr: ast.Subscript) -> Any: if isinstance(base, Sequence): return base[indices] - elif isinstance(base, GlobalTensor): + elif isinstance(base, (GlobalTensor, SharedTensor)): if not isinstance(indices, Sequence): - indices = (indices,) - if ( - isinstance(indices, Sequence) - and len(indices) == len(base.shape) - and not any(i is None or isinstance(i, slice) for i in indices) - ): - sb = StmtBuilder() + indices = [indices] + else: + indices = list(indices) + while len(indices) < len(base.shape): + indices.append(0) + if len(indices) > len(base.shape): + raise TilusProgramError(self, expr, "Too many indices for tensor of shape {}.".format(base.shape)) + offsets = [] + slice_dims = [] + for dim, idx in enumerate(indices): + if isinstance(idx, slice): + if idx.start is not None or idx.end is not None: + if not isinstance(idx.start, (int, hidet_ir.Expr)): + raise TilusProgramError( + self, + expr, + "Global/Shared tensors only support slicing whole dimensions: [..., :, ...], " + "do not support slicing like [..., start:, ...] or [..., :end, ...].", + ) + offsets.append(0) + slice_dims.append(dim) + else: + offsets.append(idx) + + sb = StmtBuilder() + if len(slice_dims) == 0: ptr = sb.tensor_ptr(tensor=base) - offset = base.layout(*indices) + offset = base.layout(*offsets) self.current_scope.append(sb.flush_stmts()) return ptr[offset] else: - raise TilusProgramError(self, expr, "Tilus Script does not support slicing on GlobalTensor.") + # slicing + if isinstance(base, GlobalTensor): + raise NotImplementedError("Global tensor slicing is not implemented yet.") + else: + sliced_tensor = sb.shared_slice( + tensor=base, + offsets=offsets, + slice_dims=slice_dims, + slice_shape=[base.shape[dim] for dim in slice_dims], + ) + self.current_scope.append(sb.flush_stmts()) + return sliced_tensor elif isinstance(base, RegisterTensor): raise TilusProgramError(self, expr, "Tilus Script does not support indexing/slicing on RegisterTensor.") - elif isinstance(base, SharedTensor): - sb = StmtBuilder() - if isinstance(indices, (hidet_ir.Expr, int)): - offsets = [as_expr(indices)] - for i in range(len(base.shape) - 1): - offsets.append(as_expr(0)) - sliced_tensor = sb.shared_slice( - tensor=base, - offsets=offsets, - slice_dims=range(1, len(base.shape)), - slice_shape=base.shape[1:], - ) - self.current_scope.append(sb.flush_stmts()) - return sliced_tensor - else: - raise TilusProgramError( - self, expr, "Tilus Script does not support slicing on SharedTensor with subscript syntax." - ) else: raise NotImplementedError() From 0dd396c354182b754f669ad26cb94ae000ea241f Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 6 Sep 2025 01:21:05 -0400 Subject: [PATCH 2/7] support slicing/indexing of shared/global tensors Signed-off-by: Yaoyao Ding --- python/tilus/backends/codegen.py | 4 +- python/tilus/backends/emitters/gmem.py | 23 ++++++- python/tilus/backends/emitters/smem.py | 12 +++- python/tilus/ir/builders/stmt_builder.py | 48 +++++++++++++- python/tilus/ir/inst.py | 7 +++ python/tilus/ir/instructions/__init__.py | 3 + python/tilus/ir/instructions/generic.py | 63 ++++++++++++++++++- python/tilus/ir/layout/global_layout.py | 40 ++++++++++++ .../inference/inference_rules/__init__.py | 1 + .../inference_rules/allocate_shared.py | 30 +++++++++ .../inference/inference_rules/empty_rule.py | 4 +- python/tilus/ir/layout/inference/order.py | 2 + python/tilus/ir/layout/inference/rule.py | 2 +- .../inference/validation_rules/always_ok.py | 2 + python/tilus/lang/script.py | 2 +- python/tilus/lang/transpiler.py | 33 +++++++--- 16 files changed, 256 insertions(+), 20 deletions(-) create mode 100644 python/tilus/ir/layout/inference/inference_rules/allocate_shared.py diff --git a/python/tilus/backends/codegen.py b/python/tilus/backends/codegen.py index e254ad6a..49e18885 100644 --- a/python/tilus/backends/codegen.py +++ b/python/tilus/backends/codegen.py @@ -472,8 +472,8 @@ def visit_Function(self, func: Function) -> IRModule: if self.smem_workspace: self.free_shared_value(self.smem_workspace) self.smem_workspace = None - if self.smem_allocator.allocated != 0: - raise ValueError("Shared memory is not properly allocated/freed") + # if self.smem_allocator.allocated != 0: + # raise ValueError("Shared memory is not properly allocated/freed") if self.smem_allocator.maximum_allocated > get_current_target().properties.shared_memory_per_block: raise CodeGenerationFailed( "Request shared memory {} bytes, but the device only allows {} bytes.".format( diff --git a/python/tilus/backends/emitters/gmem.py b/python/tilus/backends/emitters/gmem.py index facea33a..9e259644 100644 --- a/python/tilus/backends/emitters/gmem.py +++ b/python/tilus/backends/emitters/gmem.py @@ -15,7 +15,8 @@ from hidet.ir.expr import Expr from tilus.backends.codegen import BaseInstEmitter, register_emitter -from tilus.ir.instructions import AllocateGlobalInst, GlobalViewInst +from tilus.ir import GlobalTensor +from tilus.ir.instructions import AllocateGlobalInst, GlobalIndexInst, GlobalSliceInst, GlobalViewInst from tilus.utils import cdiv @@ -34,3 +35,23 @@ def emit(self, inst: AllocateGlobalInst) -> None: ) var = self.get_or_allocate_var(tensor) self.assign(var, ptr) + + +@register_emitter(GlobalIndexInst) +class GlobalIndexInstEmitter(BaseInstEmitter): + def emit(self, inst: GlobalIndexInst) -> None: + dst = inst.dst + tensor = inst.inputs[0].as_global_tensor() + var = self.get_or_allocate_var(tensor) + offset = tensor.layout(*inst.indices) + self.assign(dst, value=var[offset]) + + +@register_emitter(GlobalSliceInst) +class GlobalSliceInstEmitter(BaseInstEmitter): + def emit(self, inst: GlobalSliceInst) -> None: + input_tensor: GlobalTensor = inst.global_input + output_tensor: GlobalTensor = inst.global_output + slice_offset = input_tensor.layout(*inst.offsets) + output_var = self.get_or_allocate_var(output_tensor) + self.assign(output_var, ~self.tensor2var[input_tensor][slice_offset]) diff --git a/python/tilus/backends/emitters/smem.py b/python/tilus/backends/emitters/smem.py index e7ec65f3..69f7d7f1 100644 --- a/python/tilus/backends/emitters/smem.py +++ b/python/tilus/backends/emitters/smem.py @@ -18,7 +18,7 @@ from hidet.ir.type import tensor_pointer_type from tilus.backends.codegen import BaseInstEmitter, register_emitter -from tilus.ir.instructions import AllocateSharedInst, FreeSharedInst, SharedSliceInst +from tilus.ir.instructions import AllocateSharedInst, FreeSharedInst, SharedIndexInst, SharedSliceInst from tilus.ir.tensor import SharedTensor @@ -62,3 +62,13 @@ def emit(self, inst: SharedSliceInst) -> None: tp=int32, init=self.shared_tensor_shared_space_addr[shared_input] + slice_offset * shared_input.dtype.nbytes, ) + + +@register_emitter(SharedIndexInst) +class SharedIndexInstEmitter(BaseInstEmitter): + def emit(self, inst: SharedIndexInst) -> None: + dst = inst.dst + tensor = inst.shared_input + var = self.get_or_allocate_var(tensor) + offset = tensor.layout(*inst.indices) + self.assign(dst, value=var[offset]) diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index 3ebd2ac4..9bf0b9d3 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -51,6 +51,8 @@ ExitInst, FormatPrintInst, FreeSharedInst, + GlobalIndexInst, + GlobalSliceInst, GlobalViewInst, LoadGlobalGenericInst, LoadGlobalInst, @@ -62,6 +64,7 @@ ReduceInst, RepeatInst, RepeatInterleaveInst, + SharedIndexInst, SharedSliceInst, SqueezeInst, StoreGlobalGenericInst, @@ -285,8 +288,10 @@ def brk(self): stmt = BreakStmt() self._stack[-1].append(stmt) - def declare(self, type: BaseType, init: Optional[Expr | float | int] = None) -> Var: - var = Var("v", type=type) + def declare(self, type: BaseType, init: Optional[Expr | float | int] = None, hint: Optional[str] = None) -> Var: + if hint is not None: + hint = "v" + var = Var(hint, type=type) self.append(DeclareStmt(var, as_expr(init) if init is not None else None)) return var @@ -364,6 +369,33 @@ def allocate_global( self.append(inst) return inst.global_output + def slice_global( + self, + tensor: GlobalTensor, + offsets: Sequence[Expr | int], + slice_dims: Sequence[int], + slice_shape: Sequence[Expr | int], + ) -> GlobalTensor: + offsets_ = [as_expr(offset) for offset in offsets] + inst = GlobalSliceInst.create( + tensor=tensor, + offsets=offsets_, + dims=slice_dims, + shape=slice_shape, + ) + self.append(inst) + return inst.global_output + + def index_global( + self, + dst: Var, + tensor: GlobalTensor, + indices: Sequence[Expr | int], + ) -> None: + indices_ = [as_expr(index) for index in indices] + inst = GlobalIndexInst.create(dst=dst, tensor=tensor, indices=indices_) + self.append(inst) + def assign_register(self, output: RegisterTensor, x: RegisterTensor) -> None: inst = AssignInst.create(output, x) self.append(inst) @@ -722,7 +754,7 @@ def free_shared(self, shared_value: SharedTensor) -> None: inst = FreeSharedInst.create(shared_value) self.append(inst) - def shared_slice( + def slice_shared( self, tensor: SharedTensor, offsets: Sequence[Expr | int], @@ -739,6 +771,16 @@ def shared_slice( self.append(inst) return inst.shared_output + def index_shared( + self, + dst: Var, + tensor: SharedTensor, + indices: Sequence[Expr | int], + ) -> None: + indices_ = [as_expr(index) for index in indices] + inst = SharedIndexInst.create(dst=dst, tensor=tensor, indices=indices_) + self.append(inst) + def load_shared( self, src: SharedTensor, diff --git a/python/tilus/ir/inst.py b/python/tilus/ir/inst.py index 106274e5..d448da53 100644 --- a/python/tilus/ir/inst.py +++ b/python/tilus/ir/inst.py @@ -60,6 +60,13 @@ def shared_input(self) -> SharedTensor: assert isinstance(x, SharedTensor) return x + @property + def global_input(self) -> GlobalTensor: + assert len(self.inputs) == 1 + x = self.inputs[0] + assert isinstance(x, GlobalTensor) + return x + @property def attributes(self) -> dict[str, Any]: attrs = {} diff --git a/python/tilus/ir/instructions/__init__.py b/python/tilus/ir/instructions/__init__.py index 10d7e7e1..db45ce62 100644 --- a/python/tilus/ir/instructions/__init__.py +++ b/python/tilus/ir/instructions/__init__.py @@ -43,6 +43,8 @@ ExitInst, FormatPrintInst, FreeSharedInst, + GlobalIndexInst, + GlobalSliceInst, GlobalViewInst, LoadGlobalGenericInst, LoadGlobalInst, @@ -54,6 +56,7 @@ ReduceInst, RepeatInst, RepeatInterleaveInst, + SharedIndexInst, SharedSliceInst, ShuffleDownInst, ShuffleUpInst, diff --git a/python/tilus/ir/instructions/generic.py b/python/tilus/ir/instructions/generic.py index 9342bebf..015f3ac4 100644 --- a/python/tilus/ir/instructions/generic.py +++ b/python/tilus/ir/instructions/generic.py @@ -72,6 +72,48 @@ def create(dst: GlobalTensor, x: RegisterTensor, offsets: Sequence[Expr], dims: return StoreGlobalInst(output=None, inputs=(dst, x), offsets=tuple(offsets), dims=tuple(dims)) +@dataclass(frozen=True, eq=False) +class GlobalSliceInst(Instruction): + offsets: tuple[Expr, ...] + dims: Optional[tuple[int, ...]] + + @staticmethod + def create( + tensor: GlobalTensor, + offsets: Sequence[Expr], + dims: Sequence[int], + shape: Sequence[Expr | int], + ) -> SharedSliceInst: + from tilus.ir.layout.global_layout import global_slice + + output = GlobalTensor.create(dtype=tensor.dtype, layout=global_slice(tensor.layout, offsets, dims, shape)) + return SharedSliceInst( + output=output, + inputs=(tensor,), + offsets=tuple(offsets), + dims=tuple(dims) if len(dims) < len(tensor.shape) else None, + ) + + +@dataclass(frozen=True, eq=False) +class GlobalIndexInst(Instruction): + dst: Var + indices: tuple[Expr, ...] + + @staticmethod + def create( + dst: Var, + tensor: GlobalTensor, + indices: Sequence[Expr], + ) -> GlobalIndexInst: + return GlobalIndexInst( + output=None, + inputs=(tensor,), + dst=dst, + indices=tuple(indices), + ) + + @dataclass(frozen=True, eq=False) class LoadSharedInst(Instruction): @staticmethod @@ -103,7 +145,26 @@ def create( output=output, inputs=(tensor,), offsets=tuple(offsets), - dims=tuple(dims) if len(dims) < len(tensor.shape) else None, + dims=tuple(dims) if len(dims) < len(tensor.shape) else tuple(range(len(tensor.shape))), + ) + + +@dataclass(frozen=True, eq=False) +class SharedIndexInst(Instruction): + dst: Var + indices: tuple[Expr, ...] + + @staticmethod + def create( + dst: Var, + tensor: SharedTensor, + indices: Sequence[Expr], + ) -> SharedIndexInst: + return SharedIndexInst( + output=None, + inputs=(tensor,), + dst=dst, + indices=tuple(indices), ) diff --git a/python/tilus/ir/layout/global_layout.py b/python/tilus/ir/layout/global_layout.py index 53dbe161..6df144d7 100644 --- a/python/tilus/ir/layout/global_layout.py +++ b/python/tilus/ir/layout/global_layout.py @@ -218,3 +218,43 @@ def f_offset(axes: Sequence[Var]) -> Expr: return sum([axes[i] * strides[i] for i in range(len(shape))], start=int32.zero) return GlobalLayout.create(shape=shape, size=prod(shape), f_offset=f_offset) + + +def global_slice( + layout: GlobalLayout, offsets: Sequence[Expr | int], dims: Sequence[int], shape: Sequence[Expr | int] +) -> GlobalLayout: + """Create a sliced global layout from an existing layout. + + This function creates a new global layout by slicing an existing global layout. The slicing is defined by the + specified offsets, dimensions to slice, and the shape of the resulting layout. The new layout retains the mapping + function of the original layout, adjusted for the specified offsets and dimensions. + + Parameters + ---------- + layout: GlobalLayout + The original global layout to be sliced. + offsets: Sequence[Expr | int] + The offsets for each dimension of the original layout. It should have the same length as the original layout's + shape. + dims: Sequence[int] + The dimensions to be sliced from the original layout. Each dimension should be a valid index in the original + layout's shape. + shape: Sequence[Expr | int] + The shape of the resulting sliced global layout. It should have the same length as the number of dimensions + specified in `dims`. + + Returns + ------- + ret: GlobalLayout + A new global layout that represents the sliced version of the original layout, with the specified shape and + adjusted mapping function. + """ + assert len(dims) == len(shape) <= len(layout.shape) == len(offsets) + + def f_offset(axes: Sequence[Var]) -> Expr: + indices = list(offsets) + for dim, axis in zip(dims, axes): + indices[dim] = axis + offsets[dim] + return layout(*indices) - layout(*offsets) # type: ignore[arg-type] + + return GlobalLayout.create(shape=shape, size=prod(shape), f_offset=f_offset) diff --git a/python/tilus/ir/layout/inference/inference_rules/__init__.py b/python/tilus/ir/layout/inference/inference_rules/__init__.py index e84d5a96..90ffb326 100644 --- a/python/tilus/ir/layout/inference/inference_rules/__init__.py +++ b/python/tilus/ir/layout/inference/inference_rules/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from . import ( + allocate_shared, assign, cp_async, elementwise_binary, diff --git a/python/tilus/ir/layout/inference/inference_rules/allocate_shared.py b/python/tilus/ir/layout/inference/inference_rules/allocate_shared.py new file mode 100644 index 00000000..f246d7c2 --- /dev/null +++ b/python/tilus/ir/layout/inference/inference_rules/allocate_shared.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tilus.ir.instructions import AllocateSharedInst +from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule +from tilus.ir.layout.shared_layout import SharedLayout, shared_row_major +from tilus.ir.tensor import SharedTensor + + +@register_rule(AllocateSharedInst) +class AllocateSharedRule(LayoutInferenceRule): + @staticmethod + def inference(ctx: LayoutInferenceContext, inst: AllocateSharedInst) -> dict[SharedTensor, SharedLayout]: + tensor = inst.shared_output + + if tensor.optional_layout is not None: + return {} + else: + return {tensor: shared_row_major(*tensor.shape)} diff --git a/python/tilus/ir/layout/inference/inference_rules/empty_rule.py b/python/tilus/ir/layout/inference/inference_rules/empty_rule.py index 6ea0be19..1452f19a 100644 --- a/python/tilus/ir/layout/inference/inference_rules/empty_rule.py +++ b/python/tilus/ir/layout/inference/inference_rules/empty_rule.py @@ -15,22 +15,22 @@ from tilus import RegisterLayout, SharedLayout from tilus.ir.instructions import ( AllocateRegisterInst, - AllocateSharedInst, FormatPrintInst, FreeSharedInst, GlobalViewInst, PrintTensorInst, + SharedIndexInst, StoreGlobalInst, ) from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule from tilus.ir.tensor import Tensor +@register_rule(SharedIndexInst) @register_rule(PrintTensorInst) @register_rule(FormatPrintInst) @register_rule(FreeSharedInst) @register_rule(AllocateRegisterInst) -@register_rule(AllocateSharedInst) @register_rule(StoreGlobalInst) class EmptyRule(LayoutInferenceRule): @staticmethod diff --git a/python/tilus/ir/layout/inference/order.py b/python/tilus/ir/layout/inference/order.py index 03baed08..7eb289d9 100644 --- a/python/tilus/ir/layout/inference/order.py +++ b/python/tilus/ir/layout/inference/order.py @@ -17,6 +17,7 @@ from tilus.ir.layout.inference.rule import LayoutInferenceRule from tilus.utils import initialize +from .inference_rules.allocate_shared import AllocateSharedRule from .inference_rules.assign import AssignRule from .inference_rules.cp_async import CopyAsyncRule from .inference_rules.elementwise_binary import BinaryRule @@ -52,6 +53,7 @@ [CopyAsyncRule], [LoadSharedInferRegisterRule], [LoadSharedInferRowMajorSharedRule], + [AllocateSharedRule], ] rule2order: dict[Type[LayoutInferenceRule], int] = {} diff --git a/python/tilus/ir/layout/inference/rule.py b/python/tilus/ir/layout/inference/rule.py index b156b261..4ec6103d 100644 --- a/python/tilus/ir/layout/inference/rule.py +++ b/python/tilus/ir/layout/inference/rule.py @@ -126,7 +126,7 @@ def get_inference_rules(inst: Type[Instruction] | Instruction) -> list[Type[Layo _inference_rules[inst_cls] = _inference_rules[parent_cls] break else: - raise ValueError(f"No layout inference rule registered for {bold(inst_cls.__name__)}") + return [] return _inference_rules[inst_cls].copy() diff --git a/python/tilus/ir/layout/inference/validation_rules/always_ok.py b/python/tilus/ir/layout/inference/validation_rules/always_ok.py index 761f2f8c..c5b64151 100644 --- a/python/tilus/ir/layout/inference/validation_rules/always_ok.py +++ b/python/tilus/ir/layout/inference/validation_rules/always_ok.py @@ -25,6 +25,7 @@ LoadSharedGenericInst, LoadSharedInst, PrintTensorInst, + SharedIndexInst, SharedSliceInst, StoreGlobalGenericInst, StoreGlobalInst, @@ -33,6 +34,7 @@ from tilus.ir.layout.inference.rule import LayoutValidationRule, register_rule +@register_rule(SharedIndexInst) @register_rule(CopyAsyncGenericInst) @register_rule(CopyAsyncInst) @register_rule(StoreGlobalGenericInst) diff --git a/python/tilus/lang/script.py b/python/tilus/lang/script.py index 9256a2a0..cba5f931 100644 --- a/python/tilus/lang/script.py +++ b/python/tilus/lang/script.py @@ -631,7 +631,7 @@ def store_shared( if dims is None: assert len(src.shape) == len(dst.shape) dims = list(range(len(src.shape))) - dst = self._builder.shared_slice(dst, offsets=offsets, slice_dims=dims, slice_shape=src.shape) + dst = self._builder.slice_shared(dst, offsets=offsets, slice_dims=dims, slice_shape=src.shape) self._builder.store_shared(dst=dst, src=src) def free_shared(self, tensor: SharedTensor) -> None: diff --git a/python/tilus/lang/transpiler.py b/python/tilus/lang/transpiler.py index cc2973e5..47b10474 100644 --- a/python/tilus/lang/transpiler.py +++ b/python/tilus/lang/transpiler.py @@ -890,7 +890,7 @@ def visit_Subscript(self, expr: ast.Subscript) -> Any: slice_dims = [] for dim, idx in enumerate(indices): if isinstance(idx, slice): - if idx.start is not None or idx.end is not None: + if idx.start is not None or idx.stop is not None: if not isinstance(idx.start, (int, hidet_ir.Expr)): raise TilusProgramError( self, @@ -905,23 +905,33 @@ def visit_Subscript(self, expr: ast.Subscript) -> Any: sb = StmtBuilder() if len(slice_dims) == 0: - ptr = sb.tensor_ptr(tensor=base) - offset = base.layout(*offsets) + # indexing + val = sb.declare(type=base.dtype, hint="val") + if isinstance(base, GlobalTensor): + sb.index_global(dst=val, tensor=base, indices=offsets) + else: + sb.index_shared(dst=val, tensor=base, indices=offsets) self.current_scope.append(sb.flush_stmts()) - return ptr[offset] + return val else: # slicing + sliced_tensor: Union[GlobalTensor, SharedTensor] if isinstance(base, GlobalTensor): - raise NotImplementedError("Global tensor slicing is not implemented yet.") + sliced_tensor = sb.slice_global( + tensor=base, + offsets=offsets, + slice_dims=slice_dims, + slice_shape=[base.shape[dim] for dim in slice_dims], + ) else: - sliced_tensor = sb.shared_slice( + sliced_tensor = sb.slice_shared( tensor=base, offsets=offsets, slice_dims=slice_dims, slice_shape=[base.shape[dim] for dim in slice_dims], ) - self.current_scope.append(sb.flush_stmts()) - return sliced_tensor + self.current_scope.append(sb.flush_stmts()) + return sliced_tensor elif isinstance(base, RegisterTensor): raise TilusProgramError(self, expr, "Tilus Script does not support indexing/slicing on RegisterTensor.") else: @@ -1185,3 +1195,10 @@ def visit_Return(self, stmt: ast.Return) -> None: if stmt.value is not None: raise TilusProgramError(self, stmt, "Return statement in Tilus Script does not support returning a value.") self.current_scope.append(ReturnStmt()) + + def visit_Slice(self, expr: ast.Slice) -> slice: + return slice( + self.visit(expr.lower) if expr.lower is not None else None, + self.visit(expr.upper) if expr.upper is not None else None, + self.visit(expr.step) if expr.step is not None else None, + ) From 3e50dffe4917a1e30d8a14bde5a4d00b46f7f7ce Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 6 Sep 2025 18:59:33 +0000 Subject: [PATCH 3/7] fix Signed-off-by: Yaoyao Ding --- python/tilus/lang/transpiler.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/tilus/lang/transpiler.py b/python/tilus/lang/transpiler.py index 47b10474..a09a960b 100644 --- a/python/tilus/lang/transpiler.py +++ b/python/tilus/lang/transpiler.py @@ -882,10 +882,6 @@ def visit_Subscript(self, expr: ast.Subscript) -> Any: indices = [indices] else: indices = list(indices) - while len(indices) < len(base.shape): - indices.append(0) - if len(indices) > len(base.shape): - raise TilusProgramError(self, expr, "Too many indices for tensor of shape {}.".format(base.shape)) offsets = [] slice_dims = [] for dim, idx in enumerate(indices): @@ -903,6 +899,15 @@ def visit_Subscript(self, expr: ast.Subscript) -> Any: else: offsets.append(idx) + if len(offsets) < len(base.shape): + dim = len(offsets) + while len(offsets) < len(base.shape): + offsets.append(0) + slice_dims.append(dim) + dim += 1 + if len(indices) > len(base.shape): + raise TilusProgramError(self, expr, "Too many indices for tensor of shape {}.".format(base.shape)) + sb = StmtBuilder() if len(slice_dims) == 0: # indexing From 10a377d208bb519db1f4c776ae9587ce96a78438 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 6 Sep 2025 15:08:46 -0400 Subject: [PATCH 4/7] add script to sign Signed-off-by: Yaoyao Ding --- scripts/sign-commits.sh | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 scripts/sign-commits.sh diff --git a/scripts/sign-commits.sh b/scripts/sign-commits.sh new file mode 100644 index 00000000..dac5f873 --- /dev/null +++ b/scripts/sign-commits.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# This script signs off all unsigned commits in the current branch that are not in the main branch, +# and shows the newly signed commits. + +set -e + +MAIN_BRANCH="main" + +git fetch origin + +BASE=$(git merge-base HEAD origin/$MAIN_BRANCH) + +# List commits after the common ancestor missing "Signed-off-by" +UNSIGNED_COMMITS=$(git rev-list $BASE..HEAD | while read commit; do + if ! git show --quiet --format=%B $commit | grep -q "Signed-off-by:"; then + echo $commit + fi +done) + +if [ -z "$UNSIGNED_COMMITS" ]; then + echo "No unsigned commits to sign off." + exit 0 +fi + +# Rebase with signoff +git rebase --signoff $BASE + +echo "Newly signed commits:" +for commit in $UNSIGNED_COMMITS; do + git log --format="* %h %s" -n 1 $commit +done \ No newline at end of file From 546b6b6c8292625fdb7a409d777ca681fa2b88de Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 6 Sep 2025 05:39:59 +0000 Subject: [PATCH 5/7] wip Signed-off-by: Yaoyao Ding --- python/tilus/ir/analyzers/scalar_analyzer.py | 14 +++++++++++++- python/tilus/ir/tools/printer.py | 10 +++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/python/tilus/ir/analyzers/scalar_analyzer.py b/python/tilus/ir/analyzers/scalar_analyzer.py index 4fc4669b..2fc3cf39 100644 --- a/python/tilus/ir/analyzers/scalar_analyzer.py +++ b/python/tilus/ir/analyzers/scalar_analyzer.py @@ -25,7 +25,8 @@ import tilus.logging from tilus.ir.func import Analysis, Function -from tilus.ir.stmt import AssignStmt, DeclareStmt, ForStmt, LetStmt +from tilus.ir.instructions import SharedIndexInst +from tilus.ir.stmt import AssignStmt, DeclareStmt, ForStmt, LetStmt, InstStmt from tilus.ir.tools import IRPrinter, collect from tilus.utils import gcd @@ -148,6 +149,10 @@ def is_constant(self) -> bool: def empty_set() -> ScalarSet: return ScalarSet(lower_bound=0, upper_bound=-1) + @staticmethod + def universal_set() -> ScalarSet: + return ScalarSet(lower_bound=None, upper_bound=None, divisibility=1) + def __eq__(self, other: ScalarSet) -> bool: if self.is_empty() and other.is_empty(): return True @@ -463,6 +468,13 @@ def analyze_scalar(func: Function) -> Function: else: raise NotImplementedError() + # we are not interested in variables used in the following places + for stmt in collect(func, types=[InstStmt]): + assert isinstance(stmt, InstStmt) + if isinstance(stmt.inst, SharedIndexInst): + if stmt.inst.dst in variables: + del variables[variables.index(stmt.inst.dst)] + # initialize the scalar set of variables defined in the function body to be empty set for var in variables: var2set[var] = ScalarSet(lower_bound=0, upper_bound=-1) # empty set diff --git a/python/tilus/ir/tools/printer.py b/python/tilus/ir/tools/printer.py index cbfdae0d..983a2982 100644 --- a/python/tilus/ir/tools/printer.py +++ b/python/tilus/ir/tools/printer.py @@ -274,15 +274,15 @@ def visit_ReturnStmt(self, stmt: ReturnStmt) -> Any: return NewLine() + Text("return") def visit_DeclareStmt(self, stmt: DeclareStmt) -> Doc: - return ( + ret = ( NewLine() + Text("declare ") + self.visit(stmt.var) + ": " - + self.printer(stmt.var.type) - + " = " - + self.visit(stmt.init) - ) + + self.printer(stmt.var.type)) + if stmt.init is not None: + ret += " = " + self.visit(stmt.init) + return ret def visit_LetStmt(self, stmt: LetStmt) -> Doc: doc = Doc() From d4834be34319279d666909de103dfa804d2b9c75 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 6 Sep 2025 19:11:43 -0400 Subject: [PATCH 6/7] use stmt to implement Signed-off-by: Yaoyao Ding --- python/tilus/backends/codegen.py | 23 ++++++++-- python/tilus/backends/emitters/gmem.py | 12 +----- python/tilus/backends/emitters/smem.py | 12 +----- python/tilus/ir/analyzers/scalar_analyzer.py | 10 +---- python/tilus/ir/builders/stmt_builder.py | 43 ++++++++----------- python/tilus/ir/functors/functor.py | 37 ++++++++++++---- python/tilus/ir/instructions/__init__.py | 2 - python/tilus/ir/instructions/generic.py | 38 ---------------- .../inference/inference_rules/empty_rule.py | 2 - .../inference/validation_rules/always_ok.py | 2 - python/tilus/ir/stmt.py | 10 ++++- python/tilus/ir/tools/printer.py | 36 +++++++++++----- python/tilus/lang/transpiler.py | 34 ++++++++++++--- 13 files changed, 130 insertions(+), 131 deletions(-) diff --git a/python/tilus/backends/codegen.py b/python/tilus/backends/codegen.py index 49e18885..8aa6efee 100644 --- a/python/tilus/backends/codegen.py +++ b/python/tilus/backends/codegen.py @@ -47,7 +47,8 @@ LetStmt, ReturnStmt, SeqStmt, - TensorPtrStmt, + TensorElemPtrStmt, + TensorElemValueStmt, WhileStmt, ) from tilus.ir.tensor import GlobalTensor, RegisterTensor, SharedTensor, Tensor @@ -545,19 +546,33 @@ def visit_LetStmt(self, stmt: LetStmt) -> None: def visit_AssignStmt(self, stmt: AssignStmt) -> None: self.builder.assign(stmt.var, value=stmt.value) - def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> None: + def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> None: if stmt.space in ["generic", "global"]: - self.builder.declare(stmt.ptr_var, self.tensor2var[stmt.tensor]) + if stmt.space == "generic": + assert isinstance(stmt.tensor, (GlobalTensor, SharedTensor)) + else: + assert isinstance(stmt.tensor, GlobalTensor) + ptr = self.tensor2var[stmt.tensor] + if stmt.indices is not None: + ptr = ptr + stmt.tensor.layout(*stmt.indices) + self.builder.declare(stmt.ptr_var, ptr) elif stmt.space == "local": raise NotImplementedError("Local tensor pointer is not supported yet.") elif stmt.space == "shared": if not isinstance(stmt.tensor, SharedTensor): raise ValueError("Expected a SharedTensor for shared tensor pointer, got: {}".format(stmt.tensor)) shared_tensor: SharedTensor = stmt.tensor - self.builder.declare(stmt.ptr_var, self.shared_tensor_shared_space_addr[shared_tensor]) + addr = self.shared_tensor_shared_space_addr[shared_tensor] + if stmt.indices is not None: + addr = addr + shared_tensor.layout(*stmt.indices) * shared_tensor.dtype.nbytes + self.builder.declare(stmt.ptr_var, addr) else: raise ValueError("Unknown tensor pointer space: {}".format(stmt.space)) + def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> None: + assert isinstance(stmt.tensor, (GlobalTensor, SharedTensor)) + self.builder.declare(stmt.var, init=self.tensor2var[stmt.tensor][stmt.tensor.layout(*stmt.indices)]) + def visit_ReturnStmt(self, stmt: ReturnStmt) -> None: self.builder.ret() diff --git a/python/tilus/backends/emitters/gmem.py b/python/tilus/backends/emitters/gmem.py index 9e259644..2a36f1b3 100644 --- a/python/tilus/backends/emitters/gmem.py +++ b/python/tilus/backends/emitters/gmem.py @@ -16,7 +16,7 @@ from tilus.backends.codegen import BaseInstEmitter, register_emitter from tilus.ir import GlobalTensor -from tilus.ir.instructions import AllocateGlobalInst, GlobalIndexInst, GlobalSliceInst, GlobalViewInst +from tilus.ir.instructions import AllocateGlobalInst, GlobalSliceInst, GlobalViewInst from tilus.utils import cdiv @@ -37,16 +37,6 @@ def emit(self, inst: AllocateGlobalInst) -> None: self.assign(var, ptr) -@register_emitter(GlobalIndexInst) -class GlobalIndexInstEmitter(BaseInstEmitter): - def emit(self, inst: GlobalIndexInst) -> None: - dst = inst.dst - tensor = inst.inputs[0].as_global_tensor() - var = self.get_or_allocate_var(tensor) - offset = tensor.layout(*inst.indices) - self.assign(dst, value=var[offset]) - - @register_emitter(GlobalSliceInst) class GlobalSliceInstEmitter(BaseInstEmitter): def emit(self, inst: GlobalSliceInst) -> None: diff --git a/python/tilus/backends/emitters/smem.py b/python/tilus/backends/emitters/smem.py index 69f7d7f1..e7ec65f3 100644 --- a/python/tilus/backends/emitters/smem.py +++ b/python/tilus/backends/emitters/smem.py @@ -18,7 +18,7 @@ from hidet.ir.type import tensor_pointer_type from tilus.backends.codegen import BaseInstEmitter, register_emitter -from tilus.ir.instructions import AllocateSharedInst, FreeSharedInst, SharedIndexInst, SharedSliceInst +from tilus.ir.instructions import AllocateSharedInst, FreeSharedInst, SharedSliceInst from tilus.ir.tensor import SharedTensor @@ -62,13 +62,3 @@ def emit(self, inst: SharedSliceInst) -> None: tp=int32, init=self.shared_tensor_shared_space_addr[shared_input] + slice_offset * shared_input.dtype.nbytes, ) - - -@register_emitter(SharedIndexInst) -class SharedIndexInstEmitter(BaseInstEmitter): - def emit(self, inst: SharedIndexInst) -> None: - dst = inst.dst - tensor = inst.shared_input - var = self.get_or_allocate_var(tensor) - offset = tensor.layout(*inst.indices) - self.assign(dst, value=var[offset]) diff --git a/python/tilus/ir/analyzers/scalar_analyzer.py b/python/tilus/ir/analyzers/scalar_analyzer.py index 2fc3cf39..5a362253 100644 --- a/python/tilus/ir/analyzers/scalar_analyzer.py +++ b/python/tilus/ir/analyzers/scalar_analyzer.py @@ -25,8 +25,7 @@ import tilus.logging from tilus.ir.func import Analysis, Function -from tilus.ir.instructions import SharedIndexInst -from tilus.ir.stmt import AssignStmt, DeclareStmt, ForStmt, LetStmt, InstStmt +from tilus.ir.stmt import AssignStmt, DeclareStmt, ForStmt, LetStmt from tilus.ir.tools import IRPrinter, collect from tilus.utils import gcd @@ -468,13 +467,6 @@ def analyze_scalar(func: Function) -> Function: else: raise NotImplementedError() - # we are not interested in variables used in the following places - for stmt in collect(func, types=[InstStmt]): - assert isinstance(stmt, InstStmt) - if isinstance(stmt.inst, SharedIndexInst): - if stmt.inst.dst in variables: - del variables[variables.index(stmt.inst.dst)] - # initialize the scalar set of variables defined in the function body to be empty set for var in variables: var2set[var] = ScalarSet(lower_bound=0, upper_bound=-1) # empty set diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index 9bf0b9d3..dcfd8292 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -51,7 +51,6 @@ ExitInst, FormatPrintInst, FreeSharedInst, - GlobalIndexInst, GlobalSliceInst, GlobalViewInst, LoadGlobalGenericInst, @@ -64,7 +63,6 @@ ReduceInst, RepeatInst, RepeatInterleaveInst, - SharedIndexInst, SharedSliceInst, SqueezeInst, StoreGlobalGenericInst, @@ -90,7 +88,8 @@ InstStmt, SeqStmt, Stmt, - TensorPtrStmt, + TensorElemPtrStmt, + TensorElemValueStmt, WhileStmt, ) from tilus.ir.tensor import GlobalTensor, RegisterTensor, SharedLayout, SharedTensor, Tensor @@ -299,13 +298,25 @@ def assign(self, var: Var, value: Expr) -> None: self.append(AssignStmt(var, value)) def tensor_ptr(self, tensor: Tensor, space: str = "generic") -> Var: + return self.tensor_element_ptr(tensor, indices=None, space=space) + + def tensor_element_ptr( + self, tensor: Tensor, indices: Optional[Sequence[Expr | int]] = None, space: str = "generic" + ) -> Var: if space in ["generic", "global"]: - ptr_var = Var("v", type=~tensor.dtype) + ptr_var = Var("ptr", type=~tensor.dtype) else: - ptr_var = Var("v", int32) - self.append(TensorPtrStmt(ptr_var, tensor, space=space)) + ptr_var = Var("ptr", int32) + if indices is not None: + indices = tuple(as_expr(e) for e in indices) + self.append(TensorElemPtrStmt(ptr_var, tensor, indices=indices, space=space)) return ptr_var + def tensor_element_value(self, tensor: Tensor, indices: Sequence[Expr | int]) -> Var: + var = Var("val", type=tensor.dtype) + self.append(TensorElemValueStmt(var, tensor, indices=tuple(as_expr(e) for e in indices))) + return var + def append(self, inst_or_stmt: Union[Instruction, Stmt]) -> None: if isinstance(inst_or_stmt, Instruction): stmt: Stmt = InstStmt(inst_or_stmt) @@ -386,16 +397,6 @@ def slice_global( self.append(inst) return inst.global_output - def index_global( - self, - dst: Var, - tensor: GlobalTensor, - indices: Sequence[Expr | int], - ) -> None: - indices_ = [as_expr(index) for index in indices] - inst = GlobalIndexInst.create(dst=dst, tensor=tensor, indices=indices_) - self.append(inst) - def assign_register(self, output: RegisterTensor, x: RegisterTensor) -> None: inst = AssignInst.create(output, x) self.append(inst) @@ -771,16 +772,6 @@ def slice_shared( self.append(inst) return inst.shared_output - def index_shared( - self, - dst: Var, - tensor: SharedTensor, - indices: Sequence[Expr | int], - ) -> None: - indices_ = [as_expr(index) for index in indices] - inst = SharedIndexInst.create(dst=dst, tensor=tensor, indices=indices_) - self.append(inst) - def load_shared( self, src: SharedTensor, diff --git a/python/tilus/ir/functors/functor.py b/python/tilus/ir/functors/functor.py index d7c9d641..65cb8f4e 100644 --- a/python/tilus/ir/functors/functor.py +++ b/python/tilus/ir/functors/functor.py @@ -34,7 +34,8 @@ ReturnStmt, SeqStmt, Stmt, - TensorPtrStmt, + TensorElemPtrStmt, + TensorElemValueStmt, WhileStmt, ) from tilus.ir.tensor import GlobalTensor, RegisterTensor, SharedLayout, SharedTensor @@ -105,8 +106,10 @@ def visit(self, node): ret = self.visit_LetStmt(node) elif isinstance(node, AssignStmt): ret = self.visit_AssignStmt(node) - elif isinstance(node, TensorPtrStmt): - ret = self.visit_TensorPtrStmt(node) + elif isinstance(node, TensorElemPtrStmt): + ret = self.visit_TensorElemPtrStmt(node) + elif isinstance(node, TensorElemValueStmt): + ret = self.visit_TensorElemValueStmt(node) # scalar expression and type elif isinstance(node, Expr): ret = self.visit_Expr(node) @@ -205,7 +208,10 @@ def visit_LetStmt(self, stmt: LetStmt) -> Any: def visit_AssignStmt(self, stmt: AssignStmt) -> Any: raise NotImplementedError() - def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> Any: + def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> Any: + raise NotImplementedError() + + def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> Any: raise NotImplementedError() # tensors and layouts @@ -349,12 +355,21 @@ def visit_AssignStmt(self, stmt: AssignStmt) -> Stmt: else: return AssignStmt(stmt.var, value) - def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> Stmt: + def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> Stmt: tensor = self.visit(stmt.tensor) - if tensor is stmt.tensor: + indices = self.visit(stmt.indices) + if tensor is stmt.tensor and indices is stmt.indices: return stmt else: - return TensorPtrStmt(stmt.ptr_var, tensor, stmt.space) + return TensorElemPtrStmt(stmt.ptr_var, tensor, indices, stmt.space) + + def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> Stmt: + tensor = self.visit(stmt.tensor) + indices = self.visit(stmt.indices) + if tensor is stmt.tensor and indices is stmt.indices: + return stmt + else: + return TensorElemValueStmt(stmt.var, tensor, indices) def visit_WhileStmt(self, stmt: WhileStmt) -> Stmt: cond = self.visit(stmt.cond) @@ -494,9 +509,15 @@ def visit_AssignStmt(self, stmt: AssignStmt) -> None: self.visit(stmt.var) self.visit(stmt.value) - def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> None: + def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> None: self.visit(stmt.ptr_var) self.visit(stmt.tensor) + self.visit(stmt.indices) + + def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> None: + self.visit(stmt.var) + self.visit(stmt.tensor) + self.visit(stmt.indices) # values diff --git a/python/tilus/ir/instructions/__init__.py b/python/tilus/ir/instructions/__init__.py index db45ce62..8322bc8a 100644 --- a/python/tilus/ir/instructions/__init__.py +++ b/python/tilus/ir/instructions/__init__.py @@ -43,7 +43,6 @@ ExitInst, FormatPrintInst, FreeSharedInst, - GlobalIndexInst, GlobalSliceInst, GlobalViewInst, LoadGlobalGenericInst, @@ -56,7 +55,6 @@ ReduceInst, RepeatInst, RepeatInterleaveInst, - SharedIndexInst, SharedSliceInst, ShuffleDownInst, ShuffleUpInst, diff --git a/python/tilus/ir/instructions/generic.py b/python/tilus/ir/instructions/generic.py index 015f3ac4..3aedbc6f 100644 --- a/python/tilus/ir/instructions/generic.py +++ b/python/tilus/ir/instructions/generic.py @@ -95,25 +95,6 @@ def create( ) -@dataclass(frozen=True, eq=False) -class GlobalIndexInst(Instruction): - dst: Var - indices: tuple[Expr, ...] - - @staticmethod - def create( - dst: Var, - tensor: GlobalTensor, - indices: Sequence[Expr], - ) -> GlobalIndexInst: - return GlobalIndexInst( - output=None, - inputs=(tensor,), - dst=dst, - indices=tuple(indices), - ) - - @dataclass(frozen=True, eq=False) class LoadSharedInst(Instruction): @staticmethod @@ -149,25 +130,6 @@ def create( ) -@dataclass(frozen=True, eq=False) -class SharedIndexInst(Instruction): - dst: Var - indices: tuple[Expr, ...] - - @staticmethod - def create( - dst: Var, - tensor: SharedTensor, - indices: Sequence[Expr], - ) -> SharedIndexInst: - return SharedIndexInst( - output=None, - inputs=(tensor,), - dst=dst, - indices=tuple(indices), - ) - - @dataclass(frozen=True, eq=False) class LoadGlobalGenericInst(Instruction): ptr: Var diff --git a/python/tilus/ir/layout/inference/inference_rules/empty_rule.py b/python/tilus/ir/layout/inference/inference_rules/empty_rule.py index 1452f19a..80505a5b 100644 --- a/python/tilus/ir/layout/inference/inference_rules/empty_rule.py +++ b/python/tilus/ir/layout/inference/inference_rules/empty_rule.py @@ -19,14 +19,12 @@ FreeSharedInst, GlobalViewInst, PrintTensorInst, - SharedIndexInst, StoreGlobalInst, ) from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule from tilus.ir.tensor import Tensor -@register_rule(SharedIndexInst) @register_rule(PrintTensorInst) @register_rule(FormatPrintInst) @register_rule(FreeSharedInst) diff --git a/python/tilus/ir/layout/inference/validation_rules/always_ok.py b/python/tilus/ir/layout/inference/validation_rules/always_ok.py index c5b64151..761f2f8c 100644 --- a/python/tilus/ir/layout/inference/validation_rules/always_ok.py +++ b/python/tilus/ir/layout/inference/validation_rules/always_ok.py @@ -25,7 +25,6 @@ LoadSharedGenericInst, LoadSharedInst, PrintTensorInst, - SharedIndexInst, SharedSliceInst, StoreGlobalGenericInst, StoreGlobalInst, @@ -34,7 +33,6 @@ from tilus.ir.layout.inference.rule import LayoutValidationRule, register_rule -@register_rule(SharedIndexInst) @register_rule(CopyAsyncGenericInst) @register_rule(CopyAsyncInst) @register_rule(StoreGlobalGenericInst) diff --git a/python/tilus/ir/stmt.py b/python/tilus/ir/stmt.py index 30a11382..64613b42 100644 --- a/python/tilus/ir/stmt.py +++ b/python/tilus/ir/stmt.py @@ -116,12 +116,20 @@ class EvaluateStmt(Stmt): @dataclass(frozen=True, eq=False) -class TensorPtrStmt(Stmt): +class TensorElemPtrStmt(Stmt): ptr_var: Var tensor: Tensor + indices: Optional[tuple[Expr, ...]] space: str # 'generic', 'shared', 'global', 'local' +@dataclass(frozen=True, eq=False) +class TensorElemValueStmt(Stmt): + var: Var + tensor: Tensor + indices: tuple[Expr, ...] + + @dataclass(frozen=True, eq=False) class InstStmt(Stmt): inst: Instruction diff --git a/python/tilus/ir/tools/printer.py b/python/tilus/ir/tools/printer.py index 983a2982..33c8de37 100644 --- a/python/tilus/ir/tools/printer.py +++ b/python/tilus/ir/tools/printer.py @@ -36,7 +36,8 @@ LetStmt, ReturnStmt, SeqStmt, - TensorPtrStmt, + TensorElemPtrStmt, + TensorElemValueStmt, WhileStmt, ) from tilus.ir.tensor import GlobalLayout, GlobalTensor, RegisterTensor, SharedLayout, SharedTensor, Tensor @@ -274,12 +275,7 @@ def visit_ReturnStmt(self, stmt: ReturnStmt) -> Any: return NewLine() + Text("return") def visit_DeclareStmt(self, stmt: DeclareStmt) -> Doc: - ret = ( - NewLine() - + Text("declare ") - + self.visit(stmt.var) - + ": " - + self.printer(stmt.var.type)) + ret = NewLine() + Text("declare ") + self.visit(stmt.var) + ": " + self.printer(stmt.var.type) if stmt.init is not None: ret += " = " + self.visit(stmt.init) return ret @@ -302,17 +298,35 @@ def visit_LetStmt(self, stmt: LetStmt) -> Doc: def visit_AssignStmt(self, stmt: AssignStmt) -> Doc: return NewLine() + self.visit(stmt.var) + " = " + self.visit(stmt.value) - def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> Doc: - return ( + def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> Doc: + doc = Doc() + doc += ( NewLine() + self.visit(stmt.ptr_var) + ": " + self.printer(stmt.ptr_var.type) + " = " - + "addr(" + + "~" + + self.visit(stmt.tensor) + ) + if stmt.indices is not None: + doc += "[" + self.visit(stmt.indices) + "]" + return doc + + def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> Any: + doc = Doc() + doc += ( + NewLine() + + self.visit(stmt.var) + + ": " + + self.printer(stmt.var.type) + + " = " + self.visit(stmt.tensor) - + ")" + + "[" + + self.visit(stmt.indices) + + "]" ) + return doc def visit_Instruction(self, inst: Instruction) -> Doc: doc = Doc() diff --git a/python/tilus/lang/transpiler.py b/python/tilus/lang/transpiler.py index a09a960b..60225b27 100644 --- a/python/tilus/lang/transpiler.py +++ b/python/tilus/lang/transpiler.py @@ -911,13 +911,9 @@ def visit_Subscript(self, expr: ast.Subscript) -> Any: sb = StmtBuilder() if len(slice_dims) == 0: # indexing - val = sb.declare(type=base.dtype, hint="val") - if isinstance(base, GlobalTensor): - sb.index_global(dst=val, tensor=base, indices=offsets) - else: - sb.index_shared(dst=val, tensor=base, indices=offsets) + var = sb.tensor_element_value(base, indices) self.current_scope.append(sb.flush_stmts()) - return val + return var else: # slicing sliced_tensor: Union[GlobalTensor, SharedTensor] @@ -1137,6 +1133,32 @@ def visit_If(self, stmt: ast.If) -> None: def visit_UnaryOp( self, expr: ast.UnaryOp ) -> Union[RegisterTensor, hidet_ir.Node, hidet_ir.BaseType, float, int, str]: + if ( + isinstance(expr.op, ast.Invert) + and isinstance(expr.operand, ast.Subscript) + and isinstance(expr.operand.value, ast.Name) + ): + # handle the following syntax specially + # ~tensor[i, j, ...] + # which gets the address of an element in global/shared tensor + buf = self.visit(expr.operand.value) + if isinstance(buf, (GlobalTensor, SharedTensor)): + indices = self.visit(expr.operand.slice) + if not isinstance(indices, Sequence): + indices = [indices] + if len(indices) != len(buf.shape): + raise HidetProgramError( + self, + expr.operand, + "Index dimension {} does not match tensor shape {}.".format(len(indices), buf.shape), + ) + sb = StmtBuilder() + ptr = sb.tensor_element_ptr(buf, indices, space="generic") + self.current_scope.append(sb.flush_stmts()) + return ptr + elif isinstance(buf, RegisterTensor): + raise ValueError("Can not addressing the element of a RegisterTensor.") + value = self.visit(expr.operand) if isinstance(value, RegisterTensor): if isinstance(expr.op, ast.UAdd): From 7627b5944b593acaa1d0f03f4cd4f53f64c6464b Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 6 Sep 2025 19:26:16 -0400 Subject: [PATCH 7/7] fix Signed-off-by: Yaoyao Ding --- scripts/sign-commits.sh | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) mode change 100644 => 100755 scripts/sign-commits.sh diff --git a/scripts/sign-commits.sh b/scripts/sign-commits.sh old mode 100644 new mode 100755 index dac5f873..de7fd239 --- a/scripts/sign-commits.sh +++ b/scripts/sign-commits.sh @@ -10,22 +10,22 @@ git fetch origin BASE=$(git merge-base HEAD origin/$MAIN_BRANCH) -# List commits after the common ancestor missing "Signed-off-by" -UNSIGNED_COMMITS=$(git rev-list $BASE..HEAD | while read commit; do - if ! git show --quiet --format=%B $commit | grep -q "Signed-off-by:"; then - echo $commit - fi -done) - -if [ -z "$UNSIGNED_COMMITS" ]; then - echo "No unsigned commits to sign off." - exit 0 -fi +## List commits after the common ancestor missing "Signed-off-by" +#UNSIGNED_COMMITS=$(git rev-list $BASE..HEAD | while read commit; do +# if ! git show --quiet --format=%B $commit | grep -q "Signed-off-by:"; then +# echo $commit +# fi +#done) +# +#if [ -z "$UNSIGNED_COMMITS" ]; then +# echo "No unsigned commits to sign off." +# exit 0 +#fi # Rebase with signoff git rebase --signoff $BASE -echo "Newly signed commits:" -for commit in $UNSIGNED_COMMITS; do - git log --format="* %h %s" -n 1 $commit -done \ No newline at end of file +#echo "Newly signed commits:" +#for commit in $UNSIGNED_COMMITS; do +# git log --format="* %h %s" -n 1 $commit +#done \ No newline at end of file