diff --git a/examples/08_gemm_atomics_all_reduce/benchmark.py b/examples/08_gemm_all_reduce_atomics/benchmark.py similarity index 77% rename from examples/08_gemm_atomics_all_reduce/benchmark.py rename to examples/08_gemm_all_reduce_atomics/benchmark.py index 31de4fa3..c214e4d5 100755 --- a/examples/08_gemm_atomics_all_reduce/benchmark.py +++ b/examples/08_gemm_all_reduce_atomics/benchmark.py @@ -60,18 +60,11 @@ def parse_args(): # Best to try 1, 6 or 8 parser.add_argument("--gsize_m", type=int, default=6, help="Grid size M") - parser.add_argument("--two_tiles", type=str, default="True", help="Use two tiles") - parser.add_argument("--num_stages", type=int, default=1, help="Number of stages") - parser.add_argument("--num_warps", type=int, default=8, help="Number of warps") - parser.add_argument("--waves_per_eu", type=int, default=0, help="Waves per execution unit") - parser.add_argument("--mfmaInstrSize", type=int, default=16, help="MFMA instruction size") - parser.add_argument("--kpack", type=int, default=2, help="K packing size") parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") # For All Scatter, use: 288 # For One Shot, use: 256 - parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM") - parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs") + parser.add_argument("--gemm_sms", type=int, default=304, help="Number of SMs for GEMM") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -107,7 +100,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype) B = shmem.randn(args["n"], args["k"], device="cuda", dtype=datatype).T - C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) args["M"] = args["m"] args["N"] = args["n"] @@ -134,19 +126,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N - if args["gemm_sms"] >= args["total_sms"]: - print(f"Invalid number of GEMM SMs. {args['gemm_sms']} >= {args['total_sms']}") - exit(1) - - tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) - - locks = shmem.zeros((args["gemm_sms"],), device="cuda", dtype=torch.int32) - - P = shmem.zeros( - (args["gemm_sms"], args["BLK_M"] * args["BLK_N"]), - device="cuda", - dtype=torch.float32, - ) bias = None gemm_stream = torch.cuda.Stream() @@ -165,11 +144,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Timestamps timestamps = Timestamps(num_tiles=total_tiles) - def preamble(): - shmem.barrier() - tile_completed.zero_() - shmem.barrier() - def run_experiment(): nonlocal local_C nonlocal global_C @@ -190,9 +164,6 @@ def run_experiment(): local_C, global_C, bias, - P, - locks, - tile_completed, rank, world_size, args["gemm_sms"], @@ -200,14 +171,8 @@ def run_experiment(): args["BLK_N"], args["BLK_K"], args["gsize_m"], - args["two_tiles"], - args["num_stages"], - args["num_warps"], - args["waves_per_eu"], - args["mfmaInstrSize"], - args["kpack"], shmem.get_heap_bases(), - cu_count, + "gfx942", args["trace_tiles"], timestamps.mm_begin_timestamp, timestamps.mm_end_timestamp, @@ -228,25 +193,15 @@ def run_experiment(): # Warmup run_experiment() - shmem.barrier() - preamble() shmem.barrier() for k in ["gemm"]: kernel_timing[k]["ms"] = 0 kernel_timing[k]["experiments"] = 0 - if not is_triton_interpret_set(): - gemm_registers = matmul.streamk_registers - gemm_spills = matmul.streamk_spills - - json_writer.add_field("gemm_registers", gemm_registers) - json_writer.add_field("gemm_spills", gemm_spills) - if args["validate"]: shmem.info("Validating...") - - matmul.set_debug(False) + matmul.set_debug(True) # Validate global result success = validate_gemm(A, B, global_C, shmem, atol=2) passed_str = "passed" if success else "failed" @@ -254,18 +209,28 @@ def run_experiment(): # Wait for all to finish validation shmem.barrier() - json_writer.add_field("success", success) shmem.info("Validation completed") + json_writer.add_field("success", success) + + if not is_triton_interpret_set(): + gemm_registers = matmul.streamk_registers + gemm_spills = matmul.streamk_spills + + json_writer.add_field("gemm_registers", gemm_registers) + json_writer.add_field("gemm_spills", gemm_spills) + if args["benchmark"]: + matmul.set_debug(False) shmem.info("Benchmarking...") perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) - triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) + triton_ms = iris.do_bench(run_experiment, shmem.barrier) triton_tflops = perf(triton_ms) - shmem.info(f"tile matmul + all_reduce (grid={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") + algo_string = "all_reduce" + shmem.info(f"tile matmul + {algo_string} (grid={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") - json_writer.add_field("triton_tflops", triton_tflops) - json_writer.add_field("triton_ms", triton_ms) + json_writer.add_field("tflops", triton_tflops) + json_writer.add_field("total_ms", triton_ms) for k in ["gemm"]: json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) @@ -280,7 +245,8 @@ def run_experiment(): if args["trace_tiles"] and rank == 0: gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 - filename = f"gemm_all_reduce_tiles_trace_rank{rank}.json" + algo_string = "all_reduce" + filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json" timestamps.to_json(filename, gpu_freq) shmem.barrier() diff --git a/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py b/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py new file mode 100644 index 00000000..1b69df0d --- /dev/null +++ b/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from examples.common.utils import read_realtime + +import sys +import os + +import iris + + +@triton.jit() +def persistent_gemm_all_reduce( + A, + B, + C, + c_global, + bias_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_cm_global, + stride_cn_global, + stride_bias, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + for tile_id in range(pid, total_tiles, NUM_SMS): + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + tl.assume(pid_m > 0) + tl.assume(pid_n > 0) + + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + a = tl.load(tl.multiple_of(A_BASE, (1, 16))) + b = tl.load(tl.multiple_of(B_BASE, (16, 1))) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b) + + # Accumulator registers with C results + c = acc.to(C.type.element_ty) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + # Add compiler hints + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Define the C-mask (BLOCK_SIZE_M, 1) x (1, BLOCK_SIZE_N) + sub_mask = (rm[:, None] < M) & (rn[None, :] < N) + + # Calculate the "global" offset of C based on the rank. + # Note how each GPU is producing the entire output but partial-K. + global_offset = rm[:, None] * stride_cm_global + rn[None, :] * stride_cn_global + + # Timestamp for GEMM before store + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) + + # Store data to the global result using puts + for remote_rank in range(world_size): + if remote_rank == cur_rank: + # For the current rank, we can use store + tl.atomic_add(c_global + global_offset, c, mask=sub_mask) + else: + iris.atomic_add( + c_global + global_offset, + c, + cur_rank, + remote_rank, + heap_bases, + mask=sub_mask, + ) + + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) diff --git a/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py b/examples/08_gemm_all_reduce_atomics/matmul_wrapper.py similarity index 51% rename from examples/08_gemm_atomics_all_reduce/matmul_wrapper.py rename to examples/08_gemm_all_reduce_atomics/matmul_wrapper.py index ba55286e..54aaf430 100644 --- a/examples/08_gemm_atomics_all_reduce/matmul_wrapper.py +++ b/examples/08_gemm_all_reduce_atomics/matmul_wrapper.py @@ -9,7 +9,7 @@ # from streamk_kernel import streamk_gemm # from streamk_kernel_atomic import streamk_gemm -from gemm_atomics_all_reduce import persistent_gemm_all_reduce +from gemm_all_reduce_atomics import persistent_gemm_all_reduce from examples.common.utils import is_triton_interpret_set import iris @@ -20,8 +20,6 @@ class matmul(torch.autograd.Function): _debug = True - _num_xcds = iris.hip.get_num_xcc() - @staticmethod def set_debug(debug: bool): matmul._debug = debug @@ -35,87 +33,49 @@ def _call( c: torch.Tensor, c_global: torch.Tensor, bias: torch.Tensor, - P: torch.Tensor, - locks: torch.Tensor, - tile_completed: torch.Tensor, rank: int, world_size: int, - total_programs_streamk: int, + num_sms: int, BLK_M: int, BLK_N: int, BLK_K: int, gsize_m: int, - two_tiles: bool, - num_stages: int, - num_warps: int, - waves_per_eu: int, - mfmaInstrSize: int, - kpack: int, heap_bases_ptr: torch.Tensor = None, - cu_count: int = None, + arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): - # assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" # checks constraints assert a.shape[1] == b.shape[0], "incompatible dimensions" M, K = a.shape _, N = b.shape - num_xcds = matmul._num_xcds + num_xcds = iris.hip.get_num_xcc() + + # TODO: Use arch-specific values. + num_stages = 2 + num_warps = 8 + waves_per_eu = 0 + mfma = 16 + kpack = 1 total_blocks_M = triton.cdiv(M, BLK_M) total_blocks_N = triton.cdiv(N, BLK_N) iters_per_tile = triton.cdiv(K, BLK_K) total_tiles = total_blocks_M * total_blocks_N even_k = K % BLK_K == 0 - - if total_programs_streamk > 0: # GEMM - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile GEMM + data-parallel from original paper - # if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - # total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - if matmul._debug: - print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{total_full_tiles_streamk=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - print("total_remainder_iters_streamk=", total_partial_tiles_streamk) use_bias = False # compute grid (work to do per SM on the first wave) - grids = total_programs_streamk stride_bias = bias.stride(0) if use_bias else 0 - kk = gemm_kernel[(grids,)]( + kk = gemm_kernel[(num_sms,)]( a, b, c, c_global, bias, - P, - locks, - tile_completed, M, N, K, @@ -132,15 +92,14 @@ def _call( BLOCK_SIZE_N=BLK_N, BLOCK_SIZE_K=BLK_K, GROUP_SIZE_M=gsize_m, - NUM_SMS=total_programs_streamk, - STREAMK_TILES=total_tiles_streamk, + NUM_SMS=num_sms, NUM_XCDS=num_xcds, BIAS=use_bias, EVEN_K=even_k, num_stages=num_stages, num_warps=num_warps, waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, + matrix_instr_nonkdim=mfma, kpack=kpack, heap_bases=heap_bases_ptr, cur_rank=rank, @@ -153,9 +112,6 @@ def _call( if matmul._debug and not is_triton_interpret_set(): matmul.streamk_registers = kk.n_regs matmul.streamk_spills = kk.n_spills - print(f"{kk.n_regs} registers used, {kk.n_spills} spills") - # print(kk.asm['ttgir']) - # print(kk.asm['amdgcn']) return c @@ -167,24 +123,15 @@ def forward( c: torch.Tensor, c_global: torch.Tensor, bias: torch.Tensor, - P: torch.Tensor, - locks: torch.Tensor, - tile_completed: torch.Tensor, rank: int, world_size: int, - grid: int, - BLK_M=128, - BLK_N=128, - BLK_K=32, - gsize_m=1, - two_tiles=True, - num_stages=3, - num_warps=4, - waves_per_eu=2, - mfmaInstrSize=16, - kpack=1, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, heap_bases_ptr: torch.Tensor = None, - cu_count: int = None, + arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, @@ -195,24 +142,15 @@ def forward( c=c, c_global=c_global, bias=bias, - P=P, - locks=locks, - tile_completed=tile_completed, rank=rank, world_size=world_size, - total_programs_streamk=grid, + num_sms=num_sms, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, gsize_m=gsize_m, - two_tiles=two_tiles, - num_warps=num_warps, - num_stages=num_stages, - waves_per_eu=waves_per_eu, - mfmaInstrSize=mfmaInstrSize, - kpack=kpack, heap_bases_ptr=heap_bases_ptr, - cu_count=cu_count, + arch=arch, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, mm_begin_timestamp=mm_begin_timestamp, mm_end_timestamp=mm_end_timestamp, diff --git a/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py b/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py deleted file mode 100644 index e692f210..00000000 --- a/examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py +++ /dev/null @@ -1,252 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -import triton -import triton.language as tl -from examples.common.utils import read_realtime - -import sys -import os - -import iris - - -@triton.jit -def tile_id_to_index_range( - tile_id, - M, - N, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - - tile_in_group = tile_id % num_pid_in_group - pid_m = first_pid_m + (tile_in_group % group_size_m) - pid_n = tile_in_group // group_size_m - - rm_start = pid_m * BLOCK_SIZE_M - rn_start = pid_n * BLOCK_SIZE_N - - # clamp to the maximum valid index (M-1, N-1) - max_m = M - 1 - max_n = N - 1 - - # generate indices - rm = rm_start + tl.arange(0, BLOCK_SIZE_M) - rn = rn_start + tl.arange(0, BLOCK_SIZE_N) - - rm = tl.minimum(rm, max_m) - rn = tl.minimum(rn, max_n) - - return rm, rn, rm_start, rn_start - - -@triton.jit -def offset_for_tile(local_tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M_local, N_local): - rm, rn, rm_start, rn_start = tile_id_to_index_range( - local_tile_id, M_local, N_local, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - c_mask = (rm[:, None] < M_local) & (rn[None, :] < N_local) - return rm, rn, c_mask, rm_start, rn_start - - -@triton.jit -def extract_submask_and_offset( - rm, - rn, - mask, - rm_start, - rn_start, - start_row, - start_col, - SUB_BLOCK_SIZE_M: tl.constexpr, - SUB_BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - stride_cm_local: tl.constexpr, - stride_cn_local: tl.constexpr, -): - # Create indices for the sub-block - sub_rm = tl.arange(0, SUB_BLOCK_SIZE_M) + start_row - sub_rn = tl.arange(0, SUB_BLOCK_SIZE_N) + start_col - - # Create a 2D grid of indices for the sub-block - sub_rm_2d = sub_rm[:, None] # Shape: (SUB_BLOCK_SIZE_M, 1) - sub_rn_2d = sub_rn[None, :] # Shape: (1, SUB_BLOCK_SIZE_N) - - # Compute the sub-mask - sub_mask = (sub_rm_2d < BLOCK_SIZE_M) & (sub_rn_2d < BLOCK_SIZE_N) - - # Compute the sub-offset relative to the start of the tile - sub_offset = ((rm_start + sub_rm_2d) * stride_cm_local) + ((rn_start + sub_rn_2d) * stride_cn_local) - - return sub_mask, sub_offset - - -@triton.jit() -def persistent_gemm_all_reduce( - A, - B, - C, - c_global, - bias_ptr, - P, - locks, - tile_completed, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_cm_global, - stride_cn_global, - stride_bias, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - NUM_SMS: tl.constexpr, - STREAMK_TILES: tl.constexpr, - NUM_XCDS: tl.constexpr, - BIAS: tl.constexpr, - EVEN_K: tl.constexpr, - heap_bases: tl.tensor, - cur_rank: tl.constexpr, - world_size: tl.constexpr, - NOTIFY_REMOTES: tl.constexpr = False, - COLLECT_TIMESTAMPS: tl.constexpr = False, - mm_begin_timestamp_ptr: tl.tensor = None, - mm_end_timestamp_ptr: tl.tensor = None, -): - pid = tl.program_id(0) - - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - total_tiles = num_pid_m * num_pid_n - - tl.assume(stride_am > 0) - tl.assume(stride_ak > 0) - tl.assume(stride_bn > 0) - tl.assume(stride_bk > 0) - tl.assume(stride_cm > 0) - tl.assume(stride_cn > 0) - - acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 - - for tile_id in range(pid, total_tiles, NUM_SMS): - if COLLECT_TIMESTAMPS: - timestamp = read_realtime() - tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - - rk = tl.arange(0, BLOCK_SIZE_K) - rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) - rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - - tl.assume(pid_m > 0) - tl.assume(pid_n > 0) - - loop_k = tl.cdiv(K, BLOCK_SIZE_K) - if not EVEN_K: - loop_k -= 1 - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - for k in range(0, loop_k): - a = tl.load(tl.multiple_of(A_BASE, (1, 16))) - b = tl.load(tl.multiple_of(B_BASE, (16, 1))) - acc += tl.dot(a, b) - A_BASE += BLOCK_SIZE_K * stride_ak - B_BASE += BLOCK_SIZE_K * stride_bk - - if not EVEN_K: - k = loop_k - rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - A_BASE = tl.multiple_of(A_BASE, (1, 16)) - B_BASE = tl.multiple_of(B_BASE, (16, 1)) - a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) - b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) - acc += tl.dot(a, b) - - c = acc.to(C.type.element_ty) - # rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - # rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - # rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) - # rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - # c_mask = (rm[:, None] < M) & (rn[None, :] < N) - # C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - # tl.store(C_, c, c_mask) - - rm, rn, mask, rm_start, rn_start = offset_for_tile(tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M, N) - - # Calculate the number of sub-tiles in each dimension - num_sub_tiles_m = tl.cdiv(BLOCK_SIZE_M, BLOCK_SIZE_M) - num_sub_tiles_n = tl.cdiv(BLOCK_SIZE_N, BLOCK_SIZE_N) - total_sub_tiles = num_sub_tiles_m * num_sub_tiles_n - - for sub_tile_idx in range(0, total_sub_tiles): - # Calculate start_row and start_col for the current sub-tile - start_row = (sub_tile_idx // num_sub_tiles_n) * BLOCK_SIZE_M - start_col = (sub_tile_idx % num_sub_tiles_n) * BLOCK_SIZE_N - - # Translate to global - sub_mask, global_offset = extract_submask_and_offset( - rm, - rn, - mask, - rm_start, - rn_start, - start_row, - start_col, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - stride_cm_global, - stride_cn_global, - ) - - # Store data to the global result using puts - for remote_rank in range(world_size): - if remote_rank == cur_rank: - # For the current rank, we can use store - tl.atomic_add(c_global + global_offset, c, mask=sub_mask) - else: - iris.atomic_add( - c_global + global_offset, - c, - cur_rank, - remote_rank, - heap_bases, - mask=sub_mask, - ) - - if COLLECT_TIMESTAMPS: - timestamp = read_realtime() - tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) diff --git a/examples/15_gemm_all_reduce_ring_based/benchmark.py b/examples/15_gemm_all_reduce_ring_based/benchmark.py new file mode 100755 index 00000000..505f865b --- /dev/null +++ b/examples/15_gemm_all_reduce_ring_based/benchmark.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import random +import sys +import os +import argparse +import json + +from examples.common.utils import ( + JSONWriter, + Timestamps, + is_triton_interpret_set, +) + +import iris + +from matmul_wrapper import matmul +from examples.common.validation import validate_gemm +from gemm_all_reduce_ring_based import persistent_all_reduce + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parse matrix dimensions and configuration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A") + parser.add_argument("-n", type=int, default=4608, help="Number of columns in matrix B") + parser.add_argument("-k", type=int, default=36864, help="Common dimension between matrices A and B") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + # For All Scatter, use: 256x64x64 + # For One Shot, use: 256x256x64 + parser.add_argument("--BLK_M", type=int, default=128, help="Block size M") + parser.add_argument("--BLK_N", type=int, default=128, help="Block size N") + parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") + + # Best to try 1, 6 or 8 + parser.add_argument("--gsize_m", type=int, default=6, help="Grid size M") + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") + + # For All Scatter, use: 288 + # For One Shot, use: 256 + parser.add_argument("--gemm_sms", type=int, default=256, help="Number of SMs for GEMM") + parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Scatter kernel") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Main benchmark logic + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + num_xcds = iris.hip.get_num_xcc() + + # GEMM + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "int8": + datatype = torch.int8 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + assert args["n"] % world_size == 0, f"N ({args['n']}) must be divisible by world size ({world_size})." + assert args["k"] % world_size == 0, f"K ({args['k']}) must be divisible by world size ({world_size})." + + A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype) + B = shmem.randn(args["n"], args["k"], device="cuda", dtype=datatype).T + + args["M"] = args["m"] + args["N"] = args["n"] + args["K"] = args["k"] + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + + # Splitting + rows_per_gpu = args["k"] // world_size + args["k"] = rows_per_gpu + start_row = rank * rows_per_gpu + end_row = start_row + rows_per_gpu + local_B = B[start_row:end_row, :] + local_A = A[:, start_row:end_row] + + for key, value in args.items(): + json_writer.add_field(key, value) + + C = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype) + local_C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=torch.float32) + + total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) + total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N + + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + flags = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + ring_buffer = shmem.zeros_like(C, dtype=torch.float32) + + bias = None + + gemm_stream = torch.cuda.Stream() + comm_stream = torch.cuda.Stream() + + json_writer.add_field("gemm_sms", args["gemm_sms"]) + json_writer.add_field("comm_sms", args["comm_sms"]) + + kernel_timing = { + "gemm": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + "communication": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + # Timestamps + timestamps = Timestamps(num_tiles=total_tiles) + + def preamble(): + shmem.barrier() + locks.zero_() + flags.zero_() + ring_buffer.zero_() + shmem.barrier() + + def run_experiment(): + nonlocal local_C + nonlocal C + nonlocal kernel_timing + nonlocal ring_buffer + + shmem.barrier() + + if args["trace_tiles"]: + timestamps.reset() + shmem.barrier() + + torch.cuda.nvtx.range_push("GEMM + Communication") + torch.cuda.nvtx.range_push("GEMM") + with torch.cuda.stream(gemm_stream): + kernel_timing["gemm"]["start_event"].record() + local_C = matmul.apply( + local_A, + local_B, + local_C, + bias, + locks, + rank, + world_size, + args["gemm_sms"], + args["BLK_M"], + args["BLK_N"], + args["BLK_K"], + args["gsize_m"], + shmem.get_heap_bases(), + "gfx942", + args["trace_tiles"], + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + kernel_timing["gemm"]["end_event"].record() + kernel_timing["gemm"]["experiments"] += 1 + + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("Communication") + with torch.cuda.stream(comm_stream): + kernel_timing["communication"]["start_event"].record() + ar = persistent_all_reduce[(args["comm_sms"],)]( + C, + local_C, + ring_buffer, + locks, + flags, + args["M"], + args["N"], + C.stride(0), + C.stride(1), + args["BLK_M"], + args["BLK_N"], + args["gsize_m"], + args["comm_sms"], + num_xcds, + shmem.get_heap_bases(), + rank, + world_size, + ) + kernel_timing["communication"]["end_event"].record() + kernel_timing["communication"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + shmem.barrier() + + for k in ["gemm", "communication"]: + ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) + kernel_timing[k]["ms"] += ms + + torch.cuda.nvtx.range_pop() + + # Synchronize across all GPUs + shmem.barrier() + + # Warmup + run_experiment() + + shmem.barrier() + preamble() + shmem.barrier() + + for k in ["gemm", "communication"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + if args["validate"]: + shmem.info("Validating...") + matmul.set_debug(True) + # Validate global result + success = validate_gemm(A, B, C, shmem, atol=2) + passed_str = "passed" if success else "failed" + shmem.info(f"Final C validation {passed_str}.") + + # Wait for all to finish validation + shmem.barrier() + shmem.info("Validation completed") + + json_writer.add_field("success", success) + + if not is_triton_interpret_set(): + gemm_registers = matmul.get_matmul_registers() + gemm_spills = matmul.get_matmul_spills() + + json_writer.add_field("gemm_registers", gemm_registers) + json_writer.add_field("gemm_spills", gemm_spills) + + if args["benchmark"]: + matmul.set_debug(False) + shmem.info("Benchmarking...") + perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) + triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) + triton_tflops = perf(triton_ms) + algo_string = "all_reduce" + shmem.info(f"tile matmul + {algo_string} (grid={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") + + json_writer.add_field("tflops", triton_tflops) + json_writer.add_field("total_ms", triton_ms) + + for k in ["gemm", "communication"]: + json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) + json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + if args["trace_tiles"] and rank == 0: + gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 + algo_string = "all_reduce" + filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json" + timestamps.to_json(filename, gpu_freq) + + shmem.barrier() + + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py b/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py new file mode 100644 index 00000000..1323d287 --- /dev/null +++ b/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from examples.common.utils import read_realtime + +import sys +import os + +import iris + + +@triton.jit() +def persistent_gemm( + A, + B, + local_C, + bias_ptr, + locks, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if local_C.type.element_ty != tl.int8 else tl.int32 + + for tile_id in range(pid, total_tiles, NUM_SMS): + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + tl.assume(pid_m > 0) + tl.assume(pid_n > 0) + + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + a = tl.load(tl.multiple_of(A_BASE, (1, 16))) + b = tl.load(tl.multiple_of(B_BASE, (16, 1))) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + # Add compiler hints + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Define the C-mask (BLOCK_SIZE_M, 1) x (1, BLOCK_SIZE_N) + mask = (rm[:, None] < M) & (rn[None, :] < N) + + # Calculate the "global" offset of C based on the rank. + # Note how each GPU is producing the entire output but partial-K. + offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn + + # Timestamp for GEMM before store + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) + + # Write fully-reduced tile to local result buffer (no remote writes) + tl.store(local_C + offset, acc, mask=mask, cache_modifier=".wt") + tl.debug_barrier() + tl.store(locks + tile_id, 1, cache_modifier=".wt") + + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) + + +@triton.jit() +def persistent_all_reduce( + C, + local_C, + ring_buffer, + locks, + flags, + M, + N, + stride_cm_global, + stride_cn_global, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, +): + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + # Precompute once + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + num_groups = tl.cdiv(total_tiles, world_size) + + next_rank = (cur_rank + 1) % world_size + prev_rank = (cur_rank + world_size - 1) % world_size + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + # Persistent across *groups* now (not individual tiles): + for g in range(pid, num_groups, COMM_SMS): + group_base = g * world_size + group_size = tl.minimum(world_size, total_tiles - group_base) # tail-safe + + # ---- Reduce-Scatter over this group of up to 'group_size' tiles ---- + for s in range(0, group_size): + # Tile index this rank handles at step s + idx = group_base + ((cur_rank + group_size - s) % group_size) + + # Map linear tile idx -> (pid_m, pid_n) using existing swizzle + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = idx // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((idx % num_pid_in_group) % group_size_m) + pid_n = (idx % num_pid_in_group) // group_size_m + + # Offsets/masks for this tile + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + sub_mask = (rm[:, None] < M) & (rn[None, :] < N) + goff = rm[:, None] * stride_cm_global + rn[None, :] * stride_cn_global + + if s == 0: + # First touch of this traveling tile on this rank: + # wait for local GEMM, seed acc from local partial, and forward. + while tl.atomic_cas(locks + idx, 0, 0, sem="acquire", scope="gpu") != 1: + pass + acc = tl.load(local_C + goff, mask=sub_mask, other=0).to(acc_dtype) + + if group_size > 1: + # Wait for NEXT rank to be ready (its flag should be 0, meaning it finished previous step) + while ( + iris.atomic_cas(flags + idx, 0, 0, cur_rank, next_rank, heap_bases, sem="acquire", scope="sys") + != 0 + ): + pass + # Send to NEXT and signal that tile 'idx' is ready for neighbor + iris.store(ring_buffer + goff, acc, cur_rank, next_rank, heap_bases, mask=sub_mask) + tl.debug_barrier() # Wait for all stores to complete before releasing the lock. + iris.atomic_xchg(flags + idx, 1, cur_rank, next_rank, heap_bases, sem="release", scope="sys") + else: + # Receive the traveling accumulator for this tile from PREV + while tl.atomic_cas(flags + idx, 0, 0, sem="acquire", scope="sys") != 1: + pass + recv = tl.load(ring_buffer + goff, mask=sub_mask, other=0).to(acc_dtype) + + # Wait for all to complete before releasing the lock. + # This one can technically be moved lower (closer to recv = tl.load), + # However, doing it much later allows for the two individual loads to issue and much-much + # later reset the lock. + tl.debug_barrier() + tl.atomic_xchg(flags + idx, 0, sem="release", scope="sys") # clear local flag + + # Fold in our local partial (wait if GEMM not done yet) + while tl.atomic_cas(locks + idx, 0, 0, sem="acquire", scope="gpu") != 1: + pass + + part = tl.load(local_C + goff, mask=sub_mask, other=0).to(acc_dtype) + acc = recv + part + + # Forward unless this is the last hop for this tile + if s < group_size - 1: + # Wait for NEXT rank to be ready (its flag should be 0, meaning it finished previous step) + while ( + iris.atomic_cas(flags + idx, 0, 0, cur_rank, next_rank, heap_bases, sem="acquire", scope="sys") + != 0 + ): + pass + iris.store(ring_buffer + goff, acc, cur_rank, next_rank, heap_bases, mask=sub_mask) + tl.debug_barrier() + iris.atomic_xchg(flags + idx, 1, cur_rank, next_rank, heap_bases, sem="release", scope="sys") + else: + # Last hop for tile idx on this rank: we own the fully reduced acc + c = acc.to(C.type.element_ty) + + # All-scatter when the results are ready. + # TODO: Technically, we commonly use an all-gather operation at the end as a separate loop? + for rank in range(world_size): + if rank == cur_rank: + tl.store(C + goff, c, mask=sub_mask) + else: + iris.store(C + goff, c, cur_rank, rank, heap_bases, mask=sub_mask) diff --git a/examples/15_gemm_all_reduce_ring_based/matmul_wrapper.py b/examples/15_gemm_all_reduce_ring_based/matmul_wrapper.py new file mode 100644 index 00000000..3dd2707e --- /dev/null +++ b/examples/15_gemm_all_reduce_ring_based/matmul_wrapper.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import random +import sys +import os + +# from streamk_kernel import streamk_gemm +# from streamk_kernel_atomic import streamk_gemm +from gemm_all_reduce_ring_based import persistent_gemm + +from examples.common.utils import is_triton_interpret_set +import iris + +gemm_kernel = persistent_gemm + + +class matmul(torch.autograd.Function): + _debug = True + _registers = None + _spills = None + + _num_xcds = iris.hip.get_num_xcc() + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def get_matmul_registers(): + if matmul._debug: + return matmul._registers + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def get_matmul_spills(): + if matmul._debug: + return matmul._spills + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def _call( + a: torch.Tensor, + b: torch.Tensor, + ring_buffer: torch.Tensor, + bias: torch.Tensor, + locks: torch.Tensor, + rank: int, + world_size: int, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + heap_bases_ptr: torch.Tensor = None, + arch: str = "gfx942", + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + # assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + num_xcds = matmul._num_xcds + + # TODO: Use arch-specific values. + num_stages = 2 + num_warps = 8 + waves_per_eu = 0 + mfma = 16 + kpack = 1 + + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + iters_per_tile = triton.cdiv(K, BLK_K) + total_tiles = total_blocks_M * total_blocks_N + even_k = K % BLK_K == 0 + use_bias = False + + # compute grid (work to do per SM on the first wave) + stride_bias = bias.stride(0) if use_bias else 0 + kk = gemm_kernel[(num_sms,)]( + a, + b, + ring_buffer, + bias, + locks, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + ring_buffer.stride(0), + ring_buffer.stride(1), + stride_bias, + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + NUM_SMS=num_sms, + NUM_XCDS=num_xcds, + BIAS=use_bias, + EVEN_K=even_k, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfma, + kpack=kpack, + heap_bases=heap_bases_ptr, + cur_rank=rank, + world_size=world_size, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp_ptr=mm_begin_timestamp, + mm_end_timestamp_ptr=mm_end_timestamp, + ) + + # if matmul._debug and not is_triton_interpret_set(): + matmul._registers = kk.n_regs + matmul._spills = kk.n_spills + + return ring_buffer + + @staticmethod + def forward( + ctx, + a: torch.Tensor, + b: torch.Tensor, + ring_buffer: torch.Tensor, + bias: torch.Tensor, + locks: torch.Tensor, + rank: int, + world_size: int, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + heap_bases_ptr: torch.Tensor = None, + arch: str = "gfx942", + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + matmul._call( + a=a, + b=b, + ring_buffer=ring_buffer, + bias=bias, + locks=locks, + rank=rank, + world_size=world_size, + num_sms=num_sms, + BLK_M=BLK_M, + BLK_N=BLK_N, + BLK_K=BLK_K, + gsize_m=gsize_m, + heap_bases_ptr=heap_bases_ptr, + arch=arch, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp=mm_begin_timestamp, + mm_end_timestamp=mm_end_timestamp, + ) + return ring_buffer diff --git a/examples/16_all_reduce_ring_based/all_reduce_ring_based.py b/examples/16_all_reduce_ring_based/all_reduce_ring_based.py new file mode 100644 index 00000000..333151a0 --- /dev/null +++ b/examples/16_all_reduce_ring_based/all_reduce_ring_based.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from examples.common.utils import read_realtime + +import sys +import os + +import iris + + +@triton.jit() +def persistent_all_reduce( + partials, + ring_buffer, + output, + flags, + M, + N, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, +): + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (COMM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + # Ring topology + next_rank = (cur_rank + 1) % world_size + prev_rank = (cur_rank + world_size - 1) % world_size + + acc_dtype = tl.float32 if output.type.element_ty != tl.int8 else tl.int32 + + for tile_id in range(pid, total_tiles, COMM_SMS): + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + # Begin: See the if segment for explanation: + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + mask = (rm[:, None] < M) & (rn[None, :] < N) + offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn + # End: masks/offset calculations. + + # Initialize accumulator with local partial result from ring_buffer + acc = tl.load(partials + offset, mask=mask).to(acc_dtype) + + # Each rank sends its LOCAL partial result (not accumulated) around the ring + # while accumulating received partial results from other ranks. + # + # Initial: Each rank has computed a partial-K GEMM result in 'acc' + # Goal: Sum all partial results from all ranks + # + # Algorithm: Use ring_buffer to pass data around, accumulate locally + # - send_data: What we send (starts as our partial result) + # - acc: Running sum of all partial results received so far + + # Initialize: First, write our partial result to ring_buffer for sending + send_data = acc + + # Step loop: send to next, wait/recv from prev, add. + for _step in range(0, world_size - 1): + # 1a) Wait for NEXT rank to be ready (its lock should be 0, meaning it finished previous step) + # This prevents overwriting data that hasn't been consumed yet + while ( + iris.atomic_cas(flags + tile_id, 0, 0, cur_rank, next_rank, heap_bases, sem="acquire", scope="sys") != 0 + ): + pass + + # 1b) Send our current accumulator tile to NEXT rank's ring buffer + iris.store(ring_buffer + offset, send_data, cur_rank, next_rank, heap_bases, mask=mask) + + tl.debug_barrier() + # Signal "ready" by setting NEXT rank's flag for this tile to 1 + iris.atomic_xchg(flags + tile_id, 1, cur_rank, next_rank, heap_bases, sem="release", scope="sys") + + # 2) Wait for PREV rank to signal our local flag for this tile + while tl.atomic_cas(flags + tile_id, 0, 0, sem="acquire", scope="sys") != 1: + pass + + # 3) Consume the received tile from our LOCAL ring buffer (prev wrote here) + recv_tile = tl.load(ring_buffer + offset, mask=mask, other=tl.zeros_like(acc)) + + # 4) Accumulate received data and prepare to forward it in next iteration + acc += recv_tile # tl.load(ring_buffer + offset, mask=mask) + send_data = recv_tile # Forward what we just received (not the accumulated sum) + + # 5) Reset our local flag to 0 (done consuming this step) + tl.atomic_xchg(flags + tile_id, 0, sem="release", scope="sys") + + # Write fully-reduced tile to local result buffer (no remote writes) + o = acc.to(output.type.element_ty) + tl.store(output + offset, o, mask=mask) diff --git a/examples/16_all_reduce_ring_based/benchmark.py b/examples/16_all_reduce_ring_based/benchmark.py new file mode 100755 index 00000000..5aba7de5 --- /dev/null +++ b/examples/16_all_reduce_ring_based/benchmark.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import random +import sys +import os +import argparse +import json + +from examples.common.utils import ( + JSONWriter, + Timestamps, + is_triton_interpret_set, +) + +import iris + +from all_reduce_ring_based import persistent_all_reduce + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parse matrix dimensions and configuration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=8192, help="Number of rows in input/output matrix") + parser.add_argument("-n", type=int, default=4608, help="Number of columns in input/output matrix") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + # For All Scatter, use: 256x64x64 + # For One Shot, use: 256x256x64 + parser.add_argument("--BLK_M", type=int, default=128, help="Block size M") + parser.add_argument("--BLK_N", type=int, default=128, help="Block size N") + + # Best to try 1, 6 or 8 + parser.add_argument("--gsize_m", type=int, default=6, help="Grid size M") + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") + + # For All Scatter, use: 288 + # For One Shot, use: 256 + parser.add_argument("--num_sms", type=int, default=48, help="Number of SMs for All-Reduce kernel") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Main benchmark logic + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + cu_count = shmem.get_cu_count() + num_xcds = iris.hip.get_num_xcc() + + # datatypes + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "int8": + datatype = torch.int8 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + args["M"] = args["m"] + args["N"] = args["n"] + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Initialize partial with random data for each rank + # In all_reduce, each rank has a partial result that needs to be summed across all ranks + torch.manual_seed(123 + rank) # Different seed per rank for different data + partial = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=datatype) + partial.copy_(torch.randn((args["M"], args["N"]), device="cuda", dtype=datatype)) + + output = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=datatype) + + total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) + total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N + + flags = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + ring_buffer = shmem.zeros_like(partial, dtype=torch.float32) + comm_stream = torch.cuda.Stream() + + json_writer.add_field("num_sms", args["num_sms"]) + + kernel_timing = { + "communication": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + # Timestamps + timestamps = Timestamps(num_tiles=total_tiles) + + def preamble(): + shmem.barrier() + flags.zero_() + ring_buffer.zero_() + shmem.barrier() + + def run_experiment(): + nonlocal output + nonlocal partial + nonlocal kernel_timing + nonlocal ring_buffer + + shmem.barrier() + + if args["trace_tiles"]: + timestamps.reset() + shmem.barrier() + + torch.cuda.nvtx.range_push("Communication") + with torch.cuda.stream(comm_stream): + kernel_timing["communication"]["start_event"].record() + ar = persistent_all_reduce[(args["num_sms"],)]( + partial, + ring_buffer, + output, + flags, + args["M"], + args["N"], + output.stride(0), + output.stride(1), + args["BLK_M"], + args["BLK_N"], + args["gsize_m"], + args["num_sms"], + num_xcds, + shmem.get_heap_bases(), + rank, + world_size, + ) + kernel_timing["communication"]["end_event"].record() + kernel_timing["communication"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + shmem.barrier() + + for k in ["communication"]: + ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) + kernel_timing[k]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + # Warmup + run_experiment() + + shmem.barrier() + preamble() + shmem.barrier() + + for k in ["communication"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + if args["validate"]: + shmem.info("Validating...") + + # Run the experiment once to populate output + run_experiment() + shmem.barrier() + + # Create a reference result using torch.distributed.all_reduce + # Save original partial values for reference computation + partial_copy = partial.clone() + expected_output = partial_copy.clone() + + # Use NCCL all_reduce to compute the expected result + dist.all_reduce(expected_output, op=dist.ReduceOp.SUM) + + # Compare the output from our kernel with the expected result + success = torch.allclose(output, expected_output, atol=2) + max_diff = torch.max(torch.abs(output - expected_output)).item() + + if success: + shmem.info(f"Final validation passed. Max difference: {max_diff}") + else: + shmem.info(f"Final validation failed. Max difference: {max_diff}") + + # Wait for all to finish validation + shmem.barrier() + shmem.info("Validation completed") + + json_writer.add_field("success", success) + + if args["benchmark"]: + shmem.info("Benchmarking...") + # Calculate bandwidth instead of FLOPS since there's no GEMM + # All-reduce moves 2 * (world_size - 1) / world_size * data_size bytes + data_size_bytes = ( + args["M"] * args["N"] * 2 + if datatype == torch.float16 or datatype == torch.bfloat16 + else args["M"] * args["N"] * 4 + ) + perf = lambda ms: (2 * (world_size - 1) / world_size * data_size_bytes * 1e-9) / (ms * 1e-3) # GB/s + triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) + bandwidth_gbps = perf(triton_ms) + algo_string = "all_reduce" + shmem.info(f"{algo_string} (grid={total_tiles}): {triton_ms:.3f} ms {bandwidth_gbps:.3f} GB/s") + + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + + for k in ["communication"]: + json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) + json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + if args["trace_tiles"] and rank == 0: + gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 + algo_string = "all_reduce" + filename = f"comm_tiles_{algo_string}_trace_rank{rank}.json" + timestamps.to_json(filename, gpu_freq) + + shmem.barrier() + + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main()