forked from hidet-org/hidet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[IR] [Primitives] Add thread cluster on sm_90 (hidet-org#145)
Allow access to cluster attributes inside Hidet kernels. Launch kernels with distributed shared memory. See docs: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#distributed-shared-memory https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#thread-block-clusters API: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cluster-group-cg Towards supporting hidet-org#102 by adding cluster rank primitive in Hidet. See `test_cluster.py` for example usage. To run test on Hopper machines use `pytest --hopper`
- Loading branch information
Showing
17 changed files
with
265 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from collections import namedtuple | ||
from typing import Union | ||
|
||
from hidet.ir.expr import Expr | ||
from hidet.ir.primitives.func import call_primitive_func, register_primitive_function | ||
from hidet.ir.primitives.vars import lookup_primitive_variable, register_primitive_variable | ||
from hidet.ir.type import DataType, FuncType, PointerType, VoidType, data_type | ||
from hidet.utils.py import initialize | ||
from hidet.ir.dtypes import i32 | ||
|
||
_cluster_fields = ["thread_rank", "block_rank", "dim_threads", "dim_blocks"] | ||
|
||
|
||
@initialize() | ||
def register_cuda_cluster_functions(): | ||
|
||
for suffix in _cluster_fields: | ||
register_primitive_variable(name=f"cooperative_groups::this_cluster().{suffix}()", dtype=i32) | ||
|
||
register_primitive_function( | ||
name="this_cluster.sync", | ||
func_or_type=FuncType([], VoidType()), | ||
codegen_name="cooperative_groups::this_cluster().sync", | ||
) | ||
|
||
for dtype in ['int8', 'uint8', 'uint32', 'int32', 'float16', 'float32', 'bool']: | ||
dtype = data_type(dtype) | ||
|
||
register_primitive_function( | ||
name=f"this_cluster.map_shared_rank_{dtype}", | ||
func_or_type=FuncType([PointerType(dtype), i32], PointerType(dtype)), | ||
codegen_name="cooperative_groups::this_cluster().map_shared_rank", | ||
) | ||
|
||
|
||
def cluster_sync(): | ||
return call_primitive_func("this_cluster.sync", []) | ||
|
||
|
||
def cluster_map_shared_rank(addr: Expr, rank: Union[Expr, int], dtype: Union[DataType, str]): | ||
func_name = f"this_cluster.map_shared_rank_{dtype}" | ||
return call_primitive_func(func_name, [addr, rank]) | ||
|
||
|
||
this_cluster = namedtuple("this_cluster", field_names=_cluster_fields + ["sync", "map_shared_rank"])( | ||
*[lookup_primitive_variable("cooperative_groups::this_cluster().{}()".format(field)) for field in _cluster_fields], | ||
cluster_sync, | ||
cluster_map_shared_rank, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.