diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 1326a9acb61e2..242dcd951545d 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -587,6 +587,14 @@ def _helper_test_extra_cuda_context_by_memory(self): """ device = torch.device("cuda:%d" % self.rank) x = torch.empty((1,), device=device) + + # We need this barrier to ensure that all nodes have completed init_process_group + # If rank=0 gets a mem snapshot before other nodes have finished init_process_group, + # then we artificially see a bump in memory usage. As per the following comment, + # we are going to be moving away from this function: + # https://github.com/pytorch/pytorch/pull/154174#discussion_r2105065931 + c10d.barrier() + # Rank 0 takes a snapshot before collective -- this snapshot should have # included rank 0's own context. if self.rank == 0: