diff --git a/python/paddle/fluid/tests/unittests/collective/process_group_mpi.py b/python/paddle/fluid/tests/unittests/collective/process_group_mpi.py index 6ae0519d2d7e2..f2fc9c498b4e8 100644 --- a/python/paddle/fluid/tests/unittests/collective/process_group_mpi.py +++ b/python/paddle/fluid/tests/unittests/collective/process_group_mpi.py @@ -443,6 +443,51 @@ def config(self): self.dtype = "float32" self.shape = (2, 10, 5) + def test_create_process_group_mpi(self): + group = init_process_group() + pg = group.process_group + # test allreduce sum + test_allreduce_sum(pg, self.shape, self.dtype) + + # test allreduce max + test_allreduce_max(pg, self.shape, self.dtype) + + # test allreduce min + test_allreduce_min(pg, self.shape, self.dtype) + + # test allreduce prod + test_allreduce_prod(pg, self.shape, self.dtype) + + # test broadcast + test_broadcast(pg, self.shape, self.dtype) + + # test barrier + test_barrair(pg) + + # test allgather + test_allgather(pg, self.shape, self.dtype) + + # test alltoall + test_all2all(pg, self.shape, self.dtype) + + # test Reduce + test_reduce_sum(pg, self.shape, self.dtype) + + # test reduce max + test_reduce_max(pg, self.shape, self.dtype) + + # test reduce min + test_reduce_min(pg, self.shape, self.dtype) + + # test reduce product + test_reduce_prod(pg, self.shape, self.dtype) + + # test Scatter + test_scatter(pg, self.shape, self.dtype) + + # test send recv. + test_send_recv(pg, group, self.shape, self.dtype) + if __name__ == "__main__": unittest.main()