|
27 | 27 | enter_nested_block, nested_block, PhiState, LoopVarState, |
28 | 28 | TupleValue, make_aggregate, RangeValue, BoundMethodValue, ArrayValue, ConstantState, |
29 | 29 | ListValue, TiledViewValue, ClosureValue, MemoryEffect, attribute, operand, |
30 | | - BlockRestriction, FormattedStringValue, RawArrayMemoryValue, DataclassValue, DataclassInfo |
| 30 | + BlockRestriction, FormattedStringValue, RawArrayMemoryValue, DataclassValue, DataclassInfo, |
| 31 | + IndexSliceValue |
31 | 32 | ) |
32 | 33 | from .type import PointerTy |
33 | 34 | from . import hir, hir_stubs |
|
59 | 60 | typeof_pyval, dtype_registry, loose_type_of_pyval, get_constant_value, get_dataclass_info, |
60 | 61 | ) |
61 | 62 | from .type import ( |
62 | | - PartitionViewTy, StridedViewTy, TupleTy, TileTy, NoneType, BoundMethodTy, ArrayTy, |
| 63 | + PartitionViewTy, StridedViewTy, GatherScatterViewTy, TupleTy, TileTy, NoneType, |
| 64 | + BoundMethodTy, ArrayTy, |
63 | 65 | ListTy, make_tile_ty, SliceType, DTypeConstructor, RangeIterType, Type, |
64 | 66 | NONE, ModuleTy, TypeTy, LooselyTypedScalar, DTypeSpec, StringTy, InvalidType, |
65 | 67 | ClosureTy, LiveCapturedScope, TokenTy, TiledViewTy, FormattedStringTy, |
66 | | - StringFormat, FormattedPiece, RawArrayMemoryTy, DataclassTy |
| 68 | + StringFormat, FormattedPiece, RawArrayMemoryTy, DataclassTy, IndexSliceTy |
67 | 69 | ) |
68 | 70 | from cuda.tile._datatype import ( |
69 | 71 | DType, is_integral, is_float, is_signed, is_boolean, |
@@ -2375,6 +2377,29 @@ def _materialize_tiled_view(array: Var, |
2375 | 2377 | return _make_partition_view(array, tile_shape, order, padding_mode) |
2376 | 2378 |
|
2377 | 2379 |
|
| 2380 | +@dataclass(eq=False) |
| 2381 | +class MakeGatherScatterView(Operation, opcode="make_gather_scatter_view"): |
| 2382 | + array: Var = operand() |
| 2383 | + |
| 2384 | + @override |
| 2385 | + def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value: |
| 2386 | + gs_view_ty = self.result_var.get_type() |
| 2387 | + return bc.encode_MakeGatherScatterViewOp(ctx.builder, |
| 2388 | + typeid(ctx.type_table, gs_view_ty), |
| 2389 | + ctx.get_value(self.array)) |
| 2390 | + |
| 2391 | + |
| 2392 | +def make_gather_scatter_view(array: Var, tile_shape: Sequence[int], |
| 2393 | + sparse_dim: int, |
| 2394 | + padding_mode: PaddingMode) -> Var: |
| 2395 | + array_ty = array.get_type() |
| 2396 | + assert isinstance(array_ty, ArrayTy) |
| 2397 | + view_ty = GatherScatterViewTy(array_ty, tuple(tile_shape), sparse_dim, padding_mode) |
| 2398 | + ret = add_operation(MakeGatherScatterView, view_ty, array=array) |
| 2399 | + ret.set_aggregate(array.get_aggregate()) |
| 2400 | + return ret |
| 2401 | + |
| 2402 | + |
2378 | 2403 | @dataclass(eq=False) |
2379 | 2404 | class TileLoad(Operation, opcode="tile_load", memory_effect=MemoryEffect.LOAD): |
2380 | 2405 | latency: Optional[int] = attribute() |
@@ -4826,6 +4851,128 @@ def tiled_view_atomic_rmw_impl(int_mode: Optional[AtomicRMWMode], |
4826 | 4851 | view=view, index=index_items, update=update) |
4827 | 4852 |
|
4828 | 4853 |
|
| 4854 | +@impl(ct.Slice) |
| 4855 | +def slice_index_constructor_impl(start: Var, length: Var) -> Var: |
| 4856 | + start_ty = require_signed_integer_0d_tile_type(start) |
| 4857 | + length_ty = require_signed_integer_0d_tile_type(length) |
| 4858 | + res_type = IndexSliceTy(start_ty, length_ty) |
| 4859 | + res_loose_type = IndexSliceTy(start.get_loose_type(), length.get_loose_type()) |
| 4860 | + return make_aggregate(IndexSliceValue(start, length), res_type, res_loose_type) |
| 4861 | + |
| 4862 | + |
| 4863 | +def _parse_advanced_index(indices: Var, ndim: int) -> tuple[int, tuple[int, ...], tuple[Var, ...]]: |
| 4864 | + """Unpack, classify, validate, and build the gather scatter view index. |
| 4865 | +
|
| 4866 | + Returns (sparse_dim, tile_shape, gs_index). |
| 4867 | + """ |
| 4868 | + require_tuple_type(indices) |
| 4869 | + items = list(indices.get_aggregate().items) |
| 4870 | + if len(items) != ndim: |
| 4871 | + raise TileTypeError( |
| 4872 | + f"load_advanced/store_advanced index length {len(items)} does not " |
| 4873 | + f"match array rank {ndim}") |
| 4874 | + |
| 4875 | + sparse_dims: list[int] = [] |
| 4876 | + tile_shape: list[int] = [] |
| 4877 | + gs_index: list[Var] = [] |
| 4878 | + |
| 4879 | + for dim, item in enumerate(items): |
| 4880 | + item_ty = item.get_type() |
| 4881 | + if isinstance(item_ty, TileTy): |
| 4882 | + if item_ty.ndim != 1: |
| 4883 | + raise TileTypeError( |
| 4884 | + f"Sparse index at dim {dim} must be a 1D integer tile, " |
| 4885 | + f"got {item_ty.ndim}D") |
| 4886 | + if not is_integral(item_ty.dtype): |
| 4887 | + raise TileTypeError( |
| 4888 | + f"Sparse index at dim {dim} must be an integer tile, " |
| 4889 | + f"got dtype {item_ty.dtype}") |
| 4890 | + sparse_dims.append(dim) |
| 4891 | + tile_shape.append(item_ty.shape[0]) |
| 4892 | + gs_index.append(item) |
| 4893 | + elif isinstance(item_ty, IndexSliceTy): |
| 4894 | + length_var = item.get_aggregate().length |
| 4895 | + if not length_var.is_constant(): |
| 4896 | + raise TileTypeError( |
| 4897 | + f"ct.Slice length at dim {dim} must be a compile-time constant " |
| 4898 | + f"in load_advanced/store_advanced") |
| 4899 | + length_val = length_var.get_constant() |
| 4900 | + if not isinstance(length_val, int) or length_val <= 0: |
| 4901 | + raise TileTypeError( |
| 4902 | + f"ct.Slice length at dim {dim} must be a positive integer, got {length_val}") |
| 4903 | + tile_shape.append(length_val) |
| 4904 | + gs_index.append(item.get_aggregate().start) |
| 4905 | + else: |
| 4906 | + raise TileTypeError( |
| 4907 | + f"load_advanced/store_advanced index at dim {dim} must be a " |
| 4908 | + f"1D integer Tile (sparse dim) or ct.Slice(start, length) " |
| 4909 | + f"(dense dim), got type {item_ty}") |
| 4910 | + |
| 4911 | + if len(sparse_dims) == 0: |
| 4912 | + raise TileTypeError( |
| 4913 | + "load_advanced/store_advanced: exactly one index must be a 1D " |
| 4914 | + "integer Tile (the sparse dim); none found") |
| 4915 | + if len(sparse_dims) > 1: |
| 4916 | + raise TileTypeError( |
| 4917 | + f"load_advanced/store_advanced: exactly one index must be a 1D " |
| 4918 | + f"integer Tile (the sparse dim); found {len(sparse_dims)} at " |
| 4919 | + f"dims {sparse_dims}") |
| 4920 | + |
| 4921 | + for dim, n in enumerate(tile_shape): |
| 4922 | + if not _is_power_of_2(n): |
| 4923 | + raise TileTypeError( |
| 4924 | + f"Index at dim {dim} has size {n}; must be a power of two") |
| 4925 | + |
| 4926 | + return sparse_dims[0], tuple(tile_shape), tuple(gs_index) |
| 4927 | + |
| 4928 | + |
| 4929 | +@impl(ct.load_advanced, min_version=BytecodeVersion.V_13_3) |
| 4930 | +def load_advanced_impl(array: Var, indices: Var, padding_mode: Var, |
| 4931 | + latency: Var, allow_tma: Var) -> Var: |
| 4932 | + array_ty = require_array_type(array) |
| 4933 | + if array_ty.ndim < 2: |
| 4934 | + raise TileTypeError( |
| 4935 | + "load_advanced requires a 2D or higher-rank array; " |
| 4936 | + "use ct.gather() for 1D arrays") |
| 4937 | + sparse_dim, tile_shape, gs_index = _parse_advanced_index(indices, array_ty.ndim) |
| 4938 | + padding_mode_val = require_constant_enum(padding_mode, PaddingMode) |
| 4939 | + latency_val = require_optional_constant_int(latency) |
| 4940 | + allow_tma_val = require_optional_constant_bool(allow_tma) |
| 4941 | + _check_load_store_hints(latency_val, allow_tma_val) |
| 4942 | + |
| 4943 | + view = make_gather_scatter_view(array, tile_shape, sparse_dim, padding_mode_val) |
| 4944 | + result, _token = add_operation(TileLoad, (make_tile_ty(array_ty.dtype, tile_shape), TokenTy()), |
| 4945 | + view=view, index=gs_index, |
| 4946 | + latency=latency_val, allow_tma=allow_tma_val) |
| 4947 | + return result |
| 4948 | + |
| 4949 | + |
| 4950 | +@impl(ct.store_advanced, min_version=BytecodeVersion.V_13_3) |
| 4951 | +def store_advanced_impl(array: Var, indices: Var, tile: Var, |
| 4952 | + latency: Var, allow_tma: Var): |
| 4953 | + array_ty = require_array_type(array) |
| 4954 | + if array_ty.ndim < 2: |
| 4955 | + raise TileTypeError( |
| 4956 | + "store_advanced requires a 2D or higher-rank array; " |
| 4957 | + "use ct.scatter() for 1D arrays") |
| 4958 | + sparse_dim, tile_shape, gs_index = _parse_advanced_index(indices, array_ty.ndim) |
| 4959 | + tile_ty = require_tile_type(tile) |
| 4960 | + if tile_ty.shape != tile_shape: |
| 4961 | + raise TileTypeError( |
| 4962 | + f"Tile shape {tile_ty.shape} does not match the shape implied by " |
| 4963 | + f"indices {tile_shape}") |
| 4964 | + tile = _implicit_cast(tile, array_ty.dtype, |
| 4965 | + "Stored tile dtype is incompatible with array dtype") |
| 4966 | + latency_val = require_optional_constant_int(latency) |
| 4967 | + allow_tma_val = require_optional_constant_bool(allow_tma) |
| 4968 | + _check_load_store_hints(latency_val, allow_tma_val) |
| 4969 | + |
| 4970 | + view = make_gather_scatter_view(array, tile_shape, sparse_dim, PaddingMode.UNDETERMINED) |
| 4971 | + [_token] = add_operation(TileStore, (TokenTy(),), |
| 4972 | + view=view, index=gs_index, tile=tile, |
| 4973 | + latency=latency_val, allow_tma=allow_tma_val) |
| 4974 | + |
| 4975 | + |
4829 | 4976 | def store_var(local_idx: int, value: Var, loc: Loc | None = None): |
4830 | 4977 | scope = Scope.get_current() |
4831 | 4978 | new_var = scope.local.redefine(local_idx, loc or Builder.get_current().loc) |
|
0 commit comments