From 7596316d5d3070e686afac03a04c1f59c6f7f9f3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 30 Oct 2025 00:53:31 +0000 Subject: [PATCH 1/2] Initial plan From ad2177132698548f2fe8ca349f0630a6a09df4a5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 30 Oct 2025 00:59:51 +0000 Subject: [PATCH 2/2] Fix cache modifiers unittest pointer arithmetic Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- tests/unittests/test_copy_cache_modifiers.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/unittests/test_copy_cache_modifiers.py b/tests/unittests/test_copy_cache_modifiers.py index c54544a5..08453a36 100644 --- a/tests/unittests/test_copy_cache_modifiers.py +++ b/tests/unittests/test_copy_cache_modifiers.py @@ -28,7 +28,7 @@ def copy_kernel( # Test copy with cache modifiers - copy from current rank to other ranks for target_rank in range(num_ranks): src_data = data + BLOCK_SIZE * cur_rank - dest_data = results + BLOCK_SIZE * target_rank + dest_data = results + BLOCK_SIZE * cur_rank if load_cache_modifier is None and store_cache_modifier is None: iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, cur_rank, heap_bases, mask=mask) elif load_cache_modifier is None: @@ -96,11 +96,12 @@ def test_copy_cache_modifiers(load_cache_modifier, store_cache_modifier): shmem.barrier() - # Verify results - each rank should have copied its data to all ranks - for i in range(num_ranks): - expected_value = base * (cur_rank + 1) + # Verify results - each rank copies its data to all other ranks + # After barrier, results[rank_id] should contain data from rank_id + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) assert torch.allclose( - results[i], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device) + results[rank_id], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device) ), ( - f"Mismatch at rank {cur_rank}, target {i} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" )