diff --git a/python/oneflow/comm/__init__.py b/python/oneflow/comm/__init__.py index fd2828acfb9..9ace8b4b444 100644 --- a/python/oneflow/comm/__init__.py +++ b/python/oneflow/comm/__init__.py @@ -13,5 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -from oneflow.comm.primitive import all_reduce +from oneflow.comm.comm_ops import all_reduce +from oneflow.comm.comm_ops import all_gather from oneflow._C import send, recv diff --git a/python/oneflow/comm/primitive.py b/python/oneflow/comm/comm_ops.py similarity index 50% rename from python/oneflow/comm/primitive.py rename to python/oneflow/comm/comm_ops.py index 1e17d482f86..cf5683bca72 100644 --- a/python/oneflow/comm/primitive.py +++ b/python/oneflow/comm/comm_ops.py @@ -43,9 +43,9 @@ def all_reduce(tensor): tensor([[2, 3], [4, 5]], device='cuda:1', dtype=oneflow.int64) >>> out = flow.comm.all_reduce(input) - >>> out - tensor([[3, 5], - [7, 9]], device='cuda:0', dtype=oneflow.int64) + >>> out.numpy() + array([[3, 5], + [7, 9]]) """ assert isinstance(tensor, flow._oneflow_internal.Tensor) assert tensor.device.index == flow.env.get_local_rank() @@ -57,3 +57,49 @@ def all_reduce(tensor): ).to_consistent(placement=placement, sbp=flow.sbp.broadcast) return tensor.to_local() + + +def all_gather(tensor_list, tensor): + """ + Gathers tensors from the whole group in a list. + + Args: + tensor_list (list[Tensor]): Output list. It should contain + correctly-sized tensors to be used for output of the collective. + + For example: + + .. code-block:: python + + >>> # We have 1 process groups, 2 ranks. + >>> import oneflow as flow + + >>> input = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank() + >>> input # doctest: +ONLY_CHECK_RANK_0 + tensor([[1, 2], + [3, 4]], device='cuda:0', dtype=oneflow.int64) + >>> input # doctest: +ONLY_CHECK_RANK_1 + tensor([[2, 3], + [4, 5]], device='cuda:1', dtype=oneflow.int64) + >>> tensor_list = [flow.zeros(2, 2, dtype=flow.int64) for _ in range(2)] + >>> flow.comm.all_gather(tensor_list, input) + >>> tensor_list # doctest: +ONLY_CHECK_RANK_0 + [tensor([[1, 2], + [3, 4]], device='cuda:0', dtype=oneflow.int64), tensor([[2, 3], + [4, 5]], device='cuda:0', dtype=oneflow.int64)] + >>> tensor_list # doctest: +ONLY_CHECK_RANK_1 + [tensor([[1, 2], + [3, 4]], device='cuda:1', dtype=oneflow.int64), tensor([[2, 3], + [4, 5]], device='cuda:1', dtype=oneflow.int64)] + """ + assert isinstance(tensor, flow._oneflow_internal.Tensor) + assert isinstance(tensor_list, list) + assert tensor.device.index == flow.env.get_local_rank() + assert tensor.is_local + tensor = tensor.expand([1] + list(tensor.shape)) + device_type = tensor.device.type + tensor = tensor.to_consistent( + placement=flow.env.all_device_placement(device_type), sbp=flow.sbp.split(0) + ) + for i in range(tensor.shape[0]): + tensor_list[i] = tensor[i].to_local() diff --git a/python/oneflow/framework/unittest.py b/python/oneflow/framework/unittest.py index 975035c0a99..d04b077301c 100644 --- a/python/oneflow/framework/unittest.py +++ b/python/oneflow/framework/unittest.py @@ -376,14 +376,17 @@ def __init__(self, check_flags): self._check_flags = check_flags def check_output(self, want, got, optionflags): - target_rank_list = [bool(flag & optionflags) for flag in self._check_flags] - if ( - any(target_rank_list) - and target_rank_list.index(True) == oneflow.env.get_rank() - ): + # default check_output without flag + if optionflags == 0: return super(CondSkipChecker, self).check_output(want, got, optionflags) - else: + + target_rank_list = [bool(flag & optionflags) for flag in self._check_flags] + # wrong flag will be handled before here, so any(target_rank_list) is True + # not target rank + if target_rank_list.index(True) != oneflow.env.get_rank(): return True + elif target_rank_list.index(True) == oneflow.env.get_rank(): + return super(CondSkipChecker, self).check_output(want, got, optionflags) def check_multi_rank_docstr(module): diff --git a/python/oneflow/test/modules/test_sync_allreduce.py b/python/oneflow/test/modules/test_sync_comm_ops.py similarity index 65% rename from python/oneflow/test/modules/test_sync_allreduce.py rename to python/oneflow/test/modules/test_sync_comm_ops.py index b9c85959528..1607a1e3330 100644 --- a/python/oneflow/test/modules/test_sync_allreduce.py +++ b/python/oneflow/test/modules/test_sync_comm_ops.py @@ -35,9 +35,29 @@ def test_all_reduce_2n2d(test_case): out = flow.comm.all_reduce(input) test_case.assertTrue(np.allclose(out.numpy(), np_arr * 4)) + +class TestAllGather(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() + def test_all_gather_1n2d(test_case): + if flow.env.get_rank() == 0: + np_arr = np.array([[2, 3], [4, 5]]) + elif flow.env.get_rank() == 1: + np_arr = np.array([[1, 2], [3, 4]]) + input = flow.tensor(np_arr, device="cuda", dtype=flow.int32) + tensor_list = [flow.zeros(np_arr.shape, dtype=flow.int32) for _ in range(2)] + flow.comm.all_gather(tensor_list, input) + test_case.assertTrue( + np.allclose(tensor_list[0].numpy(), np.array([[2, 3], [4, 5]])) + ) + test_case.assertTrue( + np.allclose(tensor_list[1].numpy(), np.array([[1, 2], [3, 4]])) + ) + + +@flow.unittest.skip_unless_1n2d() +class TestDocs(flow.unittest.TestCase): def test_docs(test_case): - oneflow.framework.unittest.check_multi_rank_docstr(oneflow.comm.primitive) + oneflow.framework.unittest.check_multi_rank_docstr(oneflow.comm.comm_ops) if __name__ == "__main__":