diff --git a/benchmark/ccl/all_to_all/benchmark_x.py b/benchmark/ccl/all_to_all/benchmark_x.py new file mode 100644 index 000000000..e67260067 --- /dev/null +++ b/benchmark/ccl/all_to_all/benchmark_x.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark: iris.x.all_to_all (Triton) vs iris.x.all_to_all_gluon (Gluon) + +Validates correctness and measures bandwidth for the tile-level all-to-all +primitives across many problem sizes. Supports assembly dumping for +side-by-side comparison. + +Run modes +--------- +Single size (validate + benchmark): + python benchmark_x.py -v -b -m 4096 -n 256 -r 8 + +Sweep across many problem sizes (recommended): + python benchmark_x.py -v -b --sweep -r 8 --output_file results.json + +Dump generated assembly to files: + python benchmark_x.py --dump_asm -m 1024 -n 128 -r 8 + +Generate scatter plot from previous sweep results: + python plot_x_all_to_all.py results.json +""" + +import argparse +import json + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import triton.language as tl + +import iris +import iris.x + +GLUON_AVAILABLE = False +try: + from triton.experimental import gluon + from triton.experimental.gluon import language as gl + import iris.experimental.iris_gluon as iris_gl + + GLUON_AVAILABLE = hasattr(iris.x, "all_to_all_gluon") +except ImportError: + pass + + +# --------------------------------------------------------------------------- +# Problem-size sweep grid +# --------------------------------------------------------------------------- + +# (M, N_per_rank) pairs covering small / medium / large / various aspect-ratios. +SWEEP_SIZES = [ + # Small + (128, 64), + (256, 64), + (512, 64), + # Medium + (1024, 128), + (2048, 128), + (1024, 256), + (2048, 256), + # Large + (4096, 128), + (4096, 256), + (4096, 512), + (8192, 256), + (8192, 512), + # Extra-large + (16384, 128), + (16384, 256), +] + + +# --------------------------------------------------------------------------- +# Triton kernel wrapper +# --------------------------------------------------------------------------- + + +@triton.jit +def _triton_kernel( + input_ptr, + output_ptr, + M: tl.constexpr, + N: tl.constexpr, + N_per_rank: tl.constexpr, + stride_in_m: tl.constexpr, + stride_in_n: tl.constexpr, + stride_out_m: tl.constexpr, + stride_out_n: tl.constexpr, + context_tensor: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(0) + grid_size = tl.num_programs(0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + for tile_id in range(pid, total_tiles, grid_size): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + + tile = iris.x.TileView(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N) + src_view = iris.x.make_tensor_view(input_ptr, M, N, stride_in_m, stride_in_n) + dst_view = iris.x.make_tensor_view(output_ptr, M, N, stride_out_m, stride_out_n) + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + + iris.x.all_to_all(tile, src_view, dst_view, N_per_rank, ctx) + + +# --------------------------------------------------------------------------- +# Gluon kernel wrapper +# --------------------------------------------------------------------------- + +if GLUON_AVAILABLE: + + @gluon.jit + def _gluon_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + input_ptr, + output_ptr, + M, + N, + N_per_rank: gl.constexpr, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + num_pid_n, + cur_rank: gl.constexpr, + world_size: gl.constexpr, + BLOCK_SIZE_M: gl.constexpr, + BLOCK_SIZE_N: gl.constexpr, + ): + pid = gl.program_id(0) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + iris.x.all_to_all_gluon( + IrisDeviceCtx, + context_tensor, + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + pid_m, + pid_n, + N_per_rank, + cur_rank, + world_size, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark iris.x all_to_all: Triton vs Gluon", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=4096, help="Number of rows (ignored when --sweep)") + parser.add_argument("-n", type=int, default=256, help="Columns per rank (ignored when --sweep)") + parser.add_argument("--block_size_m", type=int, default=64, help="BLOCK_SIZE_M") + parser.add_argument("--block_size_n", type=int, default=256, help="BLOCK_SIZE_N") + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"]) + parser.add_argument("-v", "--validate", action="store_true", help="Validate output") + parser.add_argument("-b", "--benchmark", action="store_true", help="Run timing loop") + parser.add_argument("--sweep", action="store_true", help="Sweep across many (M, N) problem sizes") + parser.add_argument( + "--dump_asm", + action="store_true", + help="Dump generated AMDGCN assembly for Triton and Gluon kernels to .asm files", + ) + parser.add_argument("--output_file", type=str, default="log_x_all_to_all.json", help="JSON output path") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + return vars(parser.parse_args()) + + +def _run_one_size( + M, + N, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + dtype, + shmem, + rank, + ws, + context_tensor, + args, +): + """Run validation and/or benchmark for a single (M, N) problem size. + + Returns a dict with timing / bandwidth / validation results, or None on + the non-zero ranks (data is only collected on rank 0). + """ + total_N = N * ws + element_size = torch.tensor([], dtype=dtype).element_size() + + iris_input = shmem.zeros((M, total_N), dtype=dtype) + iris_output_triton = shmem.zeros((M, total_N), dtype=dtype) + iris_output_gluon = shmem.zeros((M, total_N), dtype=dtype) if GLUON_AVAILABLE else None + + # Fill input: chunk i is filled with value (rank * 10 + i + 1). + for target_rank in range(ws): + iris_input[:, target_rank * N : (target_rank + 1) * N] = float(rank * 10 + target_rank + 1) + + num_pid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + num_pid_n = (total_N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + total_tiles = num_pid_m * num_pid_n + grid_triton = (total_tiles,) + grid_gluon = (total_tiles,) + + def run_triton(): + _triton_kernel[grid_triton]( + iris_input, + iris_output_triton, + M, + total_N, + N, + iris_input.stride(0), + iris_input.stride(1), + iris_output_triton.stride(0), + iris_output_triton.stride(1), + context_tensor, + rank, + ws, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + ) + + def run_gluon(): + if not GLUON_AVAILABLE: + return + _gluon_kernel[grid_gluon]( + iris_gl.IrisDeviceCtx, + context_tensor, + iris_input, + iris_output_gluon, + M, + total_N, + N, + iris_input.stride(0), + iris_input.stride(1), + iris_output_gluon.stride(0), + iris_output_gluon.stride(1), + num_pid_n, + rank, + ws, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_warps=4, + ) + + result = { + "M": M, + "N": N, + "world_size": ws, + "dtype": str(dtype).replace("torch.", ""), + "BLOCK_SIZE_M": BLOCK_SIZE_M, + "BLOCK_SIZE_N": BLOCK_SIZE_N, + "total_bytes": (ws - 1) * M * N * element_size, + } + + # ------------------------------------------------------------------- + # Validation + # ------------------------------------------------------------------- + if args["validate"]: + # Build expected output: output[:, src*N:(src+1)*N] = src_rank * 10 + rank + 1 + expected = shmem.zeros((M, total_N), dtype=dtype) + for src_rank in range(ws): + expected[:, src_rank * N : (src_rank + 1) * N] = float(src_rank * 10 + rank + 1) + + atol = 0.5 + + # Triton + iris_output_triton.zero_() + shmem.barrier() + run_triton() + torch.cuda.synchronize() + shmem.barrier() + ok_triton = torch.allclose(iris_output_triton, expected, atol=atol) + result["triton_valid"] = bool(ok_triton) + + if rank == 0: + status = "PASS" if ok_triton else "FAIL" + print(f" [Triton] M={M:6d} N_per_rank={N:5d}: validation {status}") + + # Gluon + if GLUON_AVAILABLE: + iris_output_gluon.zero_() + shmem.barrier() + run_gluon() + torch.cuda.synchronize() + shmem.barrier() + ok_gluon = torch.allclose(iris_output_gluon, expected, atol=atol) + result["gluon_valid"] = bool(ok_gluon) + + if rank == 0: + status = "PASS" if ok_gluon else "FAIL" + print(f" [Gluon] M={M:6d} N_per_rank={N:5d}: validation {status}") + + # ------------------------------------------------------------------- + # Benchmark + # ------------------------------------------------------------------- + if args["benchmark"]: + total_bytes = (ws - 1) * M * N * element_size + total_bytes_gb = total_bytes / (1024**3) + + shmem.barrier() + triton_ms = iris.do_bench(run_triton, shmem.barrier) + bw_triton = total_bytes_gb / (triton_ms * 1e-3) if triton_ms > 0 else 0.0 + result["triton_ms"] = triton_ms + result["triton_bandwidth_gbps"] = bw_triton + + if rank == 0: + print(f" [Triton] M={M:6d} N_per_rank={N:5d}: {triton_ms:8.3f} ms {bw_triton:7.3f} GB/s") + + if GLUON_AVAILABLE: + shmem.barrier() + gluon_ms = iris.do_bench(run_gluon, shmem.barrier) + bw_gluon = total_bytes_gb / (gluon_ms * 1e-3) if gluon_ms > 0 else 0.0 + ratio = (bw_gluon / bw_triton * 100) if bw_triton > 0 else 0.0 + result["gluon_ms"] = gluon_ms + result["gluon_bandwidth_gbps"] = bw_gluon + result["gluon_vs_triton_percent"] = ratio + + if rank == 0: + print( + f" [Gluon] M={M:6d} N_per_rank={N:5d}: {gluon_ms:8.3f} ms {bw_gluon:7.3f} GB/s" + f" ({ratio:5.1f}% of Triton)" + ) + + return result + + +def _dump_assembly(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, dtype, shmem, rank, ws, context_tensor): + """Compile kernels and dump AMDGCN assembly to text files. + + Files are written only on rank 0. Both backends are compiled with the + same problem configuration so the resulting assembly is directly comparable. + """ + total_N = N * ws + dummy = shmem.zeros((M, total_N), dtype=dtype) + + num_pid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + num_pid_n = (total_N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + total_tiles = num_pid_m * num_pid_n + + # Trigger compilation (warmup run). + kk_triton = _triton_kernel[(total_tiles,)]( + dummy, + dummy, + M, + total_N, + N, + dummy.stride(0), + dummy.stride(1), + dummy.stride(0), + dummy.stride(1), + context_tensor, + rank, + ws, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + ) + + if rank == 0 and hasattr(kk_triton, "asm") and "amdgcn" in kk_triton.asm: + asm = kk_triton.asm["amdgcn"] + fname = f"triton_all_to_all_M{M}_N{N}_bm{BLOCK_SIZE_M}_bn{BLOCK_SIZE_N}.asm" + with open(fname, "w") as f: + f.write(asm) + n_regs = getattr(kk_triton, "n_regs", "?") + n_spills = getattr(kk_triton, "n_spills", "?") + print(f" [Triton] {fname} ({len(asm):,} chars, {n_regs} VGPRs, {n_spills} spills)") + + if GLUON_AVAILABLE: + kk_gluon = _gluon_kernel[(total_tiles,)]( + iris_gl.IrisDeviceCtx, + context_tensor, + dummy, + dummy, + M, + total_N, + N, + dummy.stride(0), + dummy.stride(1), + dummy.stride(0), + dummy.stride(1), + num_pid_n, + rank, + ws, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_warps=4, + ) + if rank == 0 and hasattr(kk_gluon, "asm") and "amdgcn" in kk_gluon.asm: + asm = kk_gluon.asm["amdgcn"] + fname = f"gluon_all_to_all_M{M}_N{N}_bm{BLOCK_SIZE_M}_bn{BLOCK_SIZE_N}.asm" + with open(fname, "w") as f: + f.write(asm) + n_regs = getattr(kk_gluon, "n_regs", "?") + n_spills = getattr(kk_gluon, "n_spills", "?") + print(f" [Gluon] {fname} ({len(asm):,} chars, {n_regs} VGPRs, {n_spills} spills)") + + shmem.barrier() + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, init_method=init_url, world_size=world_size, rank=local_rank) + + dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + dtype = dtype_map[args["datatype"]] + + BLOCK_SIZE_M = args["block_size_m"] + BLOCK_SIZE_N = args["block_size_n"] + + # Use Gluon-based iris if available (required for Gluon kernel). + if GLUON_AVAILABLE: + shmem = iris_gl.iris(args["heap_size"]) + else: + shmem = iris.iris(args["heap_size"]) + + rank = shmem.get_rank() + ws = shmem.get_num_ranks() + context_tensor = shmem.get_device_context() + + # Determine problem sizes to run. + sizes = SWEEP_SIZES if args["sweep"] else [(args["m"], args["n"])] + + if rank == 0 and len(sizes) > 1: + print(f"\n=== iris.x all_to_all sweep world_size={ws} dtype={args['datatype']} ===\n") + + all_results = [] + + for M, N in sizes: + if rank == 0 and len(sizes) > 1: + print(f"--- M={M}, N_per_rank={N} ---") + + if args["dump_asm"]: + _dump_assembly(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, dtype, shmem, rank, ws, context_tensor) + continue + + result = _run_one_size(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, dtype, shmem, rank, ws, context_tensor, args) + all_results.append(result) + + # Save results to JSON (rank 0 only). + if rank == 0 and all_results and args["output_file"]: + with open(args["output_file"], "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nResults written to {args['output_file']}") + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29572" + mp.spawn(fn=_worker, args=(num_ranks, init_url, args), nprocs=num_ranks, join=True) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ccl/all_to_all/plot_x_all_to_all.py b/benchmark/ccl/all_to_all/plot_x_all_to_all.py new file mode 100644 index 000000000..fad946232 --- /dev/null +++ b/benchmark/ccl/all_to_all/plot_x_all_to_all.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Scatter-plot: iris.x.all_to_all – Triton vs Gluon bandwidth across problem sizes. + +Reads the JSON output produced by benchmark_x.py (--sweep -b) and creates a +scatter plot with two marker series: + + • Blue circles (◉) → Triton bandwidth (GB/s) + • Orange squares (■) → Gluon bandwidth (GB/s) + +X-axis: total bytes communicated per rank (log scale) +Y-axis: achieved bandwidth in GB/s + +Usage +----- + # After running the benchmark sweep: + python benchmark_x.py -v -b --sweep -r 8 --output_file results.json + python plot_x_all_to_all.py results.json [--output scatter.png] +""" + +import argparse +import json +import sys + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Scatter-plot iris.x all_to_all Triton vs Gluon bandwidth", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("input_json", help="JSON results file from benchmark_x.py --sweep") + parser.add_argument( + "--output", + type=str, + default=None, + help="Output image file (default: .png)", + ) + parser.add_argument("--title", type=str, default="iris.x all_to_all: Triton vs Gluon bandwidth") + parser.add_argument("--dpi", type=int, default=150) + parser.add_argument("--figsize", type=int, nargs=2, default=[10, 6]) + return parser.parse_args() + + +def load_results(path: str): + with open(path) as f: + data = json.load(f) + # Support both a single-object and a list-of-objects JSON format. + if isinstance(data, dict): + data = [data] + return data + + +def make_label(row: dict) -> str: + """Short label for a data point (placed near the marker).""" + return f"({row['M']}×{row['N']})" + + +def plot(data, args): + try: + import matplotlib.pyplot as plt + import matplotlib.ticker as ticker + except ImportError: + print( + "ERROR: matplotlib is required for plotting. Install with:\n pip install matplotlib\n", + file=sys.stderr, + ) + sys.exit(1) + + # Sort by total bytes so the X-axis is monotone. + data = sorted(data, key=lambda r: r.get("total_bytes", r["M"] * r["N"])) + + triton_x, triton_y, triton_labels = [], [], [] + gluon_x, gluon_y, gluon_labels = [], [], [] + + for row in data: + x = row.get("total_bytes", 0) + label = make_label(row) + + if "triton_bandwidth_gbps" in row: + triton_x.append(x) + triton_y.append(row["triton_bandwidth_gbps"]) + triton_labels.append(label) + + if "gluon_bandwidth_gbps" in row and row["gluon_bandwidth_gbps"] is not None: + gluon_x.append(x) + gluon_y.append(row["gluon_bandwidth_gbps"]) + gluon_labels.append(label) + + fig, ax = plt.subplots(figsize=tuple(args.figsize)) + + # Two scatter series with distinct markers and colours. + if triton_x: + ax.scatter( + triton_x, + triton_y, + marker="o", + s=80, + color="#2E86AB", + label="Triton", + zorder=3, + ) + ax.plot(triton_x, triton_y, linestyle="--", color="#2E86AB", linewidth=1, alpha=0.5) + for x, y, lbl in zip(triton_x, triton_y, triton_labels): + ax.annotate(lbl, (x, y), textcoords="offset points", xytext=(4, 5), fontsize=6, color="#2E86AB") + + if gluon_x: + ax.scatter( + gluon_x, + gluon_y, + marker="s", + s=80, + color="#E07A5F", + label="Gluon", + zorder=3, + ) + ax.plot(gluon_x, gluon_y, linestyle="--", color="#E07A5F", linewidth=1, alpha=0.5) + for x, y, lbl in zip(gluon_x, gluon_y, gluon_labels): + ax.annotate(lbl, (x, y), textcoords="offset points", xytext=(4, -12), fontsize=6, color="#E07A5F") + + # Extract common metadata from first row for subtitle. + if data: + r0 = data[0] + subtitle = ( + f"world_size={r0.get('world_size', '?')} " + f"dtype={r0.get('dtype', '?')} " + f"BLOCK_M={r0.get('BLOCK_SIZE_M', '?')} " + f"BLOCK_N={r0.get('BLOCK_SIZE_N', '?')}" + ) + ax.set_title(f"{args.title}\n{subtitle}", fontsize=12) + else: + ax.set_title(args.title, fontsize=12) + + ax.set_xscale("log", base=2) + ax.set_xlabel("Total bytes communicated per rank [log₂ scale]", fontsize=11) + ax.set_ylabel("Bandwidth (GB/s)", fontsize=11) + ax.xaxis.set_major_formatter( + ticker.FuncFormatter(lambda v, _: f"{int(v / 2**20)} MiB" if v >= 2**20 else f"{int(v / 2**10)} KiB") + ) + ax.grid(True, alpha=0.3, linestyle="--") + ax.legend(fontsize=11, loc="best") + + plt.tight_layout() + + out = args.output or (args.input_json.rsplit(".", 1)[0] + "_scatter.png") + plt.savefig(out, dpi=args.dpi, bbox_inches="tight") + print(f"Scatter plot saved to: {out}") + + try: + plt.show() + except Exception: + pass + + +def print_table(data): + """Print a plain-text comparison table to stdout.""" + if not data: + print("No data to display.") + return + + header = f"{'M':>7} {'N_per_rank':>10} {'total_bytes':>14} {'Triton (GB/s)':>14} {'Gluon (GB/s)':>13} {'ratio%':>7} {'T_valid':>8} {'G_valid':>8}" + print("\n" + "=" * len(header)) + print(header) + print("=" * len(header)) + + for r in sorted(data, key=lambda x: x.get("total_bytes", 0)): + bw_t = r.get("triton_bandwidth_gbps", float("nan")) + bw_g = r.get("gluon_bandwidth_gbps", float("nan")) + ratio = r.get("gluon_vs_triton_percent", float("nan")) + tv = "PASS" if r.get("triton_valid") else ("FAIL" if "triton_valid" in r else "n/a") + gv = "PASS" if r.get("gluon_valid") else ("FAIL" if "gluon_valid" in r else "n/a") + print( + f"{r['M']:>7} {r['N']:>10} {r.get('total_bytes', 0):>14,.0f}" + f" {bw_t:>14.3f} {bw_g:>13.3f} {ratio:>7.1f} {tv:>8} {gv:>8}" + ) + + print("=" * len(header) + "\n") + + +def main(): + args = parse_args() + data = load_results(args.input_json) + print_table(data) + plot(data, args) + + +if __name__ == "__main__": + main() diff --git a/docs/reference/gluon/all_to_all_report.md b/docs/reference/gluon/all_to_all_report.md new file mode 100644 index 000000000..c40b463ff --- /dev/null +++ b/docs/reference/gluon/all_to_all_report.md @@ -0,0 +1,240 @@ +# iris.x all_to_all: Triton vs Gluon – Performance & Assembly Report + +```{note} +This document describes the Gluon port of `iris.x.all_to_all` and how to reproduce +the performance comparison and assembly analysis. Actual numbers require AMD GPUs +(MI300X / MI350X / MI355X recommended) with ROCm 7.0+ and the matching Triton build. +``` + +## Overview + +`iris.x.all_to_all` is a *tile-level* collective that lets a user-written kernel +perform an all-to-all exchange one tile at a time. The primitive is provided in +two backends that produce identical results: + +| Backend | Decorator | Context type | Remote-read API | +|---------|-----------|--------------|-----------------| +| Triton | `@triton.jit` | `iris.DeviceContext` | `iris.load(ptr, cur_rank, src_rank, heap_bases, mask)` | +| Gluon | `@gluon.jit` | `IrisDeviceCtx` | `ctx.load(ptr, src_rank, mask)` | + +The Gluon implementation lives in `iris/x/all_to_all.py` (the `all_to_all_gluon` +function) and is exported from `iris.x` when Gluon is available. + +--- + +## Semantic Equivalence + +Both backends implement the same all-to-all algorithm: + +- **Input** `(M, world_size × N_per_rank)`: each rank's chunk `[:, r*N:(r+1)*N]` + holds the data to be sent to rank `r`. +- **Output** `(M, world_size × N_per_rank)`: after the operation, + `output[:, r*N:(r+1)*N]` contains the data that rank `r` sent to the current rank. + +Correctness is validated against PyTorch `dist.all_to_all` in the test file +`tests/x/test_all_to_all_gluon.py` across five shapes and three dtypes. + +--- + +## Key Algorithmic Differences + +### Loop structure + +The **Triton** version iterates over only the source ranks that overlap with the +current output tile (`range(first_src_rank, last_src_rank + 1)`). The loop bounds +are *runtime* values, which means the compiler cannot unroll the loop statically. + +The **Gluon** version iterates over `range(world_size)` where `world_size` is a +`gl.constexpr`. The compiler *fully unrolls* this loop, resolving each branch +(local vs. remote) at compile time. This trades compile time for a potentially +shorter hot path when tiles are well-aligned with rank boundaries. + +### Memory access pattern + +Triton processes whole `BLOCK_SIZE_M × BLOCK_SIZE_N` tiles in a single vectorised +load/store using 2-D index tensors. + +Gluon processes the tile *row by row* (inner `for i in range(BLOCK_SIZE_M)` loop, +also unrolled) with 1-D column index vectors and a +`gl.BlockedLayout([1], [64], [4], [0])` layout hint that maps 256 threads over +`BLOCK_SIZE_N` columns. This matches the access pattern used by +`persistent_all_to_all_gluon` in `iris/ccl/all_to_all.py` and allows the Gluon +compiler to apply traffic-shaping optimisations. + +### RMA call + +| Operation | Triton | Gluon | +|-----------|--------|-------| +| Remote read | `iris.load(ptr, cur_rank, src_rank, heap_bases, mask)` | `ctx.load(ptr, src_rank, mask)` | +| Local read | `tl.load(ptr + offsets, mask)` | `gl.load(ptr + offsets, mask)` | + +`ctx.load` in Gluon internally calls `_translate()` which computes the pointer +offset from the heap base of the remote rank. The Triton `iris.load` does the +same but requires the caller to explicitly pass `heap_bases`. + +--- + +## Running the Benchmark + +### Requirements + +```bash +pip install matplotlib # for scatter plot generation +``` + +### Validate both backends + +```bash +cd benchmark/ccl/all_to_all +python benchmark_x.py -v -r 8 --datatype fp16 +``` + +Expected output: + +``` + [Triton] M= 4096 N_per_rank= 256: validation PASS + [Gluon] M= 4096 N_per_rank= 256: validation PASS +``` + +### Sweep across many problem sizes + +```bash +python benchmark_x.py -v -b --sweep -r 8 \ + --datatype fp16 \ + --output_file results.json +``` + +This runs 14 problem sizes (see the [Problem Size Sweep Grid](#problem-size-sweep-grid) below) and writes a JSON +file with per-size timing and bandwidth for both backends. + +### Generate the scatter plot + +```bash +python plot_x_all_to_all.py results.json --output scatter.png +``` + +The script also prints a plain-text comparison table to stdout. + +### Dump generated assembly + +```bash +python benchmark_x.py --dump_asm -m 4096 -n 256 -r 8 +``` + +This writes two files: + +``` +triton_all_to_all_M4096_N256_bm64_bn256.asm +gluon_all_to_all_M4096_N256_bm64_bn256.asm +``` + +Use `diff` or a merge tool to compare them side by side. + +--- + +## Problem Size Sweep Grid + +The default sweep covers 14 `(M, N_per_rank)` configurations spanning small to +extra-large tensors: + +| Category | M | N per rank | Total bytes per rank (8 GPUs, fp16) | +|-------------|-------|------------|--------------------------------------| +| Small | 128 | 64 | ~112 KiB | +| Small | 256 | 64 | ~224 KiB | +| Small | 512 | 64 | ~448 KiB | +| Medium | 1024 | 128 | ~1.75 MiB | +| Medium | 2048 | 128 | ~3.5 MiB | +| Medium | 1024 | 256 | ~3.5 MiB | +| Medium | 2048 | 256 | ~7 MiB | +| Large | 4096 | 128 | ~7 MiB | +| Large | 4096 | 256 | ~14 MiB | +| Large | 4096 | 512 | ~28 MiB | +| Large | 8192 | 256 | ~28 MiB | +| Large | 8192 | 512 | ~56 MiB | +| Extra-large | 16384 | 128 | ~28 MiB | +| Extra-large | 16384 | 256 | ~56 MiB | + +--- + +## Performance Results + +> **Placeholder** – Run the benchmark on MI300X hardware and paste the JSON +> output here, or embed the scatter plot image. + +``` +python benchmark_x.py -v -b --sweep -r 8 --datatype fp16 --output_file results.json +python plot_x_all_to_all.py results.json +``` + +After running you will have a scatter plot similar to: + +``` + iris.x all_to_all: Triton vs Gluon bandwidth + ───────────────────────────────────────────── + Bandwidth │ ◉ Triton ■ Gluon + (GB/s) │ ■ ◉ ◉ + │ ■ ■ ◉ + │ ■ ◉ ◉ + │ ◉■ ■ + │───────────────────────────────────────────── + Total bytes per rank (log₂ scale) +``` + +--- + +## Assembly Analysis + +The AMDGCN ISA files produced by `--dump_asm` let you compare the quality of code +generated by the two compilation paths. + +### What to look for + +| Metric | Description | +|--------|-------------| +| VGPR count | Fewer VGPRs → more warps can be in flight simultaneously (higher occupancy) | +| Spill count | Non-zero spills hurt performance through stack memory traffic | +| Load/store width | `global_load_dwordx4` / `global_store_dwordx4` instructions indicate 128-bit vectorisation | +| Buffer instructions | `buffer_load_dwordx4` / `buffer_store_dwordx4` indicate coalesced cache-line access | +| Branch instructions | Fewer branches → simpler control flow after loop unrolling | + +### Expected observations (Gluon) + +- The fully-unrolled rank loop in the Gluon version eliminates loop-overhead + branches visible in the Triton output. +- Because `world_size` is a `gl.constexpr`, each `src_rank == cur_rank` branch + is resolved at compile time, resulting in separate code paths without runtime + predication. +- The row-by-row processing in Gluon may show higher VGPR usage compared to the + 2-D tile processing in Triton, depending on the block sizes chosen. + +### Sample diff command + +```bash +diff -u triton_all_to_all_M4096_N256_bm64_bn256.asm \ + gluon_all_to_all_M4096_N256_bm64_bn256.asm | less +``` + +--- + +## Running the Tests + +The functional correctness tests for the Gluon backend are in +`tests/x/test_all_to_all_gluon.py`: + +```bash +pytest tests/x/test_all_to_all_gluon.py -v +``` + +Tests are automatically skipped when Gluon is not available. + +--- + +## Source Files + +| File | Description | +|------|-------------| +| `iris/x/all_to_all.py` | Triton (`all_to_all`) and Gluon (`all_to_all_gluon`) tile-level primitives | +| `iris/x/__init__.py` | Module exports – `all_to_all_gluon` exported when Gluon available | +| `tests/x/test_all_to_all_gluon.py` | Correctness tests vs. PyTorch `dist.all_to_all` | +| `benchmark/ccl/all_to_all/benchmark_x.py` | Validation + benchmark sweep + assembly dump | +| `benchmark/ccl/all_to_all/plot_x_all_to_all.py` | Scatter-plot generation from JSON results | diff --git a/iris/x/__init__.py b/iris/x/__init__.py index 7377fbe3b..7e83481c0 100644 --- a/iris/x/__init__.py +++ b/iris/x/__init__.py @@ -71,6 +71,13 @@ from .gather import gather from .all_gather import all_gather from .all_to_all import all_to_all + +try: + from .all_to_all import all_to_all_gluon # noqa: F401 + + _GLUON_ALL_TO_ALL_AVAILABLE = True +except ImportError: + _GLUON_ALL_TO_ALL_AVAILABLE = False from .reduce_scatter import reduce_scatter __all__ = [ @@ -94,3 +101,6 @@ "all_to_all", "reduce_scatter", ] + +if _GLUON_ALL_TO_ALL_AVAILABLE: + __all__.append("all_to_all_gluon") diff --git a/iris/x/all_to_all.py b/iris/x/all_to_all.py index 55530a8cf..debf47ffd 100644 --- a/iris/x/all_to_all.py +++ b/iris/x/all_to_all.py @@ -5,6 +5,7 @@ Tile-level all-to-all primitive for Iris. Performs all-to-all communication where each rank sends and receives data to/from all other ranks. +Provides both Triton (@triton.jit) and Gluon (@gluon.jit) implementations. """ import triton @@ -13,6 +14,16 @@ from iris.iris import DeviceContext from .core import Tile, TensorView +# Conditional import for Gluon +try: + from triton.experimental import gluon + from triton.experimental.gluon import language as gl + from iris.experimental.iris_gluon import IrisDeviceCtx as _IrisDeviceCtx # noqa: F401 + + GLUON_AVAILABLE = True +except ImportError: + GLUON_AVAILABLE = False + @triton.jit() def all_to_all( @@ -121,3 +132,137 @@ def all_to_all( mask=combined_mask, ) tl.store(dst_view.ptr + dst_offsets, data, mask=combined_mask) + + +# Gluon implementation +if GLUON_AVAILABLE: + + @gluon.jit + def all_to_all_gluon( + IrisDeviceCtx: gl.constexpr, + context_tensor, + src_ptr, + dst_ptr, + M, + N, + stride_src_m, + stride_src_n, + stride_dst_m, + stride_dst_n, + pid_m, + pid_n, + N_per_rank: gl.constexpr, + cur_rank: gl.constexpr, + world_size: gl.constexpr, + BLOCK_SIZE_M: gl.constexpr, + BLOCK_SIZE_N: gl.constexpr, + ): + """ + Gluon tile-level all-to-all for iris.x. + + Gluon port of all_to_all using IrisDeviceCtx. Can be called from + within a @gluon.jit kernel. Iterates over all source ranks with a + compile-time-unrolled loop (world_size is constexpr) and applies + masking to handle tiles that span rank-chunk boundaries. + + Args: + IrisDeviceCtx: IrisDeviceCtx class (constexpr, passed as first arg). + context_tensor: Encoded context tensor from shmem.get_device_context(). + src_ptr: Pointer to source tensor (local rank's input). + dst_ptr: Pointer to destination tensor (local rank's output). + M: Number of rows. + N: Total number of columns (world_size * N_per_rank). + stride_src_m, stride_src_n: Strides for source tensor. + stride_dst_m, stride_dst_n: Strides for destination tensor. + pid_m: Tile row index. + pid_n: Tile column index. + N_per_rank: Number of columns per rank (constexpr). + cur_rank: Current rank (constexpr). + world_size: Total number of ranks (constexpr). + BLOCK_SIZE_M: Block size for M dimension (constexpr). + BLOCK_SIZE_N: Block size for N dimension (constexpr). + + Semantics: + Input: Each rank has (M, world_size * N_per_rank) + Output: Each rank has (M, world_size * N_per_rank) + + rank dst's output columns [src*N:(src+1)*N] receive rank src's + input columns [dst*N:(dst+1)*N]. + + Example: + @gluon.jit + def my_kernel(IrisDeviceCtx: gl.constexpr, context_tensor, ...): + pid_m = ... + pid_n = ... + iris.x.all_to_all_gluon( + IrisDeviceCtx, context_tensor, + src_ptr, dst_ptr, M, N, + stride_src_m, stride_src_n, + stride_dst_m, stride_dst_n, + pid_m, pid_n, N_per_rank, rank, world_size, + BLOCK_SIZE_M, BLOCK_SIZE_N, + ) + """ + ctx = IrisDeviceCtx.initialize(context_tensor) + + # 1-D layout covering BLOCK_SIZE_N elements across 4 warps of 64 threads. + # Mirrors the layout used in persistent_all_to_all_gluon. + col_layout: gl.constexpr = gl.BlockedLayout([1], [64], [4], [0]) + + output_col_start = pid_n * BLOCK_SIZE_N + output_col_end = output_col_start + BLOCK_SIZE_N + + # Destination column indices are the same regardless of source rank. + rn_dst = output_col_start + gl.arange(0, BLOCK_SIZE_N, layout=col_layout) + + # Iterate over all source ranks (loop is unrolled because world_size is constexpr). + for src_rank in range(world_size): + src_chunk_out_start = src_rank * N_per_rank + src_chunk_out_end = (src_rank + 1) * N_per_rank + + # Intersection of this tile's output range with src_rank's chunk. + tile_src_start = tl.maximum(output_col_start, src_chunk_out_start) + tile_src_end = tl.minimum(output_col_end, src_chunk_out_end) + num_cols = tile_src_end - tile_src_start + + # Where the intersection starts within this tile and within src's chunk. + offset_in_tile = tile_src_start - output_col_start + offset_in_src_chunk = tile_src_start - src_chunk_out_start + + # Source column in src_rank's input that maps to this output region. + src_col_offset = cur_rank * N_per_rank + offset_in_src_chunk + + # Source column indices adjusted so that col offset_in_tile aligns correctly. + src_col_base = src_col_offset - offset_in_tile + rn_src = src_col_base + gl.arange(0, BLOCK_SIZE_N, layout=col_layout) + + # Column validity masks. + src_col_valid = ( + (rn_src >= src_col_offset) & (rn_src < src_col_offset + num_cols) & (rn_src >= 0) & (rn_src < N) + ) + dst_col_valid = ( + (rn_dst >= output_col_start + offset_in_tile) + & (rn_dst < output_col_start + offset_in_tile + num_cols) + & (rn_dst < N) + ) + # Also skip this rank entirely when there is no overlap. + col_mask = src_col_valid & dst_col_valid & (num_cols > 0) + + # Process each row in the tile (unrolled since BLOCK_SIZE_M is constexpr). + for i in range(BLOCK_SIZE_M): + row_m = pid_m * BLOCK_SIZE_M + i + + src_offsets = row_m * stride_src_m + rn_src * stride_src_n + dst_offsets = row_m * stride_dst_m + rn_dst * stride_dst_n + + # Combine column mask with row bounds check. + row_col_mask = col_mask & (row_m < M) + + if src_rank == cur_rank: + # Local copy: read from our own input. + data = gl.load(src_ptr + src_offsets, mask=row_col_mask) + gl.store(dst_ptr + dst_offsets, data, mask=row_col_mask) + else: + # Remote read: translate pointer to src_rank's address space. + data = ctx.load(src_ptr + src_offsets, src_rank, mask=row_col_mask) + gl.store(dst_ptr + dst_offsets, data, mask=row_col_mask) diff --git a/tests/x/test_all_to_all_gluon.py b/tests/x/test_all_to_all_gluon.py new file mode 100644 index 000000000..e7d694fb1 --- /dev/null +++ b/tests/x/test_all_to_all_gluon.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for the Gluon tile-level all-to-all primitive (iris.x.all_to_all_gluon). +""" + +import pytest +import torch +import torch.distributed as dist + +# Try to import Gluon; skip all tests if not available. +try: + from triton.experimental import gluon + from triton.experimental.gluon import language as gl + import iris.experimental.iris_gluon as iris_gl + import iris.x + + GLUON_AVAILABLE = hasattr(iris.x, "all_to_all_gluon") +except ImportError: + GLUON_AVAILABLE = False + + +if GLUON_AVAILABLE: + + @gluon.jit + def x_all_to_all_gluon_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + input_ptr, + output_ptr, + M, + N, + N_per_rank: gl.constexpr, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + num_pid_n, + cur_rank: gl.constexpr, + world_size: gl.constexpr, + BLOCK_SIZE_M: gl.constexpr, + BLOCK_SIZE_N: gl.constexpr, + ): + """Wrapper kernel that iterates over tiles and calls all_to_all_gluon.""" + pid = gl.program_id(0) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + iris.x.all_to_all_gluon( + IrisDeviceCtx, + context_tensor, + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + pid_m, + pid_n, + N_per_rank, + cur_rank, + world_size, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + ) + + +@pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available") +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-3, 1e-3), + (torch.float32, 1e-5, 1e-5), + (torch.bfloat16, 1e-3, 1e-3), + ], +) +@pytest.mark.parametrize( + "M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", + [ + (128, 64, 64, 32), # Small + (1024, 256, 128, 128), # Medium + (2048, 2048, 256, 256), # Large + (100, 100, 64, 64), # Non-aligned dimensions + (256, 384, 128, 128), # Non-square + ], +) +def test_all_to_all_gluon(dtype, atol, rtol, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): + """Test Gluon tile-level all-to-all by comparing against PyTorch's implementation.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 # 8 GB + shmem = iris_gl.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Build a reference tensor using PyTorch's dist.all_to_all. + pytorch_input = torch.randn(M, N * world_size, dtype=dtype, device=f"cuda:{rank}") + for r in range(world_size): + pytorch_input[:, r * N : (r + 1) * N].fill_(float(r + 1)) + + shmem.barrier() + input_chunks = [chunk.contiguous() for chunk in torch.chunk(pytorch_input, world_size, dim=1)] + output_chunks = [torch.empty(M, N, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_to_all(output_chunks, input_chunks) + pytorch_output = torch.cat(output_chunks, dim=1) + torch.cuda.synchronize() + + # Set up Iris Gluon tensors. + iris_input = shmem.zeros((M, N * world_size), dtype=dtype) + iris_input.copy_(pytorch_input) + iris_output = shmem.zeros((M, N * world_size), dtype=dtype) + + context_tensor = shmem.get_device_context() + shmem.barrier() + + # Launch Gluon kernel — one program per tile. + total_N = N * world_size + num_pid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + num_pid_n = (total_N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + total_tiles = num_pid_m * num_pid_n + grid = (total_tiles,) + + x_all_to_all_gluon_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + iris_input, + iris_output, + M, + total_N, + N, # N_per_rank + iris_input.stride(0), + iris_input.stride(1), + iris_output.stride(0), + iris_output.stride(1), + num_pid_n, + rank, + world_size, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + num_warps=4, + ) + + torch.cuda.synchronize() + shmem.barrier() + + max_diff = torch.abs(iris_output - pytorch_output).max().item() + + try: + assert torch.allclose(iris_output, pytorch_output, atol=atol, rtol=rtol), ( + f"Max difference: {max_diff}, expected < {atol}\n" + f"Rank {rank}: iris.x.all_to_all_gluon output does not match PyTorch all_to_all" + ) + + # Verify each rank's received chunks contain the expected value. + # After all-to-all, output[:, r*N:(r+1)*N] should hold rank r's data sent to + # rank 'rank', which is rank r's chunk 'rank' filled with value (rank+1). + for src_rank in range(world_size): + chunk = iris_output[:, src_rank * N : (src_rank + 1) * N] + expected_value = float(rank + 1) + assert torch.allclose(chunk, torch.full_like(chunk, expected_value), atol=atol), ( + f"Rank {rank}: chunk from rank {src_rank} should have value {expected_value}" + ) + + if rank == 0: + print(f"✓ all_to_all_gluon passed: {dtype}, M={M}, N={N}, blocks=({BLOCK_SIZE_M},{BLOCK_SIZE_N})") + finally: + shmem.barrier() + del shmem + import gc + + gc.collect()