Skip to content

add a set of helper functions for working with mesh-distributed tensors#2

Merged
SUNMMIO-jlou merged 3 commits intoSUNMMIO:mainfrom
xiaoyao-NKU:main
Dec 12, 2025
Merged

add a set of helper functions for working with mesh-distributed tensors#2
SUNMMIO-jlou merged 3 commits intoSUNMMIO:mainfrom
xiaoyao-NKU:main

Conversation

@xiaoyao-NKU
Copy link
Copy Markdown
Collaborator

mesh_tensor_functions(mesh_shape: dict[str, int] = {"x": 4, "y": 4}):
Purpose: Factory that returns a set of helper functions for working with mesh-distributed tensors.
Inputs: mesh_shape — mapping from mesh axis names (e.g., "x", "y") to device counts.
Returns: A dict containing three helpers: annotate_mesh_tensor_info, mesh_tensor_copy, and get_tile_shape. These helpers share internal metadata parsed from annotations.

annotate_mesh_tensor_info(mesh_tensor_info: dict):
Purpose: Validate and register per-buffer mesh layout metadata for later use by the other helpers and passes.
Inputs: mesh_tensor_info — mapping from buffer objects to metadata dicts. Each metadata dict must contain at least block_shape, program_id, and sharding.
Behavior: Deep-copies and stores metadata keyed by buffer.data; returns a T.func_attr dict with the stored metadata so it can be attached to a function for downstream passes.

get_tile_shape(buffer: tir.Buffer):
Purpose: Compute the shape of the given buffer on a single mesh tile (device).
Inputs: buffer — a tvm.tir.Buffer whose global shape and recorded sharding info are used.
Behavior: Looks up the buffer's stored mesh metadata, reads which tensor dims map to mesh "x" and "y", and computes per-tile sizes by ceil-dividing the global sizes by the corresponding mesh dimension sizes; returns a tuple giving the per-tile shape.

mesh_tensor_copy(src: tir.Buffer, dst: tir.Buffer, *, src_coord: tuple[int] | None = None, dst_coord: tuple[int] | None = None):
Purpose: Perform a block-aware copy from src to dst, optionally selecting blocks by integer block coordinates.
Inputs: src, dst — source and destination tvm.tir.Buffer; optional src_coord/dst_coord — integer block coordinates.
Behavior: If block coordinates are provided, converts them to element offsets using the buffer's recorded block_shape and slices the buffers accordingly; then issues a T.copy to perform the copy. Raises a ValueError when metadata for a buffer is missing.

xiaoyao-NKU and others added 3 commits December 9, 2025 15:56
add mesh_tensor functions: annotate_mesh_tensor_info, mesh_tensor_copy
Copy link
Copy Markdown
Collaborator

@JiaqiGuoSunlune JiaqiGuoSunlune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, it looks good to me. We can start with this feature to enable the sharding and descriptor-like copy.

To enrich the feature, we may be need to support 1) replication in addition to sharding; 2) Add unit cases to this syntax suger to assert that we get desired transform (e.g., lowering to correct T.copy and T.ceildiv (for getting tile shape))

@SUNMMIO-jlou
Copy link
Copy Markdown
Collaborator

It would be great if we can add unit tests into the test folder related to any check-in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants