Skip to content

Commit

Permalink
Add bf.wait alias for synchronize
Browse files Browse the repository at this point in the history
  • Loading branch information
bichengying committed Nov 6, 2020
1 parent 52199e0 commit 115b909
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
2 changes: 1 addition & 1 deletion bluefog/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from bluefog.torch.mpi_ops import neighbor_allreduce, neighbor_allreduce_nonblocking
from bluefog.torch.mpi_ops import hierarchical_neighbor_allreduce
from bluefog.torch.mpi_ops import hierarchical_neighbor_allreduce_nonblocking
from bluefog.torch.mpi_ops import poll, synchronize, barrier
from bluefog.torch.mpi_ops import poll, synchronize, wait, barrier

from bluefog.torch.mpi_ops import win_create, win_free
from bluefog.torch.mpi_ops import win_update, win_update_then_collect
Expand Down
28 changes: 21 additions & 7 deletions bluefog/torch/mpi_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,13 +822,12 @@ def pair_gossip_nonblocking(tensor: torch.Tensor, target_rank: int, self_weight:

def poll(handle: int) -> bool:
"""
Polls an allreduce, allgather or broadcast handle to determine whether underlying
nonblocking operation has completed. After `poll()` returns `True`, `synchronize()`
Polls an allreduce, neighbor_allreduce, etc operation handle to determine whether underlying
nonblocking operation has completed. After `poll()` returns `True`, `wait()`
will return without blocking.
Arguments:
handle: A handle returned by an allreduce, allgather, broadcast, neighbor_allgather,
and neighbro_allreduce nonblocking operation.
handle: A handle returned by an allreduce, neighbor_allreduce, etc. nonblocking operation.
Returns:
A flag indicating whether the operation has completed.
Expand All @@ -838,12 +837,12 @@ def poll(handle: int) -> bool:

def synchronize(handle: int) -> torch.Tensor:
"""
Synchronizes an nonblocking allreduce, allgather or broadcast operation until
Wait an allreduce, neighbor_allreduce, etc operation until
it's completed. Returns the result of the operation.
It is the same function as `wait()`.
Args:
handle: A handle returned by an allreduce, allgather or broadcast nonblocking
operation.
handle: A handle returned by an allreduce, neighbor_allreduce, etc. nonblocking operation.
Returns:
torch.Tensor: An output tensor of the operation.
Expand All @@ -855,6 +854,21 @@ def synchronize(handle: int) -> torch.Tensor:
return output


def wait(handle: int) -> torch.Tensor:
"""
Wait an allreduce, neighbor_allreduce, etc operation until
it's completed. Returns the result of the operation.
It is just alias of `synchronize()` function.
Args:
handle: A handle returned by an allreduce, neighbor_allreduce, etc. nonblocking operation.
Returns:
torch.Tensor: An output tensor of the operation.
"""
return synchronize(handle)


def barrier():
"""Barrier function to sychronize all MPI processes.
Expand Down

0 comments on commit 115b909

Please sign in to comment.