diff --git a/src/fault-handler.jl b/src/fault-handler.jl index f0996111d..eb688dbf7 100644 --- a/src/fault-handler.jl +++ b/src/fault-handler.jl @@ -124,6 +124,9 @@ function handle_fault(ctx, state, thunk, oldproc, chan, node_order) # Reschedule inputs from deadlist newproc = OSProc(rand(workers())) + if newproc ∉ procs(ctx) + addprocs!(ctx, [newproc]) + end while length(deadlist) > 0 dt = popfirst!(deadlist) if any((input in deadlist) for input in dt.inputs) diff --git a/src/processor.jl b/src/processor.jl index 77bab713a..3f1344714 100644 --- a/src/processor.jl +++ b/src/processor.jl @@ -238,6 +238,7 @@ mutable struct Context log_sink::Any profile::Bool options + proc_lock::ReentrantLock end """ @@ -256,6 +257,7 @@ as a `Vector{Int}`. function Context(xs) Context(xs, NoOpLog(), false, nothing) # By default don't log events end +Context(xs, log_sink, profile, options) = Context(xs, log_sink, profile, options, ReentrantLock()) Context(xs::Vector{Int}) = Context(map(OSProc, xs)) function Context() procs = [OSProc(w) for w in workers()] @@ -271,3 +273,38 @@ Write a log event function write_event(ctx::Context, event::Event) write_event(ctx.log_sink, event) end + +""" + lock(f, ctx::Context) + +Acquire `ctx.proc_lock`, execute `f` with the lock held, and release the lock when `f` returns. +""" +Base.lock(f, ctx::Context) = lock(f, ctx.proc_lock) + +""" + addprocs!(ctx::Context, xs) + +Add new workers `xs` to `ctx`. + +Workers will typically be assigned new tasks in the next scheduling iteration if scheduling is ongoing. + +Workers can be either `Processor`s or the underlying process ids as `Integer`s. +""" +addprocs!(ctx::Context, xs::AbstractVector{<:Integer}) = addprocs!(ctx, map(OSProc, xs)) +addprocs!(ctx::Context, xs::AbstractVector{<:Processor}) = lock(ctx) do + append!(ctx.procs, xs) +end +""" + rmprocs!(ctx::Context, xs) + +Remove the specified workers `xs` from `ctx`. + +Workers will typically finish all their assigned tasks if scheduling is ongoing but will not be assigned new tasks after removal. + +Workers can be either `Processors` or the underlying process ids as `Integer`s. +""" +rmprocs!(ctx::Context, xs::AbstractVector{<:Integer}) = rmprocs!(ctx, map(OSProc, xs)) +rmprocs!(ctx::Context, xs::AbstractVector{<:Processor}) = lock(ctx) do + filter!(p -> p ∉ xs, ctx.procs) +end + diff --git a/src/scheduler.jl b/src/scheduler.jl index 9d1370386..c42836ae3 100644 --- a/src/scheduler.jl +++ b/src/scheduler.jl @@ -3,7 +3,7 @@ module Sch using Distributed import MemPool: DRef -import ..Dagger: Context, Processor, Thunk, Chunk, OSProc, order, free!, dependents, noffspring, istask, inputs, affinity, tochunk, @dbg, @logmsg, timespan_start, timespan_end, unrelease, procs, move, choose_processor, execute! +import ..Dagger: Context, Processor, Thunk, Chunk, OSProc, order, free!, dependents, noffspring, istask, inputs, affinity, tochunk, @dbg, @logmsg, timespan_start, timespan_end, unrelease, procs, move, choose_processor, execute!, rmprocs!, addprocs! include("fault-handler.jl") @@ -95,13 +95,7 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) node_order = x -> -get(ord, x, 0) state = start_state(deps, node_order) # start off some tasks - for p in ps - isempty(state.ready) && break - task = pop_with_affinity!(ctx, state.ready, p, false) - if task !== nothing - fire_task!(ctx, task, p, state, chan, node_order) - end - end + worker_state = assign_new_workers!(ctx, procs, state, chan, node_order) @dbg timespan_end(ctx, :scheduler_init, 0, master) # Loop while we still have thunks to execute @@ -117,6 +111,9 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) end end + # Note: worker_state may be different things for different contexts. Don't touch it out here! + worker_state = assign_new_workers!(ctx, ps, state, chan, node_order, worker_state) + if isempty(state.running) # the block above fired only meta tasks continue @@ -128,7 +125,7 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) @warn "Worker $(proc.pid) died on thunk $thunk_id, rescheduling work" # Remove dead worker from procs list - filter!(p->p.pid!=proc.pid, ctx.procs) + rmprocs!(ctx, [proc]) ps = procs(ctx) handle_fault(ctx, state, state.thunk_dict[thunk_id], proc, chan, node_order) @@ -143,7 +140,7 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) @dbg timespan_start(ctx, :scheduler, thunk_id, master) immediate_next = finish_task!(state, node, node_order) - if !isempty(state.ready) + if !isempty(state.ready) && !shall_remove_worker(ctx, proc, ps, immediate_next) thunk = pop_with_affinity!(Context(ps), state.ready, proc, immediate_next) if thunk !== nothing fire_task!(ctx, thunk, proc, state, chan, node_order) @@ -154,6 +151,28 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) state.cache[d] end +function assign_new_workers!(ctx, ps, state, chan, node_order, assignedprocs=[]) + ps !== procs(ctx) && return assignedprocs + lock(ctx) do + # Must track individual procs to handle the case when procs are removed + for p in setdiff(ps, assignedprocs) + isempty(state.ready) && break + task = pop_with_affinity!(ctx, state.ready, p, false) + if task !== nothing + fire_task!(ctx, task, p, state, chan, node_order) + end + end + return copy(ps) + end +end + +function shall_remove_worker(ctx, proc, ps, immediate_next) + ps !== procs(ctx) && return false + return lock(ctx) do + proc ∉ procs(ctx) + end +end + function pop_with_affinity!(ctx, tasks, proc, immediate_next) # allow JIT specialization on Pairs mapfirst(c) = first.(c) diff --git a/test/processors.jl b/test/processors.jl index 93ba80548..a3b7d2dab 100644 --- a/test/processors.jl +++ b/test/processors.jl @@ -76,4 +76,19 @@ end @everywhere pop!(Dagger.PROCESSOR_CALLBACKS) end end + + @testset "Modify workers in Context" begin + ps = addprocs(4, exeflags="--project") + @everywhere ps using Dagger + + ctx = Context(ps[1:2]) + + Dagger.addprocs!(ctx, ps[3:end]) + @test map(p -> p.pid, procs(ctx)) == ps + + Dagger.rmprocs!(ctx, ps[3:end]) + @test map(p -> p.pid, procs(ctx)) == ps[1:2] + + wait(rmprocs(ps)) + end end diff --git a/test/scheduler.jl b/test/scheduler.jl index 2641a40b6..2eb521d06 100644 --- a/test/scheduler.jl +++ b/test/scheduler.jl @@ -66,31 +66,96 @@ end end @everywhere (pop!(Dagger.PROCESSOR_CALLBACKS); empty!(Dagger.OSPROC_CACHE)) - @testset "Add new workers" begin - using Distributed - ps1 = addprocs(2, exeflags="--project"); - - @everywhere begin + @testset "Modify workers in running job" begin + # Test that we can add/remove workers while scheduler is running. + # As this requires asynchronity a Condition is used to stall the tasks to + # ensure workers are actually modified while the scheduler is working + + setup = quote using Dagger, Distributed # Condition to guarantee that processing is not completed before we add new workers + # Note: c is used in expressions below c = Condition() function testfun(i) - i < 2 && return myid() + i < 4 && return myid() wait(c) return myid() end end - - ts = delayed(vcat)((delayed(testfun)(i) for i in 1:4)...); - job = @async collect(Context(ps1), ts); - - ps2 = addprocs(2, exeflags="--project"); - - while !istaskdone(job) - @everywhere ps1 notify(c) + + @testset "Add new workers" begin + ps = [] + try + ps1 = addprocs(2, exeflags="--project"); + push!(ps, ps1) + + @everywhere vcat(ps1, myid()) $setup + + ts = delayed(vcat)((delayed(testfun)(i) for i in 1:10)...); + + ctx = Context(ps1) + job = @async collect(ctx, ts); + + while !istaskstarted(job) + sleep(0.001) + end + + # Will not be added, so they should never appear in output + ps2 = addprocs(2, exeflags="--project"); + push!(ps, ps2) + + ps3 = addprocs(2, exeflags="--project") + push!(ps, ps3) + @everywhere ps3 $setup + Dagger.addprocs!(ctx, ps3) + @test length(procs(ctx)) == 4 + + while !istaskdone(job) + @everywhere ps1 notify(c) + @everywhere ps3 notify(c) + end + @test fetch(job) isa Vector + @test fetch(job) |> unique |> sort == vcat(ps1, ps3) + + finally + wait(rmprocs(ps)) + end end - @test fetch(job) |> unique |> sort == ps1 - wait(rmprocs(vcat(ps1,ps2))) + @testset "Remove workers" begin + ps = [] + try + ps1 = addprocs(4, exeflags="--project"); + push!(ps, ps1) + + @everywhere vcat(ps1, myid()) $setup + + ts = delayed(vcat)((delayed(testfun)(i) for i in 1:16)...); + + ctx = Context(ps1) + job = @async collect(ctx, ts); + + while !istaskstarted(job) + sleep(0.001) + end + + Dagger.rmprocs!(ctx, ps1[3:end]) + @test length(procs(ctx)) == 2 + + while !istaskdone(job) + @everywhere ps1 notify(c) + end + res = fetch(job) + @test res isa Vector + # First all four workers will report their IDs without hassle + # Then all four will be waiting for the Condition + # While they are waiting ps1[3:end] are removed, but when the Condition is notified they will finish their tasks before being removed + @test res[1:8] |> unique |> sort == ps1 + @test res[9:end] |> unique |> sort == ps1[1:2] + + finally + wait(rmprocs(ps)) + end + end end end