Skip to content

Commit

Permalink
Merge pull request #48702 from JuliaLang/kp/fix-48644
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash committed Mar 9, 2023
2 parents 386b1b6 + 55422d9 commit f288f88
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 31 deletions.
29 changes: 20 additions & 9 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -767,22 +767,33 @@ end

function enq_work(t::Task)
(t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable")
if t.sticky || Threads.threadpoolsize() == 1

# Sticky tasks go into their thread's work queue.
if t.sticky
tid = Threads.threadid(t)
if tid == 0
# Issue #41324
# t.sticky && tid == 0 is a task that needs to be co-scheduled with
# the parent task. If the parent (current_task) is not sticky we must
# set it to be sticky.
# XXX: Ideally we would be able to unset this
current_task().sticky = true
# The task is not yet stuck to a thread. Stick it to the current
# thread and do the same to the parent task (the current task) so
# that the tasks are correctly co-scheduled (issue #41324).
# XXX: Ideally we would be able to unset this.
tid = Threads.threadid()
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
current_task().sticky = true
end
push!(workqueue_for(tid), t)
else
Partr.multiq_insert(t, t.priority)
tid = 0
tp = Threads.threadpool(t)
if Threads.threadpoolsize(tp) == 1
# There's only one thread in the task's assigned thread pool;
# use its work queue.
tid = (tp === :default) ? 1 : Threads.threadpoolsize(:default)+1
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
push!(workqueue_for(tid), t)
else
# Otherwise, put the task in the multiqueue.
Partr.multiq_insert(t, t.priority)
tid = 0
end
end
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
return t
Expand Down
33 changes: 19 additions & 14 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,7 @@ See also `BLAS.get_num_threads` and `BLAS.set_num_threads` in the [`LinearAlgebr
man-linalg) standard library, and `nprocs()` in the [`Distributed`](@ref man-distributed)
standard library and [`Threads.maxthreadid()`](@ref).
"""
function nthreads(pool::Symbol)
if pool === :default
tpid = Int8(0)
elseif pool === :interactive
tpid = Int8(1)
else
error("invalid threadpool specified")
end
return _nthreads_in_pool(tpid)
end
nthreads(pool::Symbol) = threadpoolsize(pool)

function _nthreads_in_pool(tpid::Int8)
p = unsafe_load(cglobal(:jl_n_threads_per_pool, Ptr{Cint}))
Expand All @@ -66,15 +57,25 @@ Returns the number of threadpools currently configured.
nthreadpools() = Int(unsafe_load(cglobal(:jl_n_threadpools, Cint)))

"""
Threads.threadpoolsize()
Threads.threadpoolsize(pool::Symbol = :default) -> Int
Get the number of threads available to the Julia default worker-thread pool.
Get the number of threads available to the default thread pool (or to the
specified thread pool).
See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
[`LinearAlgebra`](@ref man-linalg) standard library, and `nprocs()` in the
[`Distributed`](@ref man-distributed) standard library.
"""
threadpoolsize() = Threads._nthreads_in_pool(Int8(0))
function threadpoolsize(pool::Symbol = :default)
if pool === :default
tpid = Int8(0)
elseif pool === :interactive
tpid = Int8(1)
else
error("invalid threadpool specified")
end
return _nthreads_in_pool(tpid)
end

function threading_run(fun, static)
ccall(:jl_enter_threaded_region, Cvoid, ())
Expand Down Expand Up @@ -343,7 +344,11 @@ macro spawn(args...)
let $(letargs...)
local task = Task($thunk)
task.sticky = false
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), task, $tpid)
local tpid_actual = $tpid
if _nthreads_in_pool(tpid_actual) == 0
tpid_actual = Int8(0)
end
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), task, tpid_actual)
if $(Expr(:islocal, var))
put!($var, task)
end
Expand Down
11 changes: 3 additions & 8 deletions test/threadpool_use.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ using Base.Threads
@test nthreadpools() == 2
@test threadpool() === :default
@test threadpool(2) === :interactive
dtask() = @test threadpool(current_task()) === :default
itask() = @test threadpool(current_task()) === :interactive
dt1 = @spawn dtask()
dt2 = @spawn :default dtask()
it = @spawn :interactive itask()
wait(dt1)
wait(dt2)
wait(it)
@test fetch(Threads.@spawn Threads.threadpool()) === :default
@test fetch(Threads.@spawn :default Threads.threadpool()) === :default
@test fetch(Threads.@spawn :interactive Threads.threadpool()) === :interactive

0 comments on commit f288f88

Please sign in to comment.