Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions python/tilus/backends/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
LetStmt,
ReturnStmt,
SeqStmt,
TensorPtrStmt,
TensorElemPtrStmt,
TensorElemValueStmt,
WhileStmt,
)
from tilus.ir.tensor import GlobalTensor, RegisterTensor, SharedTensor, Tensor
Expand Down Expand Up @@ -472,8 +473,8 @@ def visit_Function(self, func: Function) -> IRModule:
if self.smem_workspace:
self.free_shared_value(self.smem_workspace)
self.smem_workspace = None
if self.smem_allocator.allocated != 0:
raise ValueError("Shared memory is not properly allocated/freed")
# if self.smem_allocator.allocated != 0:
# raise ValueError("Shared memory is not properly allocated/freed")
if self.smem_allocator.maximum_allocated > get_current_target().properties.shared_memory_per_block:
raise CodeGenerationFailed(
"Request shared memory {} bytes, but the device only allows {} bytes.".format(
Expand Down Expand Up @@ -545,19 +546,33 @@ def visit_LetStmt(self, stmt: LetStmt) -> None:
def visit_AssignStmt(self, stmt: AssignStmt) -> None:
self.builder.assign(stmt.var, value=stmt.value)

def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> None:
def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> None:
if stmt.space in ["generic", "global"]:
self.builder.declare(stmt.ptr_var, self.tensor2var[stmt.tensor])
if stmt.space == "generic":
assert isinstance(stmt.tensor, (GlobalTensor, SharedTensor))
else:
assert isinstance(stmt.tensor, GlobalTensor)
ptr = self.tensor2var[stmt.tensor]
if stmt.indices is not None:
ptr = ptr + stmt.tensor.layout(*stmt.indices)
self.builder.declare(stmt.ptr_var, ptr)
elif stmt.space == "local":
raise NotImplementedError("Local tensor pointer is not supported yet.")
elif stmt.space == "shared":
if not isinstance(stmt.tensor, SharedTensor):
raise ValueError("Expected a SharedTensor for shared tensor pointer, got: {}".format(stmt.tensor))
shared_tensor: SharedTensor = stmt.tensor
self.builder.declare(stmt.ptr_var, self.shared_tensor_shared_space_addr[shared_tensor])
addr = self.shared_tensor_shared_space_addr[shared_tensor]
if stmt.indices is not None:
addr = addr + shared_tensor.layout(*stmt.indices) * shared_tensor.dtype.nbytes
self.builder.declare(stmt.ptr_var, addr)
else:
raise ValueError("Unknown tensor pointer space: {}".format(stmt.space))

def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> None:
assert isinstance(stmt.tensor, (GlobalTensor, SharedTensor))
self.builder.declare(stmt.var, init=self.tensor2var[stmt.tensor][stmt.tensor.layout(*stmt.indices)])

def visit_ReturnStmt(self, stmt: ReturnStmt) -> None:
self.builder.ret()

Expand Down
13 changes: 12 additions & 1 deletion python/tilus/backends/emitters/gmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from hidet.ir.expr import Expr

from tilus.backends.codegen import BaseInstEmitter, register_emitter
from tilus.ir.instructions import AllocateGlobalInst, GlobalViewInst
from tilus.ir import GlobalTensor
from tilus.ir.instructions import AllocateGlobalInst, GlobalSliceInst, GlobalViewInst
from tilus.utils import cdiv


Expand All @@ -34,3 +35,13 @@ def emit(self, inst: AllocateGlobalInst) -> None:
)
var = self.get_or_allocate_var(tensor)
self.assign(var, ptr)


@register_emitter(GlobalSliceInst)
class GlobalSliceInstEmitter(BaseInstEmitter):
def emit(self, inst: GlobalSliceInst) -> None:
input_tensor: GlobalTensor = inst.global_input
output_tensor: GlobalTensor = inst.global_output
slice_offset = input_tensor.layout(*inst.offsets)
output_var = self.get_or_allocate_var(output_tensor)
self.assign(output_var, ~self.tensor2var[input_tensor][slice_offset])
4 changes: 4 additions & 0 deletions python/tilus/ir/analyzers/scalar_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def is_constant(self) -> bool:
def empty_set() -> ScalarSet:
return ScalarSet(lower_bound=0, upper_bound=-1)

@staticmethod
def universal_set() -> ScalarSet:
return ScalarSet(lower_bound=None, upper_bound=None, divisibility=1)

def __eq__(self, other: ScalarSet) -> bool:
if self.is_empty() and other.is_empty():
return True
Expand Down
47 changes: 40 additions & 7 deletions python/tilus/ir/builders/stmt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ExitInst,
FormatPrintInst,
FreeSharedInst,
GlobalSliceInst,
GlobalViewInst,
LoadGlobalGenericInst,
LoadGlobalInst,
Expand Down Expand Up @@ -87,7 +88,8 @@
InstStmt,
SeqStmt,
Stmt,
TensorPtrStmt,
TensorElemPtrStmt,
TensorElemValueStmt,
WhileStmt,
)
from tilus.ir.tensor import GlobalTensor, RegisterTensor, SharedLayout, SharedTensor, Tensor
Expand Down Expand Up @@ -285,22 +287,36 @@ def brk(self):
stmt = BreakStmt()
self._stack[-1].append(stmt)

def declare(self, type: BaseType, init: Optional[Expr | float | int] = None) -> Var:
var = Var("v", type=type)
def declare(self, type: BaseType, init: Optional[Expr | float | int] = None, hint: Optional[str] = None) -> Var:
if hint is not None:
hint = "v"
var = Var(hint, type=type)
self.append(DeclareStmt(var, as_expr(init) if init is not None else None))
return var

def assign(self, var: Var, value: Expr) -> None:
self.append(AssignStmt(var, value))

def tensor_ptr(self, tensor: Tensor, space: str = "generic") -> Var:
return self.tensor_element_ptr(tensor, indices=None, space=space)

def tensor_element_ptr(
self, tensor: Tensor, indices: Optional[Sequence[Expr | int]] = None, space: str = "generic"
) -> Var:
if space in ["generic", "global"]:
ptr_var = Var("v", type=~tensor.dtype)
ptr_var = Var("ptr", type=~tensor.dtype)
else:
ptr_var = Var("v", int32)
self.append(TensorPtrStmt(ptr_var, tensor, space=space))
ptr_var = Var("ptr", int32)
if indices is not None:
indices = tuple(as_expr(e) for e in indices)
self.append(TensorElemPtrStmt(ptr_var, tensor, indices=indices, space=space))
return ptr_var

def tensor_element_value(self, tensor: Tensor, indices: Sequence[Expr | int]) -> Var:
var = Var("val", type=tensor.dtype)
self.append(TensorElemValueStmt(var, tensor, indices=tuple(as_expr(e) for e in indices)))
return var

def append(self, inst_or_stmt: Union[Instruction, Stmt]) -> None:
if isinstance(inst_or_stmt, Instruction):
stmt: Stmt = InstStmt(inst_or_stmt)
Expand Down Expand Up @@ -364,6 +380,23 @@ def allocate_global(
self.append(inst)
return inst.global_output

def slice_global(
self,
tensor: GlobalTensor,
offsets: Sequence[Expr | int],
slice_dims: Sequence[int],
slice_shape: Sequence[Expr | int],
) -> GlobalTensor:
offsets_ = [as_expr(offset) for offset in offsets]
inst = GlobalSliceInst.create(
tensor=tensor,
offsets=offsets_,
dims=slice_dims,
shape=slice_shape,
)
self.append(inst)
return inst.global_output

def assign_register(self, output: RegisterTensor, x: RegisterTensor) -> None:
inst = AssignInst.create(output, x)
self.append(inst)
Expand Down Expand Up @@ -722,7 +755,7 @@ def free_shared(self, shared_value: SharedTensor) -> None:
inst = FreeSharedInst.create(shared_value)
self.append(inst)

def shared_slice(
def slice_shared(
self,
tensor: SharedTensor,
offsets: Sequence[Expr | int],
Expand Down
37 changes: 29 additions & 8 deletions python/tilus/ir/functors/functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
ReturnStmt,
SeqStmt,
Stmt,
TensorPtrStmt,
TensorElemPtrStmt,
TensorElemValueStmt,
WhileStmt,
)
from tilus.ir.tensor import GlobalTensor, RegisterTensor, SharedLayout, SharedTensor
Expand Down Expand Up @@ -105,8 +106,10 @@ def visit(self, node):
ret = self.visit_LetStmt(node)
elif isinstance(node, AssignStmt):
ret = self.visit_AssignStmt(node)
elif isinstance(node, TensorPtrStmt):
ret = self.visit_TensorPtrStmt(node)
elif isinstance(node, TensorElemPtrStmt):
ret = self.visit_TensorElemPtrStmt(node)
elif isinstance(node, TensorElemValueStmt):
ret = self.visit_TensorElemValueStmt(node)
# scalar expression and type
elif isinstance(node, Expr):
ret = self.visit_Expr(node)
Expand Down Expand Up @@ -205,7 +208,10 @@ def visit_LetStmt(self, stmt: LetStmt) -> Any:
def visit_AssignStmt(self, stmt: AssignStmt) -> Any:
raise NotImplementedError()

def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> Any:
def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> Any:
raise NotImplementedError()

def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> Any:
raise NotImplementedError()

# tensors and layouts
Expand Down Expand Up @@ -349,12 +355,21 @@ def visit_AssignStmt(self, stmt: AssignStmt) -> Stmt:
else:
return AssignStmt(stmt.var, value)

def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> Stmt:
def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> Stmt:
tensor = self.visit(stmt.tensor)
if tensor is stmt.tensor:
indices = self.visit(stmt.indices)
if tensor is stmt.tensor and indices is stmt.indices:
return stmt
else:
return TensorPtrStmt(stmt.ptr_var, tensor, stmt.space)
return TensorElemPtrStmt(stmt.ptr_var, tensor, indices, stmt.space)

def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> Stmt:
tensor = self.visit(stmt.tensor)
indices = self.visit(stmt.indices)
if tensor is stmt.tensor and indices is stmt.indices:
return stmt
else:
return TensorElemValueStmt(stmt.var, tensor, indices)

def visit_WhileStmt(self, stmt: WhileStmt) -> Stmt:
cond = self.visit(stmt.cond)
Expand Down Expand Up @@ -494,9 +509,15 @@ def visit_AssignStmt(self, stmt: AssignStmt) -> None:
self.visit(stmt.var)
self.visit(stmt.value)

def visit_TensorPtrStmt(self, stmt: TensorPtrStmt) -> None:
def visit_TensorElemPtrStmt(self, stmt: TensorElemPtrStmt) -> None:
self.visit(stmt.ptr_var)
self.visit(stmt.tensor)
self.visit(stmt.indices)

def visit_TensorElemValueStmt(self, stmt: TensorElemValueStmt) -> None:
self.visit(stmt.var)
self.visit(stmt.tensor)
self.visit(stmt.indices)

# values

Expand Down
7 changes: 7 additions & 0 deletions python/tilus/ir/inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def shared_input(self) -> SharedTensor:
assert isinstance(x, SharedTensor)
return x

@property
def global_input(self) -> GlobalTensor:
assert len(self.inputs) == 1
x = self.inputs[0]
assert isinstance(x, GlobalTensor)
return x

@property
def attributes(self) -> dict[str, Any]:
attrs = {}
Expand Down
1 change: 1 addition & 0 deletions python/tilus/ir/instructions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ExitInst,
FormatPrintInst,
FreeSharedInst,
GlobalSliceInst,
GlobalViewInst,
LoadGlobalGenericInst,
LoadGlobalInst,
Expand Down
25 changes: 24 additions & 1 deletion python/tilus/ir/instructions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,29 @@ def create(dst: GlobalTensor, x: RegisterTensor, offsets: Sequence[Expr], dims:
return StoreGlobalInst(output=None, inputs=(dst, x), offsets=tuple(offsets), dims=tuple(dims))


@dataclass(frozen=True, eq=False)
class GlobalSliceInst(Instruction):
offsets: tuple[Expr, ...]
dims: Optional[tuple[int, ...]]

@staticmethod
def create(
tensor: GlobalTensor,
offsets: Sequence[Expr],
dims: Sequence[int],
shape: Sequence[Expr | int],
) -> SharedSliceInst:
from tilus.ir.layout.global_layout import global_slice

output = GlobalTensor.create(dtype=tensor.dtype, layout=global_slice(tensor.layout, offsets, dims, shape))
return SharedSliceInst(
output=output,
inputs=(tensor,),
offsets=tuple(offsets),
dims=tuple(dims) if len(dims) < len(tensor.shape) else None,
)


@dataclass(frozen=True, eq=False)
class LoadSharedInst(Instruction):
@staticmethod
Expand Down Expand Up @@ -103,7 +126,7 @@ def create(
output=output,
inputs=(tensor,),
offsets=tuple(offsets),
dims=tuple(dims) if len(dims) < len(tensor.shape) else None,
dims=tuple(dims) if len(dims) < len(tensor.shape) else tuple(range(len(tensor.shape))),
)


Expand Down
40 changes: 40 additions & 0 deletions python/tilus/ir/layout/global_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,43 @@ def f_offset(axes: Sequence[Var]) -> Expr:
return sum([axes[i] * strides[i] for i in range(len(shape))], start=int32.zero)

return GlobalLayout.create(shape=shape, size=prod(shape), f_offset=f_offset)


def global_slice(
layout: GlobalLayout, offsets: Sequence[Expr | int], dims: Sequence[int], shape: Sequence[Expr | int]
) -> GlobalLayout:
"""Create a sliced global layout from an existing layout.

This function creates a new global layout by slicing an existing global layout. The slicing is defined by the
specified offsets, dimensions to slice, and the shape of the resulting layout. The new layout retains the mapping
function of the original layout, adjusted for the specified offsets and dimensions.

Parameters
----------
layout: GlobalLayout
The original global layout to be sliced.
offsets: Sequence[Expr | int]
The offsets for each dimension of the original layout. It should have the same length as the original layout's
shape.
dims: Sequence[int]
The dimensions to be sliced from the original layout. Each dimension should be a valid index in the original
layout's shape.
shape: Sequence[Expr | int]
The shape of the resulting sliced global layout. It should have the same length as the number of dimensions
specified in `dims`.

Returns
-------
ret: GlobalLayout
A new global layout that represents the sliced version of the original layout, with the specified shape and
adjusted mapping function.
"""
assert len(dims) == len(shape) <= len(layout.shape) == len(offsets)

def f_offset(axes: Sequence[Var]) -> Expr:
indices = list(offsets)
for dim, axis in zip(dims, axes):
indices[dim] = axis + offsets[dim]
return layout(*indices) - layout(*offsets) # type: ignore[arg-type]

return GlobalLayout.create(shape=shape, size=prod(shape), f_offset=f_offset)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
allocate_shared,
assign,
cp_async,
elementwise_binary,
Expand Down
Loading
Loading