From 5fabfefff4f36885faa3eee7dfdddaa6246da71b Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Thu, 21 Aug 2025 13:31:06 +0200 Subject: [PATCH 01/13] inital latency test --- tests/examples/test_load_latency.py | 98 +++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 tests/examples/test_load_latency.py diff --git a/tests/examples/test_load_latency.py b/tests/examples/test_load_latency.py new file mode 100644 index 00000000..989a8be1 --- /dev/null +++ b/tests/examples/test_load_latency.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch +import triton +import triton.language as tl +import numpy as np +import iris +from examples.common.utils import read_realtime + + +@triton.jit() +def ping_pong( + data, + result, + len, + iter, + skip, + flag: tl.tensor, + curr_rank, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + peer = (curr_rank + 1) % 2 + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + data_mask = offsets < len + flag_mask = offsets < 1 + time_stmp_mask = offsets < 1 + + for i in range(iter + skip): + if (i == skip): + start = read_realtime(); + tl.atomic_xchg(mm_begin_timestamp_ptr + offsets, start, time_stmp_mask) + if curr_rank == (i + 1) % 2: + while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1: + pass + iris.put(data + offsets, result + offsets, curr_rank, peer, heap_bases, mask=data_mask) + tl.store(flag + offsets, i + 1, mask=flag_mask) + iris.put(flag + offsets, flag + offsets, curr_rank, peer, heap_bases, flag_mask) + else: + iris.put(data + offsets, result + offsets, curr_rank, peer, heap_bases, mask=data_mask) + tl.store(flag + offsets, i + 1, mask=flag_mask) + iris.put(flag + offsets, flag + offsets, curr_rank, peer, heap_bases, flag_mask) + while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1: + pass + stop = read_realtime(); + tl.atomic_xchg(mm_end_timestamp_ptr + offsets, stop, time_stmp_mask) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.int32, + # torch.float16, + # torch.bfloat16, + # torch.float32, + ], +) +@pytest.mark.parametrize( + "heap_size", + [ + (1 << 33), + ], +) +def test_load_bench(dtype, heap_size): + shmem = iris.iris(heap_size) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + assert num_ranks == 2 + + BLOCK_SIZE = 1 + BUFFER_LEN = 64*1024 + + iter = 200 + skip = 20 + mm_begin_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda") + mm_end_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda") + + source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype) + result_buffer = shmem.zeros_like(source_buffer) + flag = shmem.ones(1, dtype=dtype) + + grid = lambda meta: (1,) + ping_pong[grid](source_buffer, result_buffer, BUFFER_LEN, skip, iter, flag, cur_rank, BLOCK_SIZE, heap_bases,mm_begin_timestamp, mm_end_timestamp) + shmem.barrier() + begin_val = mm_begin_timestamp.cpu().item() + end_val = mm_end_timestamp.cpu().item() + with open(f'timestamps_{cur_rank}.txt', 'w') as f: + f.write(f"mm_begin_timestamp: {begin_val}\n") + f.write(f"mm_end_timestamp: {end_val}\n") From a1874955632d23d4bd176c212fc7332308482a63 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 21 Aug 2025 11:31:28 +0000 Subject: [PATCH 02/13] Apply Ruff auto-fixes --- tests/examples/test_load_latency.py | 32 ++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/examples/test_load_latency.py b/tests/examples/test_load_latency.py index 989a8be1..42fcdbb3 100644 --- a/tests/examples/test_load_latency.py +++ b/tests/examples/test_load_latency.py @@ -29,14 +29,14 @@ def ping_pong( pid = tl.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) - + data_mask = offsets < len flag_mask = offsets < 1 time_stmp_mask = offsets < 1 for i in range(iter + skip): - if (i == skip): - start = read_realtime(); + if i == skip: + start = read_realtime() tl.atomic_xchg(mm_begin_timestamp_ptr + offsets, start, time_stmp_mask) if curr_rank == (i + 1) % 2: while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1: @@ -50,7 +50,7 @@ def ping_pong( iris.put(flag + offsets, flag + offsets, curr_rank, peer, heap_bases, flag_mask) while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1: pass - stop = read_realtime(); + stop = read_realtime() tl.atomic_xchg(mm_end_timestamp_ptr + offsets, stop, time_stmp_mask) @@ -66,7 +66,7 @@ def ping_pong( @pytest.mark.parametrize( "heap_size", [ - (1 << 33), + (1 << 33), ], ) def test_load_bench(dtype, heap_size): @@ -77,22 +77,34 @@ def test_load_bench(dtype, heap_size): assert num_ranks == 2 BLOCK_SIZE = 1 - BUFFER_LEN = 64*1024 + BUFFER_LEN = 64 * 1024 iter = 200 skip = 20 mm_begin_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda") - mm_end_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda") + mm_end_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda") source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype) result_buffer = shmem.zeros_like(source_buffer) - flag = shmem.ones(1, dtype=dtype) + flag = shmem.ones(1, dtype=dtype) grid = lambda meta: (1,) - ping_pong[grid](source_buffer, result_buffer, BUFFER_LEN, skip, iter, flag, cur_rank, BLOCK_SIZE, heap_bases,mm_begin_timestamp, mm_end_timestamp) + ping_pong[grid]( + source_buffer, + result_buffer, + BUFFER_LEN, + skip, + iter, + flag, + cur_rank, + BLOCK_SIZE, + heap_bases, + mm_begin_timestamp, + mm_end_timestamp, + ) shmem.barrier() begin_val = mm_begin_timestamp.cpu().item() end_val = mm_end_timestamp.cpu().item() - with open(f'timestamps_{cur_rank}.txt', 'w') as f: + with open(f"timestamps_{cur_rank}.txt", "w") as f: f.write(f"mm_begin_timestamp: {begin_val}\n") f.write(f"mm_end_timestamp: {end_val}\n") From ad03093e48a03fb82bedbe3522a80da6d72b429c Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Wed, 27 Aug 2025 08:58:02 +0200 Subject: [PATCH 03/13] initial impl of latency --- tests/examples/test_load_latency.py | 149 ++++++++++++++++------------ 1 file changed, 88 insertions(+), 61 deletions(-) diff --git a/tests/examples/test_load_latency.py b/tests/examples/test_load_latency.py index 42fcdbb3..02e1131a 100644 --- a/tests/examples/test_load_latency.py +++ b/tests/examples/test_load_latency.py @@ -8,103 +8,130 @@ import triton.language as tl import numpy as np import iris -from examples.common.utils import read_realtime +from iris._mpi_helpers import mpi_allgather +# from examples.common.utils import read_realtime +@triton.jit +def read_realtime(): + tmp = tl.inline_asm_elementwise( + asm="mov.u64 $0, %globaltimer;", + constraints=("=l"), + args=[], + dtype=tl.int64, + is_pure=False, + pack=1, + ) + return tmp + +@triton.jit() +def gather_latencies( + local_latency, + global_latency, + curr_rank, + num_ranks , + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + latency_mask = offsets < num_ranks + iris.put(local_latency + offsets, global_latency + curr_rank * num_ranks + offsets, curr_rank, 0, heap_bases, mask=latency_mask) @triton.jit() def ping_pong( data, - result, - len, - iter, + n_elements, skip, - flag: tl.tensor, + niter, + flag, curr_rank, + peer_rank, BLOCK_SIZE: tl.constexpr, heap_bases: tl.tensor, mm_begin_timestamp_ptr: tl.tensor = None, mm_end_timestamp_ptr: tl.tensor = None, ): - peer = (curr_rank + 1) % 2 pid = tl.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) - data_mask = offsets < len + data_mask = offsets < n_elements flag_mask = offsets < 1 time_stmp_mask = offsets < 1 - for i in range(iter + skip): + for i in range(niter + skip): if i == skip: start = read_realtime() - tl.atomic_xchg(mm_begin_timestamp_ptr + offsets, start, time_stmp_mask) - if curr_rank == (i + 1) % 2: - while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1: + tl.atomic_xchg(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask) + first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank) + token_first_done = i + 1 + token_second_done = i + 2 + if curr_rank == first_rank: + iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask) + iris.store(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, flag_mask) + while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done: pass - iris.put(data + offsets, result + offsets, curr_rank, peer, heap_bases, mask=data_mask) - tl.store(flag + offsets, i + 1, mask=flag_mask) - iris.put(flag + offsets, flag + offsets, curr_rank, peer, heap_bases, flag_mask) else: - iris.put(data + offsets, result + offsets, curr_rank, peer, heap_bases, mask=data_mask) - tl.store(flag + offsets, i + 1, mask=flag_mask) - iris.put(flag + offsets, flag + offsets, curr_rank, peer, heap_bases, flag_mask) - while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1: + while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done: pass - stop = read_realtime() - tl.atomic_xchg(mm_end_timestamp_ptr + offsets, stop, time_stmp_mask) + iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask) + iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, flag_mask) + stop = read_realtime() + tl.atomic_xchg(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask) -@pytest.mark.parametrize( - "dtype", - [ - torch.int32, - # torch.float16, - # torch.bfloat16, - # torch.float32, - ], -) -@pytest.mark.parametrize( - "heap_size", - [ - (1 << 33), - ], -) -def test_load_bench(dtype, heap_size): +if __name__ == "__main__": + dtype = torch.int32 + heap_size = 1 << 32 shmem = iris.iris(heap_size) num_ranks = shmem.get_num_ranks() heap_bases = shmem.get_heap_bases() cur_rank = shmem.get_rank() - assert num_ranks == 2 BLOCK_SIZE = 1 - BUFFER_LEN = 64 * 1024 + BUFFER_LEN = 1 iter = 200 - skip = 20 - mm_begin_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda") - mm_end_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda") + skip = 1 + mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") + mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") + + local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda") source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype) result_buffer = shmem.zeros_like(source_buffer) - flag = shmem.ones(1, dtype=dtype) + flag = shmem.ones(1, dtype=dtype) grid = lambda meta: (1,) - ping_pong[grid]( - source_buffer, - result_buffer, - BUFFER_LEN, - skip, - iter, - flag, - cur_rank, - BLOCK_SIZE, - heap_bases, - mm_begin_timestamp, - mm_end_timestamp, - ) - shmem.barrier() - begin_val = mm_begin_timestamp.cpu().item() - end_val = mm_end_timestamp.cpu().item() - with open(f"timestamps_{cur_rank}.txt", "w") as f: - f.write(f"mm_begin_timestamp: {begin_val}\n") - f.write(f"mm_end_timestamp: {end_val}\n") + for source_rank in range(num_ranks): + for destination_rank in range(num_ranks): + if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]: + peer_for_me = destination_rank if cur_rank == source_rank else source_rank + ping_pong[grid](source_buffer, + BUFFER_LEN, + skip, iter, + flag, + cur_rank, peer_for_me, + BLOCK_SIZE, + heap_bases, + mm_begin_timestamp, + mm_end_timestamp) + shmem.barrier() + + for destination_rank in range(num_ranks): + local_latency[destination_rank] = (mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]) / iter + + latency_matrix = mpi_allgather(local_latency.cpu()) + + if cur_rank == 0: + with open(f"latency.txt", "w") as f: + f.write(" ," + ", ".join(f"R{j}" for j in range(num_ranks)) + "\n") + for i in range(num_ranks): + row_entries = [] + for j in range(num_ranks): + val = float(latency_matrix[i, j]) + row_entries.append(f"{val:0.6f}") + line = f"R{i}," + ", ".join(row_entries) + "\n" + f.write(line) \ No newline at end of file From e72704acf5c623ddb86ece27a300f4812356b95c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sun, 31 Aug 2025 19:39:37 +0000 Subject: [PATCH 04/13] Apply Ruff auto-fixes --- tests/examples/test_load_latency.py | 74 +++++++++++------------------ 1 file changed, 27 insertions(+), 47 deletions(-) diff --git a/tests/examples/test_load_latency.py b/tests/examples/test_load_latency.py index 02e1131a..88d9fedd 100644 --- a/tests/examples/test_load_latency.py +++ b/tests/examples/test_load_latency.py @@ -9,35 +9,8 @@ import numpy as np import iris from iris._mpi_helpers import mpi_allgather -# from examples.common.utils import read_realtime - -@triton.jit -def read_realtime(): - tmp = tl.inline_asm_elementwise( - asm="mov.u64 $0, %globaltimer;", - constraints=("=l"), - args=[], - dtype=tl.int64, - is_pure=False, - pack=1, - ) - return tmp +from examples.common.utils import read_realtime -@triton.jit() -def gather_latencies( - local_latency, - global_latency, - curr_rank, - num_ranks , - BLOCK_SIZE: tl.constexpr, - heap_bases: tl.tensor -): - pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - latency_mask = offsets < num_ranks - iris.put(local_latency + offsets, global_latency + curr_rank * num_ranks + offsets, curr_rank, 0, heap_bases, mask=latency_mask) @triton.jit() def ping_pong( @@ -66,7 +39,7 @@ def ping_pong( start = read_realtime() tl.atomic_xchg(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask) first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank) - token_first_done = i + 1 + token_first_done = i + 1 token_second_done = i + 2 if curr_rank == first_rank: iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask) @@ -82,8 +55,9 @@ def ping_pong( stop = read_realtime() tl.atomic_xchg(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask) + if __name__ == "__main__": - dtype = torch.int32 + dtype = torch.int32 heap_size = 1 << 32 shmem = iris.iris(heap_size) num_ranks = shmem.get_num_ranks() @@ -96,37 +70,43 @@ def ping_pong( iter = 200 skip = 1 mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") - mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") + mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") - local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda") + local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda") source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype) result_buffer = shmem.zeros_like(source_buffer) - flag = shmem.ones(1, dtype=dtype) + flag = shmem.ones(1, dtype=dtype) grid = lambda meta: (1,) for source_rank in range(num_ranks): for destination_rank in range(num_ranks): if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]: peer_for_me = destination_rank if cur_rank == source_rank else source_rank - ping_pong[grid](source_buffer, - BUFFER_LEN, - skip, iter, - flag, - cur_rank, peer_for_me, - BLOCK_SIZE, - heap_bases, - mm_begin_timestamp, - mm_end_timestamp) + ping_pong[grid]( + source_buffer, + BUFFER_LEN, + skip, + iter, + flag, + cur_rank, + peer_for_me, + BLOCK_SIZE, + heap_bases, + mm_begin_timestamp, + mm_end_timestamp, + ) shmem.barrier() - + for destination_rank in range(num_ranks): - local_latency[destination_rank] = (mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]) / iter - + local_latency[destination_rank] = ( + mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank] + ) / iter + latency_matrix = mpi_allgather(local_latency.cpu()) if cur_rank == 0: - with open(f"latency.txt", "w") as f: + with open("latency.txt", "w") as f: f.write(" ," + ", ".join(f"R{j}" for j in range(num_ranks)) + "\n") for i in range(num_ranks): row_entries = [] @@ -134,4 +114,4 @@ def ping_pong( val = float(latency_matrix[i, j]) row_entries.append(f"{val:0.6f}") line = f"R{i}," + ", ".join(row_entries) + "\n" - f.write(line) \ No newline at end of file + f.write(line) From f4adf5fb22040d3e1a499c3f76879378625ef925 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 31 Aug 2025 22:03:22 +0200 Subject: [PATCH 05/13] got rid of atomic timestamp store --- tests/examples/test_load_latency.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/examples/test_load_latency.py b/tests/examples/test_load_latency.py index 88d9fedd..030529c9 100644 --- a/tests/examples/test_load_latency.py +++ b/tests/examples/test_load_latency.py @@ -37,7 +37,7 @@ def ping_pong( for i in range(niter + skip): if i == skip: start = read_realtime() - tl.atomic_xchg(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask) + tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask) first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank) token_first_done = i + 1 token_second_done = i + 2 @@ -53,7 +53,7 @@ def ping_pong( iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, flag_mask) stop = read_realtime() - tl.atomic_xchg(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask) + tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask) if __name__ == "__main__": From 4b8bc7a8432a7d1e37d0de1a9071400775a703ce Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 31 Aug 2025 22:04:56 +0200 Subject: [PATCH 06/13] increase warmup time --- tests/examples/test_load_latency.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/examples/test_load_latency.py b/tests/examples/test_load_latency.py index 030529c9..ac2313ac 100644 --- a/tests/examples/test_load_latency.py +++ b/tests/examples/test_load_latency.py @@ -67,8 +67,8 @@ def ping_pong( BLOCK_SIZE = 1 BUFFER_LEN = 1 - iter = 200 - skip = 1 + iter = 100 + skip = 10 mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") From 5eff8622109199a424e6fa2db4630708fe751777 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 31 Aug 2025 22:05:43 +0200 Subject: [PATCH 07/13] cleanup --- tests/examples/test_load_latency.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/examples/test_load_latency.py b/tests/examples/test_load_latency.py index ac2313ac..ed42f336 100644 --- a/tests/examples/test_load_latency.py +++ b/tests/examples/test_load_latency.py @@ -75,7 +75,6 @@ def ping_pong( local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda") source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype) - result_buffer = shmem.zeros_like(source_buffer) flag = shmem.ones(1, dtype=dtype) grid = lambda meta: (1,) From 606853eb1b9f4c53c78dc3ead087221fe4358bde Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Mon, 1 Sep 2025 04:13:47 +0200 Subject: [PATCH 08/13] addressing comments --- benchmarks/bench_ping_pong_latency.py | 281 ++++++++++++++++++++++++++ tests/examples/test_load_latency.py | 116 ----------- 2 files changed, 281 insertions(+), 116 deletions(-) create mode 100644 benchmarks/bench_ping_pong_latency.py delete mode 100644 tests/examples/test_load_latency.py diff --git a/benchmarks/bench_ping_pong_latency.py b/benchmarks/bench_ping_pong_latency.py new file mode 100644 index 00000000..f78dc0cc --- /dev/null +++ b/benchmarks/bench_ping_pong_latency.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import json +import csv +import argparse +from pathlib import Path +import torch +import triton +import triton.language as tl +import iris +from iris._mpi_helpers import mpi_allgather +from examples.common.utils import read_realtime + + +@triton.jit() +def ping_pong( + data, + n_elements, + skip, + niter, + flag, + curr_rank, + peer_rank, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + data_mask = offsets < n_elements + time_stmp_mask = offsets < BLOCK_SIZE + flag_mask = offsets < 1 + + for i in range(niter + skip): + if i == skip: + start = read_realtime() + tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask) + first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank) + token_first_done = i + 1 + token_second_done = i + 2 + if curr_rank == first_rank: + iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask) + iris.store(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, mask=flag_mask) + while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done: + pass + else: + while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done: + pass + iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask) + iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, mask=flag_mask) + + stop = read_realtime() + tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask) + + +def torch_dtype_from_str(datatype: str) -> torch.dtype: + dtype_map = { + "int8": torch.int8, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + "int32": torch.int32, + } + try: + return dtype_map[datatype] + except KeyError: + raise ValueError(f"Unknown datatype: {datatype}") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Latency ping-pong benchmark", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-t", + "--datatype", + type=str, + default="int32", + choices=["int8", "fp16", "bf16", "fp32", "int32"], + help="Datatype for the message payload", + ) + parser.add_argument( + "-p", + "--heap_size", + type=int, + default=1 << 32, + help="Iris heap size", + ) + parser.add_argument( + "-b", + "--block_size", + type=int, + default=1, + help="Block size", + ) + parser.add_argument( + "-z", + "--buffer_size", + type=int, + default=1, + help="Length of the source buffer (elements)", + ) + parser.add_argument( + "-i", + "--iter", + type=int, + default=100, + help="Number of timed iterations", + ) + parser.add_argument( + "-w", + "--num_warmup", + type=int, + default=10, + help="Number of warmup (skip) iterations", + ) + parser.add_argument( + "-o", + "--output_file", + type=str, + default=None, + help="Optional output filename (if omitted, prints results to terminal). Supports .json, .csv", + ) + return vars(parser.parse_args()) + + +def _pretty_print_matrix(latency_matrix: torch.Tensor) -> None: + num_ranks = latency_matrix.shape[0] + col_width = 12 + header = "SRC\\DST".ljust(col_width) + "".join(f"{j:>12}" for j in range(num_ranks)) + print("\nLatency matrix (ns per iter):") + print(header) + for i in range(num_ranks): + row = f"R{i}".ljust(col_width) + for j in range(num_ranks): + row += f"{latency_matrix[i, j].item():12.6f}" + print(row) + + +def _write_csv(path: Path, latency_matrix: torch.Tensor) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", newline="") as f: + writer = csv.writer(f) + num_ranks = latency_matrix.shape[0] + writer.writerow([""] + [f"R{j}" for j in range(num_ranks)]) + for i in range(num_ranks): + row = [f"R{i}"] + [f"{latency_matrix[i, j].item():0.6f}" for j in range(num_ranks)] + writer.writerow(row) + + +def _write_json(path: Path, latency_matrix: torch.Tensor) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + num_ranks = latency_matrix.shape[0] + rows = [] + for s in range(num_ranks): + for d in range(num_ranks): + rows.append( + { + "source_rank": int(s), + "destination_rank": int(d), + "latency_ns": float(latency_matrix[s, d].item()), + } + ) + with path.open("w") as f: + json.dump(rows, f, indent=2) + + +def save_results(latency_matrix: torch.Tensor, out: str | None) -> None: + if out is None: + _pretty_print_matrix(latency_matrix) + return + + path = Path(out) + ext = path.suffix.lower() + if ext == ".json": + _write_json(path, latency_matrix) + elif ext == ".csv": + _write_csv(path, latency_matrix) + else: + raise ValueError(f"Unsupported output file extension: {out}") + + + +def print_run_settings( + args: dict, + num_ranks: int, + dtype: torch.dtype, + BLOCK_SIZE: int, + BUFFER_LEN: int, +) -> None: + elem_size = torch.tensor([], dtype=dtype).element_size() + heap_size = args["heap_size"] + out = args["output_file"] + header = "=" * 72 + print(header) + print("Latency benchmark -- run settings") + print(header) + print(f" num_ranks : {num_ranks}") + print(f" iterations : {args['iter']} (timed)") + print(f" skip (warmup) : {args['num_warmup']}") + print(f" datatype : {args['datatype']} (torch dtype: {dtype})") + print(f" element size : {elem_size} bytes") + print(f" heap size : {heap_size} ({hex(heap_size)})") + print(f" block size : {BLOCK_SIZE}") + print(f" buffer len : {BUFFER_LEN} elements") + print(f" output target : {'' if out is None else out}") + print(header) + + +if __name__ == "__main__": + args = parse_args() + dtype = torch_dtype_from_str(args["datatype"]) + heap_size = args["heap_size"] + + shmem = iris.iris(heap_size) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = args["block_size"] + BUFFER_LEN = args["buffer_size"] + + niter = args["iter"] + skip = args["num_warmup"] + + if cur_rank == 0: + print_run_settings(args, num_ranks, dtype, BLOCK_SIZE, BUFFER_LEN) + shmem.barrier() + try: + device_idx = torch.cuda.current_device() + device_name = torch.cuda.get_device_name(device_idx) + except Exception: + device_name = "unknown CUDA device" + print(f"[rank {cur_rank}] ready, device[{device_idx}]: {device_name}") + + mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") + mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") + + local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda") + + source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype) + flag = shmem.ones(1, dtype=torch.int32) + + grid = lambda meta: (1,) + for source_rank in range(num_ranks): + for destination_rank in range(num_ranks): + if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]: + peer_for_me = destination_rank if cur_rank == source_rank else source_rank + ping_pong[grid]( + source_buffer, + BUFFER_LEN, + skip, + niter, + flag, + cur_rank, + peer_for_me, + BLOCK_SIZE, + heap_bases, + mm_begin_timestamp, + mm_end_timestamp, + ) + shmem.barrier() + + mm_begin_cpu = mm_begin_timestamp.cpu().numpy() + mm_end_cpu = mm_end_timestamp.cpu().numpy() + for destination_rank in range(num_ranks): + delta = mm_end_cpu[destination_rank, :] - mm_begin_cpu[destination_rank, :] + avg_ns = float(delta.sum() / max(1, delta.size) / max(1, niter)) + local_latency[destination_rank] = avg_ns + + latency_matrix = mpi_allgather(local_latency.cpu()) + + if cur_rank == 0: + save_results(latency_matrix, args["output_file"]) + print("Benchmark complete.") diff --git a/tests/examples/test_load_latency.py b/tests/examples/test_load_latency.py deleted file mode 100644 index ed42f336..00000000 --- a/tests/examples/test_load_latency.py +++ /dev/null @@ -1,116 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -import pytest -import torch -import triton -import triton.language as tl -import numpy as np -import iris -from iris._mpi_helpers import mpi_allgather -from examples.common.utils import read_realtime - - -@triton.jit() -def ping_pong( - data, - n_elements, - skip, - niter, - flag, - curr_rank, - peer_rank, - BLOCK_SIZE: tl.constexpr, - heap_bases: tl.tensor, - mm_begin_timestamp_ptr: tl.tensor = None, - mm_end_timestamp_ptr: tl.tensor = None, -): - pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - data_mask = offsets < n_elements - flag_mask = offsets < 1 - time_stmp_mask = offsets < 1 - - for i in range(niter + skip): - if i == skip: - start = read_realtime() - tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask) - first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank) - token_first_done = i + 1 - token_second_done = i + 2 - if curr_rank == first_rank: - iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask) - iris.store(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, flag_mask) - while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done: - pass - else: - while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done: - pass - iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask) - iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, flag_mask) - - stop = read_realtime() - tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask) - - -if __name__ == "__main__": - dtype = torch.int32 - heap_size = 1 << 32 - shmem = iris.iris(heap_size) - num_ranks = shmem.get_num_ranks() - heap_bases = shmem.get_heap_bases() - cur_rank = shmem.get_rank() - - BLOCK_SIZE = 1 - BUFFER_LEN = 1 - - iter = 100 - skip = 10 - mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") - mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") - - local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda") - - source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype) - flag = shmem.ones(1, dtype=dtype) - - grid = lambda meta: (1,) - for source_rank in range(num_ranks): - for destination_rank in range(num_ranks): - if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]: - peer_for_me = destination_rank if cur_rank == source_rank else source_rank - ping_pong[grid]( - source_buffer, - BUFFER_LEN, - skip, - iter, - flag, - cur_rank, - peer_for_me, - BLOCK_SIZE, - heap_bases, - mm_begin_timestamp, - mm_end_timestamp, - ) - shmem.barrier() - - for destination_rank in range(num_ranks): - local_latency[destination_rank] = ( - mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank] - ) / iter - - latency_matrix = mpi_allgather(local_latency.cpu()) - - if cur_rank == 0: - with open("latency.txt", "w") as f: - f.write(" ," + ", ".join(f"R{j}" for j in range(num_ranks)) + "\n") - for i in range(num_ranks): - row_entries = [] - for j in range(num_ranks): - val = float(latency_matrix[i, j]) - row_entries.append(f"{val:0.6f}") - line = f"R{i}," + ", ".join(row_entries) + "\n" - f.write(line) From a3e902350fca01b14bf3380af8212d5076b4369f Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Mon, 1 Sep 2025 06:27:21 +0200 Subject: [PATCH 09/13] Fix deadlock --- benchmarks/bench_ping_pong_latency.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_ping_pong_latency.py b/benchmarks/bench_ping_pong_latency.py index f78dc0cc..9628f612 100644 --- a/benchmarks/bench_ping_pong_latency.py +++ b/benchmarks/bench_ping_pong_latency.py @@ -45,14 +45,14 @@ def ping_pong( token_second_done = i + 2 if curr_rank == first_rank: iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask) - iris.store(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, mask=flag_mask) + iris.atomic_xchg(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, mask=flag_mask) while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done: pass else: while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done: pass iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask) - iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, mask=flag_mask) + iris.atomic_xchg(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, mask=flag_mask) stop = read_realtime() tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask) From 003b273c7bd63fa289ee2fd40ed6ea5b1ab00677 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 1 Sep 2025 04:27:45 +0000 Subject: [PATCH 10/13] Apply Ruff auto-fixes --- benchmarks/bench_ping_pong_latency.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmarks/bench_ping_pong_latency.py b/benchmarks/bench_ping_pong_latency.py index 9628f612..a7196f09 100644 --- a/benchmarks/bench_ping_pong_latency.py +++ b/benchmarks/bench_ping_pong_latency.py @@ -10,7 +10,7 @@ import triton import triton.language as tl import iris -from iris._mpi_helpers import mpi_allgather +from iris._mpi_helpers import mpi_allgather from examples.common.utils import read_realtime @@ -186,7 +186,6 @@ def save_results(latency_matrix: torch.Tensor, out: str | None) -> None: raise ValueError(f"Unsupported output file extension: {out}") - def print_run_settings( args: dict, num_ranks: int, From 56ad603b1fe0e21a3f550cf50a0953ab575115d8 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Thu, 4 Sep 2025 12:28:58 +0200 Subject: [PATCH 11/13] Rewrote latency test --- ..._pong_latency.py => bench_load_latency.py} | 30 +++++++------------ 1 file changed, 10 insertions(+), 20 deletions(-) rename benchmarks/{bench_ping_pong_latency.py => bench_load_latency.py} (87%) diff --git a/benchmarks/bench_ping_pong_latency.py b/benchmarks/bench_load_latency.py similarity index 87% rename from benchmarks/bench_ping_pong_latency.py rename to benchmarks/bench_load_latency.py index a7196f09..ef041044 100644 --- a/benchmarks/bench_ping_pong_latency.py +++ b/benchmarks/bench_load_latency.py @@ -15,12 +15,11 @@ @triton.jit() -def ping_pong( +def load_remote( data, n_elements, skip, niter, - flag, curr_rank, peer_rank, BLOCK_SIZE: tl.constexpr, @@ -34,25 +33,18 @@ def ping_pong( data_mask = offsets < n_elements time_stmp_mask = offsets < BLOCK_SIZE - flag_mask = offsets < 1 for i in range(niter + skip): if i == skip: start = read_realtime() tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask) - first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank) - token_first_done = i + 1 - token_second_done = i + 2 - if curr_rank == first_rank: - iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask) - iris.atomic_xchg(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, mask=flag_mask) - while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done: - pass - else: - while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done: - pass - iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask) - iris.atomic_xchg(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, mask=flag_mask) + + # iris.load(data + offsets, curr_rank, peer_rank,heap_bases, data_mask) + from_base = tl.load(heap_bases + curr_rank) + to_base = tl.load(heap_bases + peer_rank) + offset = tl.cast(data + offsets, tl.uint64) - from_base + translated_ptr = tl.cast(tl.cast(to_base, tl.pointer_type(tl.int8)) + offset, (data + offsets).dtype) + result = tl.load(translated_ptr, mask=data_mask, cache_modifier=".cv", volatile=True) stop = read_realtime() tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask) @@ -244,19 +236,17 @@ def print_run_settings( local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda") source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype) - flag = shmem.ones(1, dtype=torch.int32) grid = lambda meta: (1,) for source_rank in range(num_ranks): for destination_rank in range(num_ranks): - if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]: + if cur_rank in [source_rank, destination_rank]: peer_for_me = destination_rank if cur_rank == source_rank else source_rank - ping_pong[grid]( + load_remote[grid]( source_buffer, BUFFER_LEN, skip, niter, - flag, cur_rank, peer_for_me, BLOCK_SIZE, From b301385ec6185ce7bbf61dfa17a1b1e3c2b36076 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sat, 6 Sep 2025 20:04:01 +0200 Subject: [PATCH 12/13] Fix latency time measurement --- benchmarks/bench_load_latency.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/benchmarks/bench_load_latency.py b/benchmarks/bench_load_latency.py index ef041044..9acf7e92 100644 --- a/benchmarks/bench_load_latency.py +++ b/benchmarks/bench_load_latency.py @@ -38,10 +38,10 @@ def load_remote( if i == skip: start = read_realtime() tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask) - + # iris.load(data + offsets, curr_rank, peer_rank,heap_bases, data_mask) from_base = tl.load(heap_bases + curr_rank) - to_base = tl.load(heap_bases + peer_rank) + to_base = tl.load(heap_bases + peer_rank) offset = tl.cast(data + offsets, tl.uint64) - from_base translated_ptr = tl.cast(tl.cast(to_base, tl.pointer_type(tl.int8)) + offset, (data + offsets).dtype) result = tl.load(translated_ptr, mask=data_mask, cache_modifier=".cv", volatile=True) @@ -240,15 +240,14 @@ def print_run_settings( grid = lambda meta: (1,) for source_rank in range(num_ranks): for destination_rank in range(num_ranks): - if cur_rank in [source_rank, destination_rank]: - peer_for_me = destination_rank if cur_rank == source_rank else source_rank + if cur_rank == source_rank: load_remote[grid]( source_buffer, BUFFER_LEN, skip, niter, cur_rank, - peer_for_me, + destination_rank, BLOCK_SIZE, heap_bases, mm_begin_timestamp, @@ -258,13 +257,16 @@ def print_run_settings( mm_begin_cpu = mm_begin_timestamp.cpu().numpy() mm_end_cpu = mm_end_timestamp.cpu().numpy() + + gpu_freq = iris.hip.get_wall_clock_rate(cur_rank) + for destination_rank in range(num_ranks): delta = mm_end_cpu[destination_rank, :] - mm_begin_cpu[destination_rank, :] - avg_ns = float(delta.sum() / max(1, delta.size) / max(1, niter)) - local_latency[destination_rank] = avg_ns + avg_cc = float(delta.sum() / max(1, delta.size) / max(1, niter)) + local_latency[destination_rank] = avg_cc * 1e6 / gpu_freq latency_matrix = mpi_allgather(local_latency.cpu()) if cur_rank == 0: save_results(latency_matrix, args["output_file"]) - print("Benchmark complete.") + print("Benchmark complete.") \ No newline at end of file From 8620fa3101374afad98d76f3d086e919a834e807 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 6 Sep 2025 20:02:51 +0000 Subject: [PATCH 13/13] Apply Ruff auto-fixes --- benchmarks/bench_load_latency.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_load_latency.py b/benchmarks/bench_load_latency.py index 9acf7e92..9e048862 100644 --- a/benchmarks/bench_load_latency.py +++ b/benchmarks/bench_load_latency.py @@ -269,4 +269,4 @@ def print_run_settings( if cur_rank == 0: save_results(latency_matrix, args["output_file"]) - print("Benchmark complete.") \ No newline at end of file + print("Benchmark complete.")