Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions tests/unittests/test_copy_cache_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def copy_kernel(
# Test copy with cache modifiers - copy from current rank to other ranks
for target_rank in range(num_ranks):
src_data = data + BLOCK_SIZE * cur_rank
dest_data = results + BLOCK_SIZE * target_rank
dest_data = results + BLOCK_SIZE * cur_rank
if load_cache_modifier is None and store_cache_modifier is None:
iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, cur_rank, heap_bases, mask=mask)
elif load_cache_modifier is None:
Expand Down Expand Up @@ -96,11 +96,12 @@ def test_copy_cache_modifiers(load_cache_modifier, store_cache_modifier):

shmem.barrier()

# Verify results - each rank should have copied its data to all ranks
for i in range(num_ranks):
expected_value = base * (cur_rank + 1)
# Verify results - each rank copies its data to all other ranks
# After barrier, results[rank_id] should contain data from rank_id
for rank_id in range(num_ranks):
expected_value = (rank_id + num_ranks) * (rank_id + 1)
assert torch.allclose(
results[i], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device)
results[rank_id], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device)
), (
f"Mismatch at rank {cur_rank}, target {i} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}"
f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}"
)