diff --git a/docs/src/reference/collective.md b/docs/src/reference/collective.md index e1a25ce37..c89e6f880 100644 --- a/docs/src/reference/collective.md +++ b/docs/src/reference/collective.md @@ -36,6 +36,7 @@ MPI.Neighbor_allgatherv! ```@docs MPI.Scatter! MPI.Scatter +MPI.scatter MPI.Scatterv! ``` diff --git a/src/collective.jl b/src/collective.jl index 70c201a22..3dded32cc 100644 --- a/src/collective.jl +++ b/src/collective.jl @@ -167,6 +167,52 @@ Scatter(sendbuf, T, comm; root::Integer=Cint(0)) = Scatter(sendbuf, ::Type{T}, root::Integer, comm::Comm) where {T} = Scatter!(sendbuf, Ref{T}(), root, comm)[] +""" + scatter(objs::Union{AbstractVector, Nothing}, comm::Comm; root::Integer=0) + +Sends the `j`-th element of `objs` in the `root` process to rank `j-1` and returns it. On `root`, `objs` is expected to be a `Comm_size(comm)`-element vector. On the other ranks, it is ignored and can be `nothing`. + +This method can handle arbitrary data. + +# See also + +- [`Scatter!`](@ref) +""" +function scatter(objs::Union{AbstractVector, Nothing}, comm::Comm; root::Integer=0) + isroot = Comm_rank(comm) == root + + if isroot + if length(objs) != Comm_size(comm) + throw(ArgumentError("Length of argument objs ($(length(objs))) != number of ranks in comm ($(Comm_size(comm))).")) + end + + sendbuffer = IOBuffer() + counts = Vector{Int64}(undef, length(objs)) + + last_pos = 0 + for (i, obj) in enumerate(objs) + Serialization.serialize(sendbuffer, i == root + 1 ? nothing : obj) + counts[i] = position(sendbuffer) - last_pos + last_pos = position(sendbuffer) + end + + count = Scatter(counts, Int64, comm; root = root) + sendbuf = VBuffer(take!(sendbuffer), counts) + + Scatterv!(sendbuf, IN_PLACE, comm; root = root) + return objs[root + 1] + else + count = Scatter(nothing, Int64, comm; root = root) + + data = Array{UInt8}(undef, count) + recvbuf = Buffer(data) + + Scatterv!(nothing, recvbuf, comm; root = root) + return MPI.deserialize(recvbuf.data) + end +end + + """ Scatterv!(sendbuf, recvbuf, comm::Comm; root::Integer=0) diff --git a/test/test_scatter.jl b/test/test_scatter.jl index 880fcb316..75e30c07c 100644 --- a/test/test_scatter.jl +++ b/test/test_scatter.jl @@ -26,14 +26,31 @@ for T in MPITestTypes end @test Array(B)[1] == T(rank+1) + B = MPI.scatter(A, comm; root = root) + @test B == T(rank+1) + # Test throwing if isroot B = ArrayType{T}(undef, 0) @test_throws DivideError MPI.Scatter!(A, B, comm; root=root) B = ArrayType{T}(undef, 8) @test_throws AssertionError MPI.Scatter!(A, B, comm; root=root) + + wrong_length = ArrayType{T}(undef, size-1) + @test_throws ArgumentError MPI.scatter(wrong_length, comm; root=root) end end + +objs = ["test", 1, Array{Int}, [1,"test"]] +objs_sized = [objs[mod1(i, length(objs))] for i = 1:size] + +B = MPI.scatter(objs_sized, comm; root = root) +@test B == objs_sized[rank+1] +objs_gathered = MPI.gather(B, comm; root = root) +if isroot + @test objs_gathered == objs_sized +end + MPI.Finalize() @test MPI.Finalized()