add a set of helper functions for working with mesh-distributed tensors#2
Merged
SUNMMIO-jlou merged 3 commits intoSUNMMIO:mainfrom Dec 12, 2025
Merged
add a set of helper functions for working with mesh-distributed tensors#2SUNMMIO-jlou merged 3 commits intoSUNMMIO:mainfrom
SUNMMIO-jlou merged 3 commits intoSUNMMIO:mainfrom
Conversation
add mesh_tensor functions: annotate_mesh_tensor_info, mesh_tensor_copy
JiaqiGuoSunlune
approved these changes
Dec 11, 2025
Collaborator
There was a problem hiding this comment.
Overall, it looks good to me. We can start with this feature to enable the sharding and descriptor-like copy.
To enrich the feature, we may be need to support 1) replication in addition to sharding; 2) Add unit cases to this syntax suger to assert that we get desired transform (e.g., lowering to correct T.copy and T.ceildiv (for getting tile shape))
Collaborator
|
It would be great if we can add unit tests into the test folder related to any check-in. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
mesh_tensor_functions(mesh_shape: dict[str, int] = {"x": 4, "y": 4}):
Purpose: Factory that returns a set of helper functions for working with mesh-distributed tensors.
Inputs: mesh_shape — mapping from mesh axis names (e.g., "x", "y") to device counts.
Returns: A dict containing three helpers: annotate_mesh_tensor_info, mesh_tensor_copy, and get_tile_shape. These helpers share internal metadata parsed from annotations.
annotate_mesh_tensor_info(mesh_tensor_info: dict):
Purpose: Validate and register per-buffer mesh layout metadata for later use by the other helpers and passes.
Inputs: mesh_tensor_info — mapping from buffer objects to metadata dicts. Each metadata dict must contain at least block_shape, program_id, and sharding.
Behavior: Deep-copies and stores metadata keyed by buffer.data; returns a T.func_attr dict with the stored metadata so it can be attached to a function for downstream passes.
get_tile_shape(buffer: tir.Buffer):
Purpose: Compute the shape of the given buffer on a single mesh tile (device).
Inputs: buffer — a tvm.tir.Buffer whose global shape and recorded sharding info are used.
Behavior: Looks up the buffer's stored mesh metadata, reads which tensor dims map to mesh "x" and "y", and computes per-tile sizes by ceil-dividing the global sizes by the corresponding mesh dimension sizes; returns a tuple giving the per-tile shape.
mesh_tensor_copy(src: tir.Buffer, dst: tir.Buffer, *, src_coord: tuple[int] | None = None, dst_coord: tuple[int] | None = None):
Purpose: Perform a block-aware copy from src to dst, optionally selecting blocks by integer block coordinates.
Inputs: src, dst — source and destination tvm.tir.Buffer; optional src_coord/dst_coord — integer block coordinates.
Behavior: If block coordinates are provided, converts them to element offsets using the buffer's recorded block_shape and slices the buffers accordingly; then issues a T.copy to perform the copy. Raises a ValueError when metadata for a buffer is missing.