Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix enq_work behavior when single-threaded #48702

Merged
merged 5 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -767,22 +767,33 @@ end

function enq_work(t::Task)
(t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable")
if t.sticky || Threads.threadpoolsize() == 1

# Sticky tasks go into their thread's work queue.
if t.sticky
tid = Threads.threadid(t)
if tid == 0
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
# Issue #41324
# t.sticky && tid == 0 is a task that needs to be co-scheduled with
# the parent task. If the parent (current_task) is not sticky we must
# set it to be sticky.
# XXX: Ideally we would be able to unset this
current_task().sticky = true
# The task is not yet stuck to a thread. Stick it to the current
# thread and do the same to the parent task (the current task) so
# that the tasks are correctly co-scheduled (issue #41324).
# XXX: Ideally we would be able to unset this.
tid = Threads.threadid()
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
current_task().sticky = true
end
push!(workqueue_for(tid), t)
else
Partr.multiq_insert(t, t.priority)
tid = 0
tp = Threads.threadpool(t)
if Threads.threadpoolsize(tp) == 1
# There's only one thread in the task's assigned thread pool;
# use its work queue.
tid = (tp === :default) ? 1 : Threads.threadpoolsize(:default)+1
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
push!(workqueue_for(tid), t)
else
# Otherwise, put the task in the multiqueue.
Partr.multiq_insert(t, t.priority)
tid = 0
end
end
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
return t
Expand Down
33 changes: 19 additions & 14 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,7 @@ See also `BLAS.get_num_threads` and `BLAS.set_num_threads` in the [`LinearAlgebr
man-linalg) standard library, and `nprocs()` in the [`Distributed`](@ref man-distributed)
standard library and [`Threads.maxthreadid()`](@ref).
"""
function nthreads(pool::Symbol)
if pool === :default
tpid = Int8(0)
elseif pool === :interactive
tpid = Int8(1)
else
error("invalid threadpool specified")
end
return _nthreads_in_pool(tpid)
end
nthreads(pool::Symbol) = threadpoolsize(pool)

function _nthreads_in_pool(tpid::Int8)
p = unsafe_load(cglobal(:jl_n_threads_per_pool, Ptr{Cint}))
Expand All @@ -66,15 +57,25 @@ Returns the number of threadpools currently configured.
nthreadpools() = Int(unsafe_load(cglobal(:jl_n_threadpools, Cint)))

"""
Threads.threadpoolsize()
Threads.threadpoolsize(pool::Symbol = :default) -> Int

Get the number of threads available to the Julia default worker-thread pool.
Get the number of threads available to the default thread pool (or to the
specified thread pool).

See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
[`LinearAlgebra`](@ref man-linalg) standard library, and `nprocs()` in the
[`Distributed`](@ref man-distributed) standard library.
"""
threadpoolsize() = Threads._nthreads_in_pool(Int8(0))
function threadpoolsize(pool::Symbol = :default)
if pool === :default
tpid = Int8(0)
elseif pool === :interactive
tpid = Int8(1)
else
error("invalid threadpool specified")
end
return _nthreads_in_pool(tpid)
end

function threading_run(fun, static)
ccall(:jl_enter_threaded_region, Cvoid, ())
Expand Down Expand Up @@ -343,7 +344,11 @@ macro spawn(args...)
let $(letargs...)
local task = Task($thunk)
task.sticky = false
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), task, $tpid)
local tpid_actual = $tpid
if _nthreads_in_pool(tpid_actual) == 0
tpid_actual = Int8(0)
end
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), task, tpid_actual)
if $(Expr(:islocal, var))
put!($var, task)
end
Expand Down
11 changes: 3 additions & 8 deletions test/threadpool_use.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ using Base.Threads
@test nthreadpools() == 2
@test threadpool() === :default
@test threadpool(2) === :interactive
dtask() = @test threadpool(current_task()) === :default
itask() = @test threadpool(current_task()) === :interactive
dt1 = @spawn dtask()
dt2 = @spawn :default dtask()
it = @spawn :interactive itask()
wait(dt1)
wait(dt2)
wait(it)
@test fetch(Threads.@spawn Threads.threadpool()) === :default
@test fetch(Threads.@spawn :default Threads.threadpool()) === :default
@test fetch(Threads.@spawn :interactive Threads.threadpool()) === :interactive