diff --git a/tests/unittests/test_copy_cache_modifiers.py b/tests/unittests/test_copy_cache_modifiers.py index c54544a5..08453a36 100644 --- a/tests/unittests/test_copy_cache_modifiers.py +++ b/tests/unittests/test_copy_cache_modifiers.py @@ -28,7 +28,7 @@ def copy_kernel( # Test copy with cache modifiers - copy from current rank to other ranks for target_rank in range(num_ranks): src_data = data + BLOCK_SIZE * cur_rank - dest_data = results + BLOCK_SIZE * target_rank + dest_data = results + BLOCK_SIZE * cur_rank if load_cache_modifier is None and store_cache_modifier is None: iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, cur_rank, heap_bases, mask=mask) elif load_cache_modifier is None: @@ -96,11 +96,12 @@ def test_copy_cache_modifiers(load_cache_modifier, store_cache_modifier): shmem.barrier() - # Verify results - each rank should have copied its data to all ranks - for i in range(num_ranks): - expected_value = base * (cur_rank + 1) + # Verify results - each rank copies its data to all other ranks + # After barrier, results[rank_id] should contain data from rank_id + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) assert torch.allclose( - results[i], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device) + results[rank_id], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device) ), ( - f"Mismatch at rank {cur_rank}, target {i} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" )