diff --git a/src/compute.jl b/src/compute.jl index f31be828d..d7a547dc9 100644 --- a/src/compute.jl +++ b/src/compute.jl @@ -25,11 +25,21 @@ runs the scheduler with the specified options. Returns a Chunk which references the result. """ function compute(ctx::Context, d::Thunk; options=nothing) - if !(:scheduler in keys(PLUGINS)) - PLUGINS[:scheduler] = get_type(PLUGIN_CONFIGS[:scheduler]) + scheduler = get!(PLUGINS, :scheduler) do + get_type(PLUGIN_CONFIGS[:scheduler]) end - scheduler = PLUGINS[:scheduler] - (scheduler).compute_dag(ctx, d; options=options) + res = scheduler.compute_dag(ctx, d; options=options) + if ctx.log_file !== nothing + if ctx.log_sink !== LocalEventLog + logs = get_logs!(ctx.log_sink) + open(ctx.log_file, "w") do io + Dagger.show_plan(io, logs, d) + end + else + @warn "Context log_sink not set to LocalEventLog, skipping" + end + end + res end function debug_compute(ctx::Context, args...; profile=false, options=nothing) diff --git a/src/processor.jl b/src/processor.jl index eda92570b..3a581cc6b 100644 --- a/src/processor.jl +++ b/src/processor.jl @@ -235,10 +235,11 @@ default_enabled(proc::ThreadProc) = true "A context represents a set of processors to use for an operation." mutable struct Context procs::Vector{Processor} + proc_lock::ReentrantLock log_sink::Any + log_file::Union{String,Nothing} profile::Bool options - proc_lock::ReentrantLock end """ @@ -253,16 +254,19 @@ number of threads. It is also possible to create a Context from a vector of [`OSProc`](@ref), or equivalently the underlying process ids can also be passed directly as a `Vector{Int}`. + +Special fields include: +- 'log_sink': A log sink object to use, if any. +- `log_file::Union{String,Nothing}`: Path to logfile. If specified, at +scheduler termination logs will be collected, combined with input thunks, and +written out in DOT format to this location. +- `profile::Bool`: Whether or not to perform profiling with Profile stdlib. """ -function Context(xs) - Context(xs, NoOpLog(), false, nothing) # By default don't log events -end -Context(xs, log_sink, profile, options) = Context(xs, log_sink, profile, options, ReentrantLock()) +Context(procs::Vector{P}=Processor[OSProc(w) for w in workers()]; + proc_lock=ReentrantLock(), log_sink=NoOpLog(), log_file=nothing, + profile=false, options=nothing) where {P<:Processor} = + Context(procs, proc_lock, log_sink, log_file, profile, options) Context(xs::Vector{Int}) = Context(map(OSProc, xs)) -function Context() - procs = [OSProc(w) for w in workers()] - Context(procs) -end procs(ctx::Context) = lock(ctx) do copy(ctx.procs) end diff --git a/src/scheduler.jl b/src/scheduler.jl index d47cc0333..3c005a8ae 100644 --- a/src/scheduler.jl +++ b/src/scheduler.jl @@ -42,6 +42,9 @@ Stores DAG-global options to be passed to the Dagger.Sch scheduler. # Arguments - `single::Int=0`: Force all work onto worker with specified id. `0` disables this option. +- `proctypes::Vector{Type{<:Processor}}=Type[]`: Force scheduler to use one or +more processors that are instances/subtypes of a contained type. Leave this +vector empty to disable. """ Base.@kwdef struct SchedulerOptions single::Int = 0 @@ -375,8 +378,8 @@ function start_state(deps::Dict, node_order) state end -@noinline function do_task(thunk_id, f, data, send_result, persist, cache, options, ids, logsink) - ctx = Context(Processor[], logsink, false, nothing) +@noinline function do_task(thunk_id, f, data, send_result, persist, cache, options, ids, log_sink) + ctx = Context(Processor[]; log_sink=log_sink) proc = OSProc() fetched = map(Iterators.zip(data,ids)) do (x, id) @dbg timespan_start(ctx, :comm, (thunk_id, id), (f, id)) @@ -407,10 +410,10 @@ end result_meta end -@noinline function async_apply(p::OSProc, thunk_id, f, data, chan, send_res, persist, cache, options, ids, logsink) +@noinline function async_apply(p::OSProc, thunk_id, f, data, chan, send_res, persist, cache, options, ids, log_sink) @async begin try - put!(chan, remotecall_fetch(do_task, p.pid, thunk_id, f, data, send_res, persist, cache, options, ids, logsink)) + put!(chan, remotecall_fetch(do_task, p.pid, thunk_id, f, data, send_res, persist, cache, options, ids, log_sink)) catch ex bt = catch_backtrace() put!(chan, (p, thunk_id, CapturedException(ex, bt))) diff --git a/src/ui/graph.jl b/src/ui/graph.jl index 21c329728..7dbabbedb 100644 --- a/src/ui/graph.jl +++ b/src/ui/graph.jl @@ -146,6 +146,7 @@ function write_dag(io, logs::Vector, t=nothing) c = write_node(io, arg, c, name) push!(nodes, name) end + arg_c += 1 end argnodemap[id] = nodes end diff --git a/test/ui.jl b/test/ui.jl index 5fd977462..925ef2ded 100644 --- a/test/ui.jl +++ b/test/ui.jl @@ -44,4 +44,18 @@ end logs = Dagger.get_logs!(log) plan = Dagger.show_plan(logs, j) end + +@testset "Automatic Plan Rendering" begin + x = compute(rand(Blocks(2,2),4,4)) + mktemp() do path, io + ctx = Context(;log_sink=Dagger.LocalEventLog(),log_file=path) + compute(ctx, x * x) + plan = String(read(io)) + @test occursin("digraph {", plan) + @test occursin("Comm:", plan) + @test occursin("Move:", plan) + @test occursin("Compute:", plan) + @test endswith(plan, "}\n") + end +end end