Skip to content

Commit

Permalink
reset process_mpi test
Browse files Browse the repository at this point in the history
  • Loading branch information
yjjiang11 committed Dec 8, 2022
1 parent 278cf18 commit eebef08
Showing 1 changed file with 45 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,51 @@ def config(self):
self.dtype = "float32"
self.shape = (2, 10, 5)

def test_create_process_group_mpi(self):
group = init_process_group()
pg = group.process_group
# test allreduce sum
test_allreduce_sum(pg, self.shape, self.dtype)

# test allreduce max
test_allreduce_max(pg, self.shape, self.dtype)

# test allreduce min
test_allreduce_min(pg, self.shape, self.dtype)

# test allreduce prod
test_allreduce_prod(pg, self.shape, self.dtype)

# test broadcast
test_broadcast(pg, self.shape, self.dtype)

# test barrier
test_barrair(pg)

# test allgather
test_allgather(pg, self.shape, self.dtype)

# test alltoall
test_all2all(pg, self.shape, self.dtype)

# test Reduce
test_reduce_sum(pg, self.shape, self.dtype)

# test reduce max
test_reduce_max(pg, self.shape, self.dtype)

# test reduce min
test_reduce_min(pg, self.shape, self.dtype)

# test reduce product
test_reduce_prod(pg, self.shape, self.dtype)

# test Scatter
test_scatter(pg, self.shape, self.dtype)

# test send recv.
test_send_recv(pg, group, self.shape, self.dtype)


if __name__ == "__main__":
unittest.main()

0 comments on commit eebef08

Please sign in to comment.