Skip to content

Commit

Permalink
make default worker pool an AbstractWorkerPool (#49101)
Browse files Browse the repository at this point in the history
Changes [Distributed._default_worker_pool](https://github.com/JuliaLang/julia/blob/5f5d2040511b42ba74bd7529a0eac9cf817ad496/stdlib/Distributed/src/workerpool.jl#L242) to hold an `AbstractWorkerPool` instead of `WorkerPool`. With this, alternate implementations can be plugged in as the default pool. Helps in cases where a cluster is always meant to use a certain custom pool. Lower level calls can then work without having to pass a custom pool reference with every call.
  • Loading branch information
tanmaykm committed Apr 6, 2023
1 parent 1bf65b9 commit def2dda
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
6 changes: 3 additions & 3 deletions stdlib/Distributed/src/pmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct BatchProcessingError <: Exception
end

"""
pgenerate([::WorkerPool], f, c...) -> iterator
pgenerate([::AbstractWorkerPool], f, c...) -> iterator
Apply `f` to each element of `c` in parallel using available workers and tasks.
Expand All @@ -18,14 +18,14 @@ Note that `f` must be made available to all worker processes; see
[Code Availability and Loading Packages](@ref code-availability)
for details.
"""
function pgenerate(p::WorkerPool, f, c)
function pgenerate(p::AbstractWorkerPool, f, c)
if length(p) == 0
return AsyncGenerator(f, c; ntasks=()->nworkers(p))
end
batches = batchsplit(c, min_batch_count = length(p) * 3)
return Iterators.flatten(AsyncGenerator(remote(p, b -> asyncmap(f, b)), batches))
end
pgenerate(p::WorkerPool, f, c1, c...) = pgenerate(p, a->f(a...), zip(c1, c...))
pgenerate(p::AbstractWorkerPool, f, c1, c...) = pgenerate(p, a->f(a...), zip(c1, c...))
pgenerate(f, c) = pgenerate(default_worker_pool(), f, c)
pgenerate(f, c1, c...) = pgenerate(a->f(a...), zip(c1, c...))

Expand Down
15 changes: 13 additions & 2 deletions stdlib/Distributed/src/workerpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,14 @@ perform a `remote_do` on it.
"""
remote_do(f, pool::AbstractWorkerPool, args...; kwargs...) = remotecall_pool(remote_do, f, pool, args...; kwargs...)

const _default_worker_pool = Ref{Union{WorkerPool, Nothing}}(nothing)
const _default_worker_pool = Ref{Union{AbstractWorkerPool, Nothing}}(nothing)

"""
default_worker_pool()
[`WorkerPool`](@ref) containing idle [`workers`](@ref) - used by `remote(f)` and [`pmap`](@ref) (by default).
[`AbstractWorkerPool`](@ref) containing idle [`workers`](@ref) - used by `remote(f)` and [`pmap`](@ref)
(by default). Unless one is explicitly set via `default_worker_pool!(pool)`, the default worker pool is
initialized to a [`WorkerPool`](@ref).
# Examples
```julia-repl
Expand All @@ -267,6 +269,15 @@ function default_worker_pool()
return _default_worker_pool[]
end

"""
default_worker_pool!(pool::AbstractWorkerPool)
Set a [`AbstractWorkerPool`](@ref) to be used by `remote(f)` and [`pmap`](@ref) (by default).
"""
function default_worker_pool!(pool::AbstractWorkerPool)
_default_worker_pool[] = pool
end

"""
remote([p::AbstractWorkerPool], f) -> Function
Expand Down
13 changes: 13 additions & 0 deletions stdlib/Distributed/test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,19 @@ wp = CachingPool(workers())
clear!(wp)
@test length(wp.map_obj2ref) == 0

# default_worker_pool! tests
wp_default = Distributed.default_worker_pool()
try
wp = CachingPool(workers())
Distributed.default_worker_pool!(wp)
@test [1:100...] == pmap(x->x, wp, 1:100)
@test !isempty(wp.map_obj2ref)
clear!(wp)
@test isempty(wp.map_obj2ref)
finally
Distributed.default_worker_pool!(wp_default)
end

# The below block of tests are usually run only on local development systems, since:
# - tests which print errors
# - addprocs tests are memory intensive
Expand Down

0 comments on commit def2dda

Please sign in to comment.