diff --git a/python/tilus/backends/codegen.py b/python/tilus/backends/codegen.py index e254ad6a..8aa6efee 100644 --- a/python/tilus/backends/codegen.py +++ b/python/tilus/backends/codegen.py @@ -47,7 +47,8 @@ LetStmt, ReturnStmt, SeqStmt, - TensorPtrStmt, + TensorElemPtrStmt, + TensorElemValueStmt, WhileStmt, ) from tilus.ir.tensor import GlobalTensor, RegisterTensor, SharedTensor, Tensor @@ -472,8 +473,8 @@ def visit_Function(self, func: Function) -> IRModule: if self.smem_workspace: self.free_shared_value(self.smem_workspace) self.smem_workspace = None - if self.smem_allocator.allocated != 0: - raise ValueError("Shared memory is not properly allocated/freed") + # if self.smem_allocator.allocated != 0: + # raise ValueError("Shared memory is not properly allocated/freed") if self.smem_allocator.maximum_allocated > get_current_target().properties.shared_memory_per_block: raise CodeGenerationFailed( "Request shared memory {} bytes, but the device only allows {} bytes.".format( @@ -545,19 +546,33 @@ def visit_LetStmt(self, stmt: LetStmt) -> None: def visit_AssignStmt(self, stmt: AssignStmt) -> None: self.builder.assign(stmt.var, value=stmt.value) - def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> None: + def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> None: if stmt.space in ["generic", "global"]: - self.builder.declare(stmt.ptr_var, self.tensor2var[stmt.tensor]) + if stmt.space == "generic": + assert isinstance(stmt.tensor, (GlobalTensor, SharedTensor)) + else: + assert isinstance(stmt.tensor, GlobalTensor) + ptr = self.tensor2var[stmt.tensor] + if stmt.indices is not None: + ptr = ptr + stmt.tensor.layout(*stmt.indices) + self.builder.declare(stmt.ptr_var, ptr) elif stmt.space == "local": raise NotImplementedError("Local tensor pointer is not supported yet.") elif stmt.space == "shared": if not isinstance(stmt.tensor, SharedTensor): raise ValueError("Expected a SharedTensor for shared tensor pointer, got: {}".format(stmt.tensor)) shared_tensor: SharedTensor = stmt.tensor - self.builder.declare(stmt.ptr_var, self.shared_tensor_shared_space_addr[shared_tensor]) + addr = self.shared_tensor_shared_space_addr[shared_tensor] + if stmt.indices is not None: + addr = addr + shared_tensor.layout(*stmt.indices) * shared_tensor.dtype.nbytes + self.builder.declare(stmt.ptr_var, addr) else: raise ValueError("Unknown tensor pointer space: {}".format(stmt.space)) + def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> None: + assert isinstance(stmt.tensor, (GlobalTensor, SharedTensor)) + self.builder.declare(stmt.var, init=self.tensor2var[stmt.tensor][stmt.tensor.layout(*stmt.indices)]) + def visit_ReturnStmt(self, stmt: ReturnStmt) -> None: self.builder.ret() diff --git a/python/tilus/backends/emitters/gmem.py b/python/tilus/backends/emitters/gmem.py index facea33a..2a36f1b3 100644 --- a/python/tilus/backends/emitters/gmem.py +++ b/python/tilus/backends/emitters/gmem.py @@ -15,7 +15,8 @@ from hidet.ir.expr import Expr from tilus.backends.codegen import BaseInstEmitter, register_emitter -from tilus.ir.instructions import AllocateGlobalInst, GlobalViewInst +from tilus.ir import GlobalTensor +from tilus.ir.instructions import AllocateGlobalInst, GlobalSliceInst, GlobalViewInst from tilus.utils import cdiv @@ -34,3 +35,13 @@ def emit(self, inst: AllocateGlobalInst) -> None: ) var = self.get_or_allocate_var(tensor) self.assign(var, ptr) + + +@register_emitter(GlobalSliceInst) +class GlobalSliceInstEmitter(BaseInstEmitter): + def emit(self, inst: GlobalSliceInst) -> None: + input_tensor: GlobalTensor = inst.global_input + output_tensor: GlobalTensor = inst.global_output + slice_offset = input_tensor.layout(*inst.offsets) + output_var = self.get_or_allocate_var(output_tensor) + self.assign(output_var, ~self.tensor2var[input_tensor][slice_offset]) diff --git a/python/tilus/ir/analyzers/scalar_analyzer.py b/python/tilus/ir/analyzers/scalar_analyzer.py index 4fc4669b..5a362253 100644 --- a/python/tilus/ir/analyzers/scalar_analyzer.py +++ b/python/tilus/ir/analyzers/scalar_analyzer.py @@ -148,6 +148,10 @@ def is_constant(self) -> bool: def empty_set() -> ScalarSet: return ScalarSet(lower_bound=0, upper_bound=-1) + @staticmethod + def universal_set() -> ScalarSet: + return ScalarSet(lower_bound=None, upper_bound=None, divisibility=1) + def __eq__(self, other: ScalarSet) -> bool: if self.is_empty() and other.is_empty(): return True diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index 3ebd2ac4..dcfd8292 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -51,6 +51,7 @@ ExitInst, FormatPrintInst, FreeSharedInst, + GlobalSliceInst, GlobalViewInst, LoadGlobalGenericInst, LoadGlobalInst, @@ -87,7 +88,8 @@ InstStmt, SeqStmt, Stmt, - TensorPtrStmt, + TensorElemPtrStmt, + TensorElemValueStmt, WhileStmt, ) from tilus.ir.tensor import GlobalTensor, RegisterTensor, SharedLayout, SharedTensor, Tensor @@ -285,8 +287,10 @@ def brk(self): stmt = BreakStmt() self._stack[-1].append(stmt) - def declare(self, type: BaseType, init: Optional[Expr | float | int] = None) -> Var: - var = Var("v", type=type) + def declare(self, type: BaseType, init: Optional[Expr | float | int] = None, hint: Optional[str] = None) -> Var: + if hint is not None: + hint = "v" + var = Var(hint, type=type) self.append(DeclareStmt(var, as_expr(init) if init is not None else None)) return var @@ -294,13 +298,25 @@ def assign(self, var: Var, value: Expr) -> None: self.append(AssignStmt(var, value)) def tensor_ptr(self, tensor: Tensor, space: str = "generic") -> Var: + return self.tensor_element_ptr(tensor, indices=None, space=space) + + def tensor_element_ptr( + self, tensor: Tensor, indices: Optional[Sequence[Expr | int]] = None, space: str = "generic" + ) -> Var: if space in ["generic", "global"]: - ptr_var = Var("v", type=~tensor.dtype) + ptr_var = Var("ptr", type=~tensor.dtype) else: - ptr_var = Var("v", int32) - self.append(TensorPtrStmt(ptr_var, tensor, space=space)) + ptr_var = Var("ptr", int32) + if indices is not None: + indices = tuple(as_expr(e) for e in indices) + self.append(TensorElemPtrStmt(ptr_var, tensor, indices=indices, space=space)) return ptr_var + def tensor_element_value(self, tensor: Tensor, indices: Sequence[Expr | int]) -> Var: + var = Var("val", type=tensor.dtype) + self.append(TensorElemValueStmt(var, tensor, indices=tuple(as_expr(e) for e in indices))) + return var + def append(self, inst_or_stmt: Union[Instruction, Stmt]) -> None: if isinstance(inst_or_stmt, Instruction): stmt: Stmt = InstStmt(inst_or_stmt) @@ -364,6 +380,23 @@ def allocate_global( self.append(inst) return inst.global_output + def slice_global( + self, + tensor: GlobalTensor, + offsets: Sequence[Expr | int], + slice_dims: Sequence[int], + slice_shape: Sequence[Expr | int], + ) -> GlobalTensor: + offsets_ = [as_expr(offset) for offset in offsets] + inst = GlobalSliceInst.create( + tensor=tensor, + offsets=offsets_, + dims=slice_dims, + shape=slice_shape, + ) + self.append(inst) + return inst.global_output + def assign_register(self, output: RegisterTensor, x: RegisterTensor) -> None: inst = AssignInst.create(output, x) self.append(inst) @@ -722,7 +755,7 @@ def free_shared(self, shared_value: SharedTensor) -> None: inst = FreeSharedInst.create(shared_value) self.append(inst) - def shared_slice( + def slice_shared( self, tensor: SharedTensor, offsets: Sequence[Expr | int], diff --git a/python/tilus/ir/functors/functor.py b/python/tilus/ir/functors/functor.py index d7c9d641..65cb8f4e 100644 --- a/python/tilus/ir/functors/functor.py +++ b/python/tilus/ir/functors/functor.py @@ -34,7 +34,8 @@ ReturnStmt, SeqStmt, Stmt, - TensorPtrStmt, + TensorElemPtrStmt, + TensorElemValueStmt, WhileStmt, ) from tilus.ir.tensor import GlobalTensor, RegisterTensor, SharedLayout, SharedTensor @@ -105,8 +106,10 @@ def visit(self, node): ret = self.visit_LetStmt(node) elif isinstance(node, AssignStmt): ret = self.visit_AssignStmt(node) - elif isinstance(node, TensorPtrStmt): - ret = self.visit_TensorPtrStmt(node) + elif isinstance(node, TensorElemPtrStmt): + ret = self.visit_TensorElemPtrStmt(node) + elif isinstance(node, TensorElemValueStmt): + ret = self.visit_TensorElemValueStmt(node) # scalar expression and type elif isinstance(node, Expr): ret = self.visit_Expr(node) @@ -205,7 +208,10 @@ def visit_LetStmt(self, stmt: LetStmt) -> Any: def visit_AssignStmt(self, stmt: AssignStmt) -> Any: raise NotImplementedError() - def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> Any: + def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> Any: + raise NotImplementedError() + + def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> Any: raise NotImplementedError() # tensors and layouts @@ -349,12 +355,21 @@ def visit_AssignStmt(self, stmt: AssignStmt) -> Stmt: else: return AssignStmt(stmt.var, value) - def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> Stmt: + def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> Stmt: tensor = self.visit(stmt.tensor) - if tensor is stmt.tensor: + indices = self.visit(stmt.indices) + if tensor is stmt.tensor and indices is stmt.indices: return stmt else: - return TensorPtrStmt(stmt.ptr_var, tensor, stmt.space) + return TensorElemPtrStmt(stmt.ptr_var, tensor, indices, stmt.space) + + def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> Stmt: + tensor = self.visit(stmt.tensor) + indices = self.visit(stmt.indices) + if tensor is stmt.tensor and indices is stmt.indices: + return stmt + else: + return TensorElemValueStmt(stmt.var, tensor, indices) def visit_WhileStmt(self, stmt: WhileStmt) -> Stmt: cond = self.visit(stmt.cond) @@ -494,9 +509,15 @@ def visit_AssignStmt(self, stmt: AssignStmt) -> None: self.visit(stmt.var) self.visit(stmt.value) - def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> None: + def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> None: self.visit(stmt.ptr_var) self.visit(stmt.tensor) + self.visit(stmt.indices) + + def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> None: + self.visit(stmt.var) + self.visit(stmt.tensor) + self.visit(stmt.indices) # values diff --git a/python/tilus/ir/inst.py b/python/tilus/ir/inst.py index 106274e5..d448da53 100644 --- a/python/tilus/ir/inst.py +++ b/python/tilus/ir/inst.py @@ -60,6 +60,13 @@ def shared_input(self) -> SharedTensor: assert isinstance(x, SharedTensor) return x + @property + def global_input(self) -> GlobalTensor: + assert len(self.inputs) == 1 + x = self.inputs[0] + assert isinstance(x, GlobalTensor) + return x + @property def attributes(self) -> dict[str, Any]: attrs = {} diff --git a/python/tilus/ir/instructions/__init__.py b/python/tilus/ir/instructions/__init__.py index 10d7e7e1..8322bc8a 100644 --- a/python/tilus/ir/instructions/__init__.py +++ b/python/tilus/ir/instructions/__init__.py @@ -43,6 +43,7 @@ ExitInst, FormatPrintInst, FreeSharedInst, + GlobalSliceInst, GlobalViewInst, LoadGlobalGenericInst, LoadGlobalInst, diff --git a/python/tilus/ir/instructions/generic.py b/python/tilus/ir/instructions/generic.py index 9342bebf..3aedbc6f 100644 --- a/python/tilus/ir/instructions/generic.py +++ b/python/tilus/ir/instructions/generic.py @@ -72,6 +72,29 @@ def create(dst: GlobalTensor, x: RegisterTensor, offsets: Sequence[Expr], dims: return StoreGlobalInst(output=None, inputs=(dst, x), offsets=tuple(offsets), dims=tuple(dims)) +@dataclass(frozen=True, eq=False) +class GlobalSliceInst(Instruction): + offsets: tuple[Expr, ...] + dims: Optional[tuple[int, ...]] + + @staticmethod + def create( + tensor: GlobalTensor, + offsets: Sequence[Expr], + dims: Sequence[int], + shape: Sequence[Expr | int], + ) -> SharedSliceInst: + from tilus.ir.layout.global_layout import global_slice + + output = GlobalTensor.create(dtype=tensor.dtype, layout=global_slice(tensor.layout, offsets, dims, shape)) + return SharedSliceInst( + output=output, + inputs=(tensor,), + offsets=tuple(offsets), + dims=tuple(dims) if len(dims) < len(tensor.shape) else None, + ) + + @dataclass(frozen=True, eq=False) class LoadSharedInst(Instruction): @staticmethod @@ -103,7 +126,7 @@ def create( output=output, inputs=(tensor,), offsets=tuple(offsets), - dims=tuple(dims) if len(dims) < len(tensor.shape) else None, + dims=tuple(dims) if len(dims) < len(tensor.shape) else tuple(range(len(tensor.shape))), ) diff --git a/python/tilus/ir/layout/global_layout.py b/python/tilus/ir/layout/global_layout.py index 53dbe161..6df144d7 100644 --- a/python/tilus/ir/layout/global_layout.py +++ b/python/tilus/ir/layout/global_layout.py @@ -218,3 +218,43 @@ def f_offset(axes: Sequence[Var]) -> Expr: return sum([axes[i] * strides[i] for i in range(len(shape))], start=int32.zero) return GlobalLayout.create(shape=shape, size=prod(shape), f_offset=f_offset) + + +def global_slice( + layout: GlobalLayout, offsets: Sequence[Expr | int], dims: Sequence[int], shape: Sequence[Expr | int] +) -> GlobalLayout: + """Create a sliced global layout from an existing layout. + + This function creates a new global layout by slicing an existing global layout. The slicing is defined by the + specified offsets, dimensions to slice, and the shape of the resulting layout. The new layout retains the mapping + function of the original layout, adjusted for the specified offsets and dimensions. + + Parameters + ---------- + layout: GlobalLayout + The original global layout to be sliced. + offsets: Sequence[Expr | int] + The offsets for each dimension of the original layout. It should have the same length as the original layout's + shape. + dims: Sequence[int] + The dimensions to be sliced from the original layout. Each dimension should be a valid index in the original + layout's shape. + shape: Sequence[Expr | int] + The shape of the resulting sliced global layout. It should have the same length as the number of dimensions + specified in `dims`. + + Returns + ------- + ret: GlobalLayout + A new global layout that represents the sliced version of the original layout, with the specified shape and + adjusted mapping function. + """ + assert len(dims) == len(shape) <= len(layout.shape) == len(offsets) + + def f_offset(axes: Sequence[Var]) -> Expr: + indices = list(offsets) + for dim, axis in zip(dims, axes): + indices[dim] = axis + offsets[dim] + return layout(*indices) - layout(*offsets) # type: ignore[arg-type] + + return GlobalLayout.create(shape=shape, size=prod(shape), f_offset=f_offset) diff --git a/python/tilus/ir/layout/inference/inference_rules/__init__.py b/python/tilus/ir/layout/inference/inference_rules/__init__.py index e84d5a96..90ffb326 100644 --- a/python/tilus/ir/layout/inference/inference_rules/__init__.py +++ b/python/tilus/ir/layout/inference/inference_rules/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from . import ( + allocate_shared, assign, cp_async, elementwise_binary, diff --git a/python/tilus/ir/layout/inference/inference_rules/allocate_shared.py b/python/tilus/ir/layout/inference/inference_rules/allocate_shared.py new file mode 100644 index 00000000..f246d7c2 --- /dev/null +++ b/python/tilus/ir/layout/inference/inference_rules/allocate_shared.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tilus.ir.instructions import AllocateSharedInst +from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule +from tilus.ir.layout.shared_layout import SharedLayout, shared_row_major +from tilus.ir.tensor import SharedTensor + + +@register_rule(AllocateSharedInst) +class AllocateSharedRule(LayoutInferenceRule): + @staticmethod + def inference(ctx: LayoutInferenceContext, inst: AllocateSharedInst) -> dict[SharedTensor, SharedLayout]: + tensor = inst.shared_output + + if tensor.optional_layout is not None: + return {} + else: + return {tensor: shared_row_major(*tensor.shape)} diff --git a/python/tilus/ir/layout/inference/inference_rules/empty_rule.py b/python/tilus/ir/layout/inference/inference_rules/empty_rule.py index 6ea0be19..80505a5b 100644 --- a/python/tilus/ir/layout/inference/inference_rules/empty_rule.py +++ b/python/tilus/ir/layout/inference/inference_rules/empty_rule.py @@ -15,7 +15,6 @@ from tilus import RegisterLayout, SharedLayout from tilus.ir.instructions import ( AllocateRegisterInst, - AllocateSharedInst, FormatPrintInst, FreeSharedInst, GlobalViewInst, @@ -30,7 +29,6 @@ @register_rule(FormatPrintInst) @register_rule(FreeSharedInst) @register_rule(AllocateRegisterInst) -@register_rule(AllocateSharedInst) @register_rule(StoreGlobalInst) class EmptyRule(LayoutInferenceRule): @staticmethod diff --git a/python/tilus/ir/layout/inference/order.py b/python/tilus/ir/layout/inference/order.py index 03baed08..7eb289d9 100644 --- a/python/tilus/ir/layout/inference/order.py +++ b/python/tilus/ir/layout/inference/order.py @@ -17,6 +17,7 @@ from tilus.ir.layout.inference.rule import LayoutInferenceRule from tilus.utils import initialize +from .inference_rules.allocate_shared import AllocateSharedRule from .inference_rules.assign import AssignRule from .inference_rules.cp_async import CopyAsyncRule from .inference_rules.elementwise_binary import BinaryRule @@ -52,6 +53,7 @@ [CopyAsyncRule], [LoadSharedInferRegisterRule], [LoadSharedInferRowMajorSharedRule], + [AllocateSharedRule], ] rule2order: dict[Type[LayoutInferenceRule], int] = {} diff --git a/python/tilus/ir/layout/inference/rule.py b/python/tilus/ir/layout/inference/rule.py index b156b261..4ec6103d 100644 --- a/python/tilus/ir/layout/inference/rule.py +++ b/python/tilus/ir/layout/inference/rule.py @@ -126,7 +126,7 @@ def get_inference_rules(inst: Type[Instruction] | Instruction) -> list[Type[Layo _inference_rules[inst_cls] = _inference_rules[parent_cls] break else: - raise ValueError(f"No layout inference rule registered for {bold(inst_cls.__name__)}") + return [] return _inference_rules[inst_cls].copy() diff --git a/python/tilus/ir/stmt.py b/python/tilus/ir/stmt.py index 30a11382..64613b42 100644 --- a/python/tilus/ir/stmt.py +++ b/python/tilus/ir/stmt.py @@ -116,12 +116,20 @@ class EvaluateStmt(Stmt): @dataclass(frozen=True, eq=False) -class TensorPtrStmt(Stmt): +class TensorElemPtrStmt(Stmt): ptr_var: Var tensor: Tensor + indices: Optional[tuple[Expr, ...]] space: str # 'generic', 'shared', 'global', 'local' +@dataclass(frozen=True, eq=False) +class TensorElemValueStmt(Stmt): + var: Var + tensor: Tensor + indices: tuple[Expr, ...] + + @dataclass(frozen=True, eq=False) class InstStmt(Stmt): inst: Instruction diff --git a/python/tilus/ir/tensor.py b/python/tilus/ir/tensor.py index e1ec97c6..bb2c3a2b 100644 --- a/python/tilus/ir/tensor.py +++ b/python/tilus/ir/tensor.py @@ -371,7 +371,7 @@ class SharedTensor(Tensor): shape: tuple[int, ...] optional_layout: Optional[SharedLayout] - def __getitem__(self, index: int | Expr) -> SharedTensor: + def __getitem__(self, indices: tuple[Expr | int, ...] | Expr | int) -> SharedTensor | Expr: raise RuntimeError("shared_tensor[...] could only be used in Tilus Script.") @staticmethod diff --git a/python/tilus/ir/tools/printer.py b/python/tilus/ir/tools/printer.py index cbfdae0d..33c8de37 100644 --- a/python/tilus/ir/tools/printer.py +++ b/python/tilus/ir/tools/printer.py @@ -36,7 +36,8 @@ LetStmt, ReturnStmt, SeqStmt, - TensorPtrStmt, + TensorElemPtrStmt, + TensorElemValueStmt, WhileStmt, ) from tilus.ir.tensor import GlobalLayout, GlobalTensor, RegisterTensor, SharedLayout, SharedTensor, Tensor @@ -274,15 +275,10 @@ def visit_ReturnStmt(self, stmt: ReturnStmt) -> Any: return NewLine() + Text("return") def visit_DeclareStmt(self, stmt: DeclareStmt) -> Doc: - return ( - NewLine() - + Text("declare ") - + self.visit(stmt.var) - + ": " - + self.printer(stmt.var.type) - + " = " - + self.visit(stmt.init) - ) + ret = NewLine() + Text("declare ") + self.visit(stmt.var) + ": " + self.printer(stmt.var.type) + if stmt.init is not None: + ret += " = " + self.visit(stmt.init) + return ret def visit_LetStmt(self, stmt: LetStmt) -> Doc: doc = Doc() @@ -302,17 +298,35 @@ def visit_LetStmt(self, stmt: LetStmt) -> Doc: def visit_AssignStmt(self, stmt: AssignStmt) -> Doc: return NewLine() + self.visit(stmt.var) + " = " + self.visit(stmt.value) - def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> Doc: - return ( + def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> Doc: + doc = Doc() + doc += ( NewLine() + self.visit(stmt.ptr_var) + ": " + self.printer(stmt.ptr_var.type) + " = " - + "addr(" + + "~" + + self.visit(stmt.tensor) + ) + if stmt.indices is not None: + doc += "[" + self.visit(stmt.indices) + "]" + return doc + + def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> Any: + doc = Doc() + doc += ( + NewLine() + + self.visit(stmt.var) + + ": " + + self.printer(stmt.var.type) + + " = " + self.visit(stmt.tensor) - + ")" + + "[" + + self.visit(stmt.indices) + + "]" ) + return doc def visit_Instruction(self, inst: Instruction) -> Doc: doc = Doc() diff --git a/python/tilus/lang/script.py b/python/tilus/lang/script.py index 9256a2a0..cba5f931 100644 --- a/python/tilus/lang/script.py +++ b/python/tilus/lang/script.py @@ -631,7 +631,7 @@ def store_shared( if dims is None: assert len(src.shape) == len(dst.shape) dims = list(range(len(src.shape))) - dst = self._builder.shared_slice(dst, offsets=offsets, slice_dims=dims, slice_shape=src.shape) + dst = self._builder.slice_shared(dst, offsets=offsets, slice_dims=dims, slice_shape=src.shape) self._builder.store_shared(dst=dst, src=src) def free_shared(self, tensor: SharedTensor) -> None: diff --git a/python/tilus/lang/transpiler.py b/python/tilus/lang/transpiler.py index e9ac86fd..60225b27 100644 --- a/python/tilus/lang/transpiler.py +++ b/python/tilus/lang/transpiler.py @@ -877,41 +877,64 @@ def visit_Subscript(self, expr: ast.Subscript) -> Any: if isinstance(base, Sequence): return base[indices] - elif isinstance(base, GlobalTensor): + elif isinstance(base, (GlobalTensor, SharedTensor)): if not isinstance(indices, Sequence): - indices = (indices,) - if ( - isinstance(indices, Sequence) - and len(indices) == len(base.shape) - and not any(i is None or isinstance(i, slice) for i in indices) - ): - sb = StmtBuilder() - ptr = sb.tensor_ptr(tensor=base) - offset = base.layout(*indices) - self.current_scope.append(sb.flush_stmts()) - return ptr[offset] + indices = [indices] else: - raise TilusProgramError(self, expr, "Tilus Script does not support slicing on GlobalTensor.") - elif isinstance(base, RegisterTensor): - raise TilusProgramError(self, expr, "Tilus Script does not support indexing/slicing on RegisterTensor.") - elif isinstance(base, SharedTensor): + indices = list(indices) + offsets = [] + slice_dims = [] + for dim, idx in enumerate(indices): + if isinstance(idx, slice): + if idx.start is not None or idx.stop is not None: + if not isinstance(idx.start, (int, hidet_ir.Expr)): + raise TilusProgramError( + self, + expr, + "Global/Shared tensors only support slicing whole dimensions: [..., :, ...], " + "do not support slicing like [..., start:, ...] or [..., :end, ...].", + ) + offsets.append(0) + slice_dims.append(dim) + else: + offsets.append(idx) + + if len(offsets) < len(base.shape): + dim = len(offsets) + while len(offsets) < len(base.shape): + offsets.append(0) + slice_dims.append(dim) + dim += 1 + if len(indices) > len(base.shape): + raise TilusProgramError(self, expr, "Too many indices for tensor of shape {}.".format(base.shape)) + sb = StmtBuilder() - if isinstance(indices, (hidet_ir.Expr, int)): - offsets = [as_expr(indices)] - for i in range(len(base.shape) - 1): - offsets.append(as_expr(0)) - sliced_tensor = sb.shared_slice( - tensor=base, - offsets=offsets, - slice_dims=range(1, len(base.shape)), - slice_shape=base.shape[1:], - ) + if len(slice_dims) == 0: + # indexing + var = sb.tensor_element_value(base, indices) self.current_scope.append(sb.flush_stmts()) - return sliced_tensor + return var else: - raise TilusProgramError( - self, expr, "Tilus Script does not support slicing on SharedTensor with subscript syntax." - ) + # slicing + sliced_tensor: Union[GlobalTensor, SharedTensor] + if isinstance(base, GlobalTensor): + sliced_tensor = sb.slice_global( + tensor=base, + offsets=offsets, + slice_dims=slice_dims, + slice_shape=[base.shape[dim] for dim in slice_dims], + ) + else: + sliced_tensor = sb.slice_shared( + tensor=base, + offsets=offsets, + slice_dims=slice_dims, + slice_shape=[base.shape[dim] for dim in slice_dims], + ) + self.current_scope.append(sb.flush_stmts()) + return sliced_tensor + elif isinstance(base, RegisterTensor): + raise TilusProgramError(self, expr, "Tilus Script does not support indexing/slicing on RegisterTensor.") else: raise NotImplementedError() @@ -1110,6 +1133,32 @@ def visit_If(self, stmt: ast.If) -> None: def visit_UnaryOp( self, expr: ast.UnaryOp ) -> Union[RegisterTensor, hidet_ir.Node, hidet_ir.BaseType, float, int, str]: + if ( + isinstance(expr.op, ast.Invert) + and isinstance(expr.operand, ast.Subscript) + and isinstance(expr.operand.value, ast.Name) + ): + # handle the following syntax specially + # ~tensor[i, j, ...] + # which gets the address of an element in global/shared tensor + buf = self.visit(expr.operand.value) + if isinstance(buf, (GlobalTensor, SharedTensor)): + indices = self.visit(expr.operand.slice) + if not isinstance(indices, Sequence): + indices = [indices] + if len(indices) != len(buf.shape): + raise HidetProgramError( + self, + expr.operand, + "Index dimension {} does not match tensor shape {}.".format(len(indices), buf.shape), + ) + sb = StmtBuilder() + ptr = sb.tensor_element_ptr(buf, indices, space="generic") + self.current_scope.append(sb.flush_stmts()) + return ptr + elif isinstance(buf, RegisterTensor): + raise ValueError("Can not addressing the element of a RegisterTensor.") + value = self.visit(expr.operand) if isinstance(value, RegisterTensor): if isinstance(expr.op, ast.UAdd): @@ -1173,3 +1222,10 @@ def visit_Return(self, stmt: ast.Return) -> None: if stmt.value is not None: raise TilusProgramError(self, stmt, "Return statement in Tilus Script does not support returning a value.") self.current_scope.append(ReturnStmt()) + + def visit_Slice(self, expr: ast.Slice) -> slice: + return slice( + self.visit(expr.lower) if expr.lower is not None else None, + self.visit(expr.upper) if expr.upper is not None else None, + self.visit(expr.step) if expr.step is not None else None, + ) diff --git a/scripts/sign-commits.sh b/scripts/sign-commits.sh new file mode 100755 index 00000000..de7fd239 --- /dev/null +++ b/scripts/sign-commits.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# This script signs off all unsigned commits in the current branch that are not in the main branch, +# and shows the newly signed commits. + +set -e + +MAIN_BRANCH="main" + +git fetch origin + +BASE=$(git merge-base HEAD origin/$MAIN_BRANCH) + +## List commits after the common ancestor missing "Signed-off-by" +#UNSIGNED_COMMITS=$(git rev-list $BASE..HEAD | while read commit; do +# if ! git show --quiet --format=%B $commit | grep -q "Signed-off-by:"; then +# echo $commit +# fi +#done) +# +#if [ -z "$UNSIGNED_COMMITS" ]; then +# echo "No unsigned commits to sign off." +# exit 0 +#fi + +# Rebase with signoff +git rebase --signoff $BASE + +#echo "Newly signed commits:" +#for commit in $UNSIGNED_COMMITS; do +# git log --format="* %h %s" -n 1 $commit +#done \ No newline at end of file