Skip to content

Commit

Permalink
Separate foreign threads into a :foreign threadpool (#50912)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel Baraldi <baraldigabriel@gmail.com>
Co-authored-by: Dilum Aluthge <dilum@aluthge.com>
(cherry picked from commit 8be469e)
  • Loading branch information
vchuravy authored and KristofferC committed Aug 23, 2023
1 parent f58e1eb commit 13ec3ce
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 7 deletions.
7 changes: 7 additions & 0 deletions base/partr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ end

function multiq_insert(task::Task, priority::UInt16)
tpid = ccall(:jl_get_task_threadpoolid, Int8, (Any,), task)
@assert tpid > -1
heap_p = multiq_size(tpid)
tp = tpid + 1

Expand Down Expand Up @@ -131,6 +132,9 @@ function multiq_deletemin()

tid = Threads.threadid()
tp = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1) + 1
if tp == 0 # Foreign thread
return nothing
end
tpheaps = heaps[tp]

@label retry
Expand Down Expand Up @@ -182,6 +186,9 @@ end
function multiq_check_empty()
tid = Threads.threadid()
tp = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1) + 1
if tp == 0 # Foreign thread
return true
end
for i = UInt32(1):length(heaps[tp])
if heaps[tp][i].ntasks != 0
return false
Expand Down
2 changes: 1 addition & 1 deletion base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ function enq_work(t::Task)
else
@label not_sticky
tp = Threads.threadpool(t)
if Threads.threadpoolsize(tp) == 1
if tp === :foreign || Threads.threadpoolsize(tp) == 1
# There's only one thread in the task's assigned thread pool;
# use its work queue.
tid = (tp === :interactive) ? 1 : Threads.threadpoolsize(:interactive)+1
Expand Down
14 changes: 10 additions & 4 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ function _tpid_to_sym(tpid::Int8)
return :interactive
elseif tpid == 1
return :default
elseif tpid == -1
return :foreign
else
throw(ArgumentError("Unrecognized threadpool id $tpid"))
end
Expand All @@ -73,6 +75,8 @@ function _sym_to_tpid(tp::Symbol)
return Int8(0)
elseif tp === :default
return Int8(1)
elseif tp == :foreign
return Int8(-1)
else
throw(ArgumentError("Unrecognized threadpool name `$(repr(tp))`"))
end
Expand All @@ -81,7 +85,7 @@ end
"""
Threads.threadpool(tid = threadid()) -> Symbol
Returns the specified thread's threadpool; either `:default` or `:interactive`.
Returns the specified thread's threadpool; either `:default`, `:interactive`, or `:foreign`.
"""
function threadpool(tid = threadid())
tpid = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1)
Expand All @@ -108,6 +112,8 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
function threadpoolsize(pool::Symbol = :default)
if pool === :default || pool === :interactive
tpid = _sym_to_tpid(pool)
elseif pool == :foreign
error("Threadpool size of `:foreign` is indeterminant")
else
error("invalid threadpool specified")
end
Expand Down Expand Up @@ -151,7 +157,7 @@ function threading_run(fun, static)
else
# TODO: this should be the current pool (except interactive) if there
# are ever more than two pools.
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, _sym_to_tpid(:default))
@assert ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, _sym_to_tpid(:default)) == 1
end
tasks[i] = t
schedule(t)
Expand Down Expand Up @@ -357,10 +363,10 @@ end

function _spawn_set_thrpool(t::Task, tp::Symbol)
tpid = _sym_to_tpid(tp)
if _nthreads_in_pool(tpid) == 0
if tpid == -1 || _nthreads_in_pool(tpid) == 0
tpid = _sym_to_tpid(:default)
end
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, tpid)
@assert ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, tpid) == 1
nothing
end

Expand Down
2 changes: 1 addition & 1 deletion src/partr.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ JL_DLLEXPORT int jl_set_task_tid(jl_task_t *task, int16_t tid) JL_NOTSAFEPOINT

JL_DLLEXPORT int jl_set_task_threadpoolid(jl_task_t *task, int8_t tpid) JL_NOTSAFEPOINT
{
if (tpid < 0 || tpid >= jl_n_threadpools)
if (tpid < -1 || tpid >= jl_n_threadpools)
return 0;
task->threadpoolid = tpid;
return 1;
Expand Down
2 changes: 1 addition & 1 deletion src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT
if (tid < n)
return (int8_t)i;
}
return 0; // everything else uses threadpool 0 (though does not become part of any threadpool)
return -1; // everything else uses threadpool -1 (does not belong to any threadpool)
}

jl_ptls_t jl_init_threadtls(int16_t tid)
Expand Down

0 comments on commit 13ec3ce

Please sign in to comment.