diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index c02e968e23fb6..bcd6316a9d2a5 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -639,6 +639,14 @@ def _helper_test_extra_cuda_context_by_memory(self): """ device = torch.device(f"cuda:{self.rank:d}") x = torch.empty((1,), device=device) + + # We need this barrier to ensure that all nodes have completed init_process_group + # If rank=0 gets a mem snapshot before other nodes have finished init_process_group, + # then we artificially see a bump in memory usage. As per the following comment, + # we are going to be moving away from this function: + # https://github.com/pytorch/pytorch/pull/154174#discussion_r2105065931 + c10d.barrier() + # Rank 0 takes a snapshot before collective -- this snapshot should have # included rank 0's own context. if self.rank == 0: