Skip to content

Commit

Permalink
options: Add dispatch-based options
Browse files Browse the repository at this point in the history
  • Loading branch information
jpsamaroo committed Jul 4, 2022
1 parent 22d30d8 commit dbc4b50
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 87 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
1 change: 1 addition & 0 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using UUIDs
import ContextVariablesX

using Requires
using MacroTools

const PLUGINS = Dict{Symbol,Any}()
const PLUGIN_CONFIGS = Dict{Symbol,String}(
Expand Down
45 changes: 45 additions & 0 deletions src/options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,48 @@ function get_options(key::Symbol, default)
opts = get_options()
return haskey(opts, key) ? getproperty(opts, key) : default
end

# Dispatch-based setters

"""
default_option(::Val{name}, Tf, Targs...) where name = value
Defines the default value for option `name` to `value` when Dagger is preparing
to execute a function with type `Tf` with the argument types `Targs`. Users and
libraries may override this to set default values for tasks.
An easier way to define these defaults is with [`@option`](@ref).
Note that the actual task's argument values are not passed, as it may not
always be possible or efficient to gather all Dagger task arguments on one
worker.
This function may be executed within the scheduler, so it should generally be
made very cheap to execute. If the function throws an error, the scheduler will
use whatever the global default value is for that option instead.
"""
default_option(::Val{name}, Tf, Targs...) where name = nothing
default_option(::Val) = throw(ArgumentError("default_option requires a function type and any argument types"))

"""
@option name myfunc(A, B, C) = value
A convenience macro for defining [`default_option`](@ref). For example:
```julia
Dagger.@option single mylocalfunc(Int) = 1
```
The above call will set the `single` option to `1` for any Dagger task calling
`mylocalfunc(Int)` with an `Int` argument.
"""
macro option(name, ex)
@capture(ex, f_(args__) = value_)
args = esc.(args)
argsyms = map(_->gensym(), args)
_args = map(arg->:(::$Type{$(argsyms[arg[1]])}), enumerate(args))
argsubs = map(arg->:($(argsyms[arg[1]])<:$(arg[2])), enumerate(args))
quote
Dagger.default_option(::$Val{$name}, ::Type{$typeof($(esc(f)))}, $(_args...)) where {$(argsubs...)} = $(esc(value))
end
end
119 changes: 53 additions & 66 deletions src/sch/Sch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,13 @@ If this returns a `Chunk`, all thunks will be skipped, and the `Chunk` will be
returned. If `nothing` is returned, restoring is skipped, and the scheduler
will execute as usual. If this function throws an error, restoring will be
skipped, and the error will be displayed.
- `round_robin::Bool=false`: Whether to schedule in round-robin mode, which
spreads load instead of the default behavior of filling processors to capacity.
"""
Base.@kwdef struct SchedulerOptions
single::Int = 0
single::Union{Int,Nothing} = nothing
proclist = nothing
allow_errors::Bool = false
allow_errors::Union{Bool,Nothing} = false
checkpoint = nothing
restore = nothing
round_robin::Bool = false
end

"""
Expand Down Expand Up @@ -213,11 +210,11 @@ device must support `MemPool.CPURAMResource`. When `nothing`, uses
`MemPool.GLOBAL_DEVICE[]`.
"""
Base.@kwdef struct ThunkOptions
single::Int = 0
single::Union{Int,Nothing} = nothing
proclist = nothing
time_util::Dict{Type,Any} = Dict{Type,Any}()
alloc_util::Dict{Type,UInt64} = Dict{Type,UInt64}()
allow_errors::Bool = false
time_util::Union{Dict{Type,Any},Nothing} = nothing
alloc_util::Union{Dict{Type,UInt64},Nothing} = nothing
allow_errors::Union{Bool,Nothing} = nothing
checkpoint = nothing
restore = nothing
storage::Union{Chunk,Nothing} = nothing
Expand All @@ -232,20 +229,50 @@ include("eager.jl")
Combine `SchedulerOptions` and `ThunkOptions` into a new `ThunkOptions`.
"""
function Base.merge(sopts::SchedulerOptions, topts::ThunkOptions)
single = topts.single != 0 ? topts.single : sopts.single
allow_errors = sopts.allow_errors || topts.allow_errors
single = topts.single !== nothing ? topts.single : sopts.single
allow_errors = topts.allow_errors !== nothing ? topts.allow_errors : sopts.allow_errors
proclist = topts.proclist !== nothing ? topts.proclist : sopts.proclist
ThunkOptions(single, proclist, topts.time_util, topts.alloc_util, allow_errors, topts.checkpoint, topts.restore, topts.storage)
ThunkOptions(single,
proclist,
topts.time_util,
topts.alloc_util,
allow_errors,
topts.checkpoint,
topts.restore,
topts.storage)
end
Base.merge(sopts::SchedulerOptions, ::Nothing) =
ThunkOptions(sopts.single, sopts.proclist, Dict{Type,Any}(), sopts.allow_errors)
ThunkOptions(sopts.single,
sopts.proclist,
nothing,
nothing,
sopts.allow_errors)
"""
populate_defaults(opts::ThunkOptions, Tf, Targs) -> ThunkOptions
function isrestricted(task::Thunk, proc::OSProc)
if (task.options !== nothing) && (task.options.single != 0) &&
(task.options.single != proc.pid)
return true
Returns a `ThunkOptions` with default values filled in for a function of type
`Tf` with argument types `Targs`, if the option was previously unspecified in
`opts`.
"""
function populate_defaults(opts::ThunkOptions, Tf, Targs)
function maybe_default(opt::Symbol)
old_opt = getproperty(opts, opt)
if old_opt !== nothing
return old_opt
else
return Dagger.default_option(Val(opt), Tf, Targs...)
end
end
return false
ThunkOptions(
maybe_default(:single),
maybe_default(:proclist),
maybe_default(:time_util),
maybe_default(:alloc_util),
maybe_default(:allow_errors),
maybe_default(:checkpoint),
maybe_default(:restore),
maybe_default(:storage),
)
end

function cleanup(ctx)
Expand Down Expand Up @@ -474,7 +501,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options)
timespan_finish(ctx, :handle_fault, 0, 0)
return # effectively `continue`
else
if ctx.options.allow_errors || unwrap_weak_checked(state.thunk_dict[thunk_id]).options.allow_errors
if something(ctx.options.allow_errors, false) ||
something(unwrap_weak_checked(state.thunk_dict[thunk_id]).options.allow_errors, false)
thunk_failed = true
else
throw(res)
Expand Down Expand Up @@ -541,7 +569,7 @@ function scheduler_exit(ctx, state::ComputeState, options)
end

function procs_to_use(ctx, options=ctx.options)
return if options.single !== 0
return if options.single !== nothing
@assert options.single in vcat(1, workers()) "Sch option `single` must specify an active worker ID."
OSProc[OSProc(options.single)]
else
Expand Down Expand Up @@ -621,7 +649,9 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
@goto fallback
end

local_procs, costs = estimate_task_costs(state, local_procs, task)
inputs = collect_task_inputs(state, task)
opts = populate_defaults(opts, chunktype(task.f), map(chunktype, inputs))
local_procs, costs = estimate_task_costs(state, local_procs, task, inputs)
scheduled = false

# Move our corresponding ThreadProc to be the last considered
Expand Down Expand Up @@ -707,9 +737,6 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
push!(get!(()->Vector{Tuple{Thunk,<:Any,<:Any}}(), to_fire, (gproc, proc)), (task, est_time_util, est_alloc_util))

# Proceed to next entry to spread work
if !ctx.options.round_robin
@warn "Round-robin mode is always on"
end
state.procs_cache_list[] = state.procs_cache_list[].next
@goto pop_task

Expand Down Expand Up @@ -781,46 +808,6 @@ function remove_dead_proc!(ctx, state, proc, options=ctx.options)
state.procs_cache_list[] = nothing
end

function pop_with_affinity!(ctx, tasks, proc)
# TODO: use the size
parent_affinity_procs = Vector(undef, length(tasks))
# parent_affinity_sizes = Vector(undef, length(tasks))
for i=length(tasks):-1:1
t = tasks[i]
aff = affinity(t)
aff_procs = first.(aff)
if proc in aff_procs
if !isrestricted(t,proc)
deleteat!(tasks, i)
return t
end
end
parent_affinity_procs[i] = aff_procs
end
for i=length(tasks):-1:1
# use up tasks without affinities
# let the procs with the respective affinities pick up
# other tasks
aff_procs = parent_affinity_procs[i]
if isempty(aff_procs)
t = tasks[i]
if !isrestricted(t,proc)
deleteat!(tasks, i)
return t
end
end
if all(!(p in aff_procs) for p in procs(ctx))
# no proc is ever going to ask for it
t = tasks[i]
if !isrestricted(t,proc)
deleteat!(tasks, i)
return t
end
end
end
return nothing
end

function finish_task!(ctx, state, node, thunk_failed)
pop!(state.running, node)
delete!(state.running_on, node)
Expand Down Expand Up @@ -914,12 +901,12 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state)
toptions = thunk.options !== nothing ? thunk.options : ThunkOptions()
options = merge(ctx.options, toptions)
propagated = get_propagated_options(thunk)
@assert (options.single == 0) || (gproc.pid == options.single)
@assert (options.single === nothing) || (gproc.pid == options.single)
# TODO: Set `sch_handle.tid.ref` to the right `DRef`
sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...)

# TODO: De-dup common fields (log_sink, uid, etc.)
push!(to_send, Any[thunk.id, time_util, alloc_util, fn_type(thunk.f), data, thunk.get_result,
push!(to_send, Any[thunk.id, time_util, alloc_util, chunktype(thunk.f), data, thunk.get_result,
thunk.persist, thunk.cache, thunk.meta, options,
propagated, ids,
(log_sink=ctx.log_sink, profile=ctx.profile),
Expand Down
40 changes: 19 additions & 21 deletions src/sch/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,10 @@ function report_catch_error(err, desc=nothing)
write(stderr, iob)
end

fn_type(x::Chunk) = x.chunktype
fn_type(x) = typeof(x)
chunktype(x) = typeof(x)
function signature(task::Thunk, state)
sig = Any[fn_type(task.f)]
for input in task.inputs
input = unwrap_weak_checked(input)
input = istask(input) ? state.cache[input] : input
push!(sig, fn_type(input))
end
sig = Any[chunktype(task.f)]
append!(sig, collect_task_inputs(state, task))
sig
end

Expand All @@ -283,7 +278,7 @@ function can_use_proc(task, gproc, proc, opts, scope)
end

# Check against single
if opts.single != 0
if opts.single !== nothing
if gproc.pid != opts.single
@debug "Rejected $proc: gproc.pid != single"
return false
Expand All @@ -303,7 +298,7 @@ end
function has_capacity(state, p, gp, time_util, alloc_util, sig)
T = typeof(p)
# FIXME: MaxUtilization
est_time_util = round(UInt64, if haskey(time_util, T)
est_time_util = round(UInt64, if time_util !== nothing && haskey(time_util, T)
time_util[T] * 1000^3
else
get(state.signature_time_cost, sig, 1000^3)
Expand Down Expand Up @@ -364,29 +359,32 @@ function impute_sum(xs)
total + nothing_count * total / something_count
end

"Collects all arguments for `task`, converting Thunk inputs to Chunks."
function collect_task_inputs(state, task)
inputs = Any[]
for input in task.inputs
input = unwrap_weak_checked(input)
push!(inputs, istask(input) ? state.cache[input] : input)
end
inputs
end

"""
Estimates the cost of scheduling `task` on each processor in `procs`. Considers
current estimated per-processor compute pressure, and transfer costs for each
`Chunk` argument to `task`. Returns `(procs, costs)`, with `procs` sorted in
order of ascending cost.
"""
function estimate_task_costs(state, procs, task)
function estimate_task_costs(state, procs, task, inputs)
tx_rate = state.transfer_rate[]

# Find all Chunks
chunks = Chunk[]
for input in task.inputs
input = unwrap_weak_checked(input)
input_raw = istask(input) ? state.cache[input] : input
if input_raw isa Chunk
push!(chunks, input_raw)
for input in inputs
if input isa Chunk
push!(chunks, input)
end
end
#=
inputs = map(@nospecialize(input)->istask(input) ? state.cache[input] : input,
map(@nospecialize(x)->unwrap_weak_checked(x), task.inputs))
chunks = filter(@nospecialize(t)->isa(t, Chunk), inputs)
=#

# Estimate network transfer costs based on data size
# N.B. `affinity(x)` really means "data size of `x`"
Expand Down

0 comments on commit dbc4b50

Please sign in to comment.