Skip to content

Commit

Permalink
[IR] [Primitives] Add thread cluster on sm_90 (hidet-org#145)
Browse files Browse the repository at this point in the history
Allow access to cluster attributes inside Hidet kernels. Launch kernels
with distributed shared memory.

See docs:


https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#distributed-shared-memory

https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#thread-block-clusters

API:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cluster-group-cg

Towards supporting hidet-org#102 by adding cluster rank primitive in Hidet.

See `test_cluster.py` for example usage. To run test on Hopper machines
use `pytest --hopper`
  • Loading branch information
KTong821 committed Apr 22, 2024
1 parent e9fad7a commit 1444c20
Show file tree
Hide file tree
Showing 17 changed files with 265 additions and 17 deletions.
10 changes: 10 additions & 0 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from hidet.ir.compute import TensorNode, ScalarNode
from hidet.ir.functors import ModuleFunctor, StmtFunctor, ExprFunctor, TypeFunctor
from hidet.ir.tools import TypeInfer
from hidet.transforms.generate_launch_func import _normalize_dim3
from hidet.utils.doc import Doc, NewLine, Text, doc_join
from hidet.ir.utils.call_graph import CallGraph
from hidet.utils.namer import Namer
Expand All @@ -50,6 +51,7 @@ def __init__(self):
self.require_fp16 = False
self.require_bf16 = False
self.require_tf32 = False
self.require_cooperative_groups = False

def __call__(self, node) -> Doc:
return self.visit(node)
Expand Down Expand Up @@ -691,6 +693,8 @@ def require_headers(self) -> Doc:
doc += Text('#include <cuda_fp16.h>') + NewLine()
if self.require_bf16:
doc += Text('#include <cuda_bf16.h>') + NewLine()
if self.require_cooperative_groups:
doc += Text('#include <cooperative_groups.h>') + NewLine()
doc += Text('#include <hidet/runtime/symbols.h>') + NewLine()
doc += Text('#include <hidet/runtime/memory_planner.h>') + NewLine()
doc += Text('#include <hidet/runtime/cpu/context.h>') + NewLine()
Expand Down Expand Up @@ -733,6 +737,12 @@ def visit_Function(self, func: Function) -> Doc:

doc += self(func.ret_type)

if 'cuda.cluster_dim' in func.attrs:
cluster_dims = _normalize_dim3(func.attrs['cuda.cluster_dim'])
doc += f" __cluster_dims__({cluster_dims[0]}, {cluster_dims[1]}, {cluster_dims[2]})"

self.require_cooperative_groups = True

# launch bound for grid worker
if func.kind == 'cuda_kernel':
block_dim = func.attrs['cuda.block_dim']
Expand Down
4 changes: 3 additions & 1 deletion python/hidet/graph/ops/fusion/apply_prologue_epilogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,9 @@ def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt):
func_name = stmt.func_var.name
if func_name in self.func_records:
args = self.process_call(func_name, list(stmt.args))
return LaunchKernelStmt(stmt.func_var, args, stmt.grid_dim, stmt.block_dim, stmt.shared_mem_bytes)
return LaunchKernelStmt(
stmt.func_var, args, stmt.grid_dim, stmt.cluster_dim, stmt.block_dim, stmt.shared_mem_bytes
)
return super().visit_LaunchKernelStmt(stmt)

def visit_TensorElement(self, e: TensorElement):
Expand Down
3 changes: 3 additions & 0 deletions python/hidet/ir/builders/func_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
label: str = "",
ret_type=VoidType(),
grid_dim=None,
cluster_dim=None,
block_dim=None,
dynamic_smem_bytes=None,
min_blocks=None,
Expand All @@ -44,6 +45,8 @@ def __init__(

if grid_dim is not None:
self.attrs['cuda.grid_dim'] = grid_dim
if cluster_dim is not None:
self.attrs['cuda.cluster_dim'] = cluster_dim
if block_dim is not None:
self.attrs['cuda.block_dim'] = block_dim
if dynamic_smem_bytes:
Expand Down
2 changes: 2 additions & 0 deletions python/hidet/ir/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Function(Node):
- 'public': this is a packed function that wraps kernel function(s)
'cuda.grid_dim': Union[int, List[int]]
the grid dimension in cuda launch configuration
'cuda.cluster_dim': Union[int, List[int]]
the cluster dimension in cuda launch configuration
'cuda.block_dim': Union[int, List[int]]
the block dimension in cuda launch configuration
'cuda.dynamic_smem_bytes': int
Expand Down
7 changes: 4 additions & 3 deletions python/hidet/ir/functors/stmt_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,17 @@ def visit_AsmStmt(self, stmt: AsmStmt):
def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt):
func_var = self.visit(stmt.func_var)
args = [self.visit(e) for e in stmt.args]
grid_dim = (self.visit(stmt.grid_dim[0]), self.visit(stmt.grid_dim[1]), self.visit(stmt.grid_dim[2]))
block_dim = (self.visit(stmt.block_dim[0]), self.visit(stmt.block_dim[1]), self.visit(stmt.block_dim[2]))
grid_dim = tuple(self.visit(stmt.grid_dim[i]) for i in range(3))
cluster_dim = tuple(self.visit(stmt.cluster_dim[i]) for i in range(3))
block_dim = tuple(self.visit(stmt.block_dim[i]) for i in range(3))
shared_mem_bytes = self.visit(stmt.shared_mem_bytes)
if same_list(
[func_var, *args, *grid_dim, *block_dim, shared_mem_bytes],
[stmt.func_var, *stmt.args, *stmt.grid_dim, *stmt.block_dim, stmt.shared_mem_bytes],
):
return stmt
else:
return LaunchKernelStmt(func_var, args, grid_dim, block_dim, shared_mem_bytes)
return LaunchKernelStmt(func_var, args, grid_dim, cluster_dim, block_dim, shared_mem_bytes)

def visit_BlackBoxStmt(self, stmt: BlackBoxStmt):
exprs = [self.visit(e) for e in stmt.exprs]
Expand Down
1 change: 1 addition & 0 deletions python/hidet/ir/primitives/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from . import math
from . import mma

from .cluster import this_cluster
from .smem import set_kernel_max_dynamic_smem_bytes
from .sync import syncthreads, syncwarp
from .ldst import lds128, sts128
Expand Down
49 changes: 49 additions & 0 deletions python/hidet/ir/primitives/cuda/cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from collections import namedtuple
from typing import Union

from hidet.ir.expr import Expr
from hidet.ir.primitives.func import call_primitive_func, register_primitive_function
from hidet.ir.primitives.vars import lookup_primitive_variable, register_primitive_variable
from hidet.ir.type import DataType, FuncType, PointerType, VoidType, data_type
from hidet.utils.py import initialize
from hidet.ir.dtypes import i32

_cluster_fields = ["thread_rank", "block_rank", "dim_threads", "dim_blocks"]


@initialize()
def register_cuda_cluster_functions():

for suffix in _cluster_fields:
register_primitive_variable(name=f"cooperative_groups::this_cluster().{suffix}()", dtype=i32)

register_primitive_function(
name="this_cluster.sync",
func_or_type=FuncType([], VoidType()),
codegen_name="cooperative_groups::this_cluster().sync",
)

for dtype in ['int8', 'uint8', 'uint32', 'int32', 'float16', 'float32', 'bool']:
dtype = data_type(dtype)

register_primitive_function(
name=f"this_cluster.map_shared_rank_{dtype}",
func_or_type=FuncType([PointerType(dtype), i32], PointerType(dtype)),
codegen_name="cooperative_groups::this_cluster().map_shared_rank",
)


def cluster_sync():
return call_primitive_func("this_cluster.sync", [])


def cluster_map_shared_rank(addr: Expr, rank: Union[Expr, int], dtype: Union[DataType, str]):
func_name = f"this_cluster.map_shared_rank_{dtype}"
return call_primitive_func(func_name, [addr, rank])


this_cluster = namedtuple("this_cluster", field_names=_cluster_fields + ["sync", "map_shared_rank"])(
*[lookup_primitive_variable("cooperative_groups::this_cluster().{}()".format(field)) for field in _cluster_fields],
cluster_sync,
cluster_map_shared_rank,
)
8 changes: 7 additions & 1 deletion python/hidet/ir/schedulers/cuda/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,10 @@ def schedule_grid_compute(self, node: GridCompute, tensor_map: Dict[TensorNode,
fb += BufferStoreStmt(out_param, task_index, value)
func = fb.get()
func_var = self.add_function(func)
return launch_kernel(func_var, args=call_args, grid_dim=grid_dim, block_dim=block_dim)
return launch_kernel(
func_var,
args=call_args,
grid_dim=grid_dim,
cluster_dim=func.get_attr('cluster_dim', default=1),
block_dim=block_dim,
)
11 changes: 7 additions & 4 deletions python/hidet/ir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,14 @@ def __init__(
func_var: Var,
args: Sequence[Expr],
grid_dim: Tuple[Expr, Expr, Expr],
cluster_dim: Tuple[Expr, Expr, Expr],
block_dim: Tuple[Expr, Expr, Expr],
shared_mem: Expr,
):
self.func_var: Var = func_var
self.args: List[Expr] = list(args)
self.grid_dim: Tuple[Expr, Expr, Expr] = grid_dim
self.cluster_dim: Tuple[Expr, Expr, Expr] = cluster_dim
self.block_dim: Tuple[Expr, Expr, Expr] = block_dim
self.shared_mem_bytes: Expr = shared_mem

Expand Down Expand Up @@ -404,17 +406,18 @@ def launch_kernel(
args: Sequence[Expr],
grid_dim: Union[Sequence[Int], Int],
block_dim: Union[Sequence[Int], Int],
cluster_dim: Union[Sequence[Int], Int] = 1,
shared_mem: Optional[Int] = 0,
) -> LaunchKernelStmt:
launch_config: List[Tuple[Expr, Expr, Expr]] = []
for dims in [grid_dim, block_dim]:
for dims in [grid_dim, cluster_dim, block_dim]:
if not isinstance(dims, (list, tuple)):
dims = [dims]
dims = list(dims)
if len(dims) > 3:
raise ValueError('Grid/Block dimension must be 3 or less.')
raise ValueError('Grid/Cluster/Block dimension must be 3 or less.')
while len(dims) < 3:
dims.append(1)
launch_config.append(convert(dims))
grid_dim, block_dim = launch_config
return LaunchKernelStmt(func_var, args, grid_dim, block_dim, convert(shared_mem))
grid_dim, cluster_dim, block_dim = launch_config
return LaunchKernelStmt(func_var, args, grid_dim, cluster_dim, block_dim, convert(shared_mem))
2 changes: 1 addition & 1 deletion python/hidet/ir/tools/ir_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ def visit_launch_kernel_stmt(self, node):
shared_smem = self(node[3])

args = [self(v) for v in node[4:]]
return LaunchKernelStmt(fn_var, args, grid_dim, block_dim, shared_smem)
return LaunchKernelStmt(fn_var, args, grid_dim, (1,), block_dim, shared_smem)

def visit_let_expr(self, node):
return Let(self(node[0]), self(node[1]), self(node[2]))
Expand Down
3 changes: 3 additions & 0 deletions python/hidet/lang/attrs/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
# The grid dimension of a cuda kernel, specifying the number of thread blocks
grid_dim: Dim3 = 1

# The optional cluster dimension of a cuda kernel, specifying the number of thread blocks per cluster
cluster_dim: Dim3 = 1

# The block dimension of a cuda kernel, specifying the number of threads per block
block_dim: Dim3 = 1

Expand Down
1 change: 1 addition & 0 deletions python/hidet/lang/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from hidet.ir.stmt import DeclareScope
from hidet.ir.layout import DataLayout
from hidet.ir.primitives.cuda.vars import threadIdx, blockIdx, blockDim, gridDim
from hidet.ir.primitives.cuda.cluster import this_cluster
from hidet.ir.primitives.cuda.smem import dynamic_shared_memory, set_kernel_max_dynamic_smem_bytes
from hidet.ir.primitives.cuda.sync import syncthreads, syncthreads_and, syncthreads_count, syncthreads_or, syncwarp
from hidet.ir.primitives.cuda.mma import MmaConfig, mma_sync, ldmatrix
Expand Down
3 changes: 2 additions & 1 deletion python/hidet/lang/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def process_assign(self, lhs: Union[Attribute, Subscript, Name], rhs, type_annot
namespace = {hidet.lang.attrs: '', hidet.lang.attrs.cuda: 'cuda.'}
if lhs_base in namespace:
attr_name = namespace[lhs_base] + lhs.attr
if attr_name in ['cuda.block_dim', 'cuda.grid_dim', 'cuda.dynamic_smem_bytes']:
if attr_name in ['cuda.block_dim', 'cuda.cluster_dim', 'cuda.grid_dim', 'cuda.dynamic_smem_bytes']:
if isinstance(rhs, (tuple, list)):
rhs = [simplify(v) for v in rhs]
else:
Expand Down Expand Up @@ -959,6 +959,7 @@ def visit_Call(self, expr: Call):
func_var=func_var,
args=args,
grid_dim=func.attrs['cuda.grid_dim'],
cluster_dim=func.attrs.get('cuda.cluster_dim', 1),
block_dim=func.attrs['cuda.block_dim'],
shared_mem=func.attrs.get('cuda.dynamic_smem_bytes', 0),
)
Expand Down
8 changes: 8 additions & 0 deletions python/hidet/transforms/check_launch_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from hidet.ir.stmt import LaunchKernelStmt, AssertStmt
from hidet.ir.func import Function
from hidet.transforms.base import Pass, FunctionPass
from hidet.utils.py import prod


class CheckLaunchConfigurationRewriter(IRRewriter):
Expand Down Expand Up @@ -49,6 +50,13 @@ def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt):
stmt.block_dim[2],
)
sb += AssertStmt(False, "Invalid launch configuration")
conditions = [grid_dim % cluster_dim != 0 for grid_dim, cluster_dim in zip(stmt.grid_dim, stmt.cluster_dim)]
with sb.if_then(logical_or(*conditions)):
sb += AssertStmt(False, "Cluster dims must elementwise evenly divide grid dims")

conditions = prod(stmt.cluster_dim) > 8
with sb.if_then(conditions):
sb += AssertStmt(False, "At most 8 thread blocks in a cluster")
with sb.if_then(stmt.shared_mem_bytes > 49152):
# if the shared memory is larger than 48KB, we should call cudaFuncSetAttribute
sb += BlackBoxStmt(
Expand Down
1 change: 1 addition & 0 deletions python/hidet/transforms/generate_launch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def add_launch_func(ir_module: IRModule, kernel_func: Function):
func_var,
params,
grid_dim=rewrite(_normalize_dim3(kernel_func.get_attr('cuda.grid_dim')), param_remap),
cluster_dim=rewrite(_normalize_dim3(kernel_func.get_attr('cuda.cluster_dim', default=1)), param_remap),
block_dim=rewrite(_normalize_dim3(kernel_func.get_attr('cuda.block_dim')), param_remap),
shared_mem=shared_memory_bytes,
)
Expand Down
19 changes: 13 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@
# limitations under the License.
import os
import pytest
import shutil
import hidet


def pytest_addoption(parser):
parser.addoption("--clear-cache", action="store_true", help="Clear operator cache before running tests")
parser.addoption("--runslow", action="store_true", help="Run slow tests")
parser.addoption("--hopper", action='store_true', help="Run test that requires sm_90+")


def pytest_configure(config):
config.addinivalue_line("markers", "slow: mark test as slow to run")
config.addinivalue_line("markers", "hopper: mark test as requiring sm_90+ to run")


def pytest_sessionstart(session):
Expand All @@ -43,13 +44,19 @@ def pytest_sessionstart(session):


def pytest_collection_modifyitems(config, items):
keywords = {
"slow": pytest.mark.skip(reason="need --runslow option to run"),
"hopper": pytest.mark.skip(reason="need --hopper option to run"),
}
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
del keywords["slow"]
if config.getoption("--hopper"):
del keywords["hopper"]

for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
for keyword in keywords.keys():
if keyword in item.keywords:
item.add_marker(keywords[keyword])


@pytest.fixture(autouse=True)
Expand Down

0 comments on commit 1444c20

Please sign in to comment.