Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate foreign threads into a :foreign threadpool #50912

Merged
merged 3 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions base/partr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ end

function multiq_insert(task::Task, priority::UInt16)
tpid = ccall(:jl_get_task_threadpoolid, Int8, (Any,), task)
@assert tpid > -1
heap_p = multiq_size(tpid)
tp = tpid + 1

Expand Down Expand Up @@ -124,6 +125,9 @@ function multiq_deletemin()

tid = Threads.threadid()
tp = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1) + 1
if tp == 0 # Foreign thread
return nothing
end
tpheaps = heaps[tp]

@label retry
Expand Down Expand Up @@ -175,6 +179,9 @@ end
function multiq_check_empty()
tid = Threads.threadid()
tp = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1) + 1
if tp == 0 # Foreign thread
return true
end
for i = UInt32(1):length(heaps[tp])
if heaps[tp][i].ntasks != 0
return false
Expand Down
2 changes: 1 addition & 1 deletion base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ function enq_work(t::Task)
else
@label not_sticky
tp = Threads.threadpool(t)
if Threads.threadpoolsize(tp) == 1
if tp === :foreign || Threads.threadpoolsize(tp) == 1
# There's only one thread in the task's assigned thread pool;
# use its work queue.
tid = (tp === :interactive) ? 1 : Threads.threadpoolsize(:interactive)+1
Expand Down
14 changes: 10 additions & 4 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ function _tpid_to_sym(tpid::Int8)
return :interactive
elseif tpid == 1
return :default
elseif tpid == -1
return :foreign
else
throw(ArgumentError("Unrecognized threadpool id $tpid"))
end
Expand All @@ -73,6 +75,8 @@ function _sym_to_tpid(tp::Symbol)
return Int8(0)
elseif tp === :default
return Int8(1)
elseif tp == :foreign
return Int8(-1)
else
throw(ArgumentError("Unrecognized threadpool name `$(repr(tp))`"))
end
Expand All @@ -81,7 +85,7 @@ end
"""
Threads.threadpool(tid = threadid()) -> Symbol

Returns the specified thread's threadpool; either `:default` or `:interactive`.
Returns the specified thread's threadpool; either `:default`, `:interactive`, or `:foreign`.
"""
function threadpool(tid = threadid())
tpid = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1)
Expand All @@ -108,6 +112,8 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
function threadpoolsize(pool::Symbol = :default)
if pool === :default || pool === :interactive
tpid = _sym_to_tpid(pool)
elseif pool == :foreign
error("Threadpool size of `:foreign` is indeterminant")
else
error("invalid threadpool specified")
end
Expand Down Expand Up @@ -151,7 +157,7 @@ function threading_run(fun, static)
else
# TODO: this should be the current pool (except interactive) if there
# are ever more than two pools.
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, _sym_to_tpid(:default))
@assert ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, _sym_to_tpid(:default)) == 1
end
tasks[i] = t
schedule(t)
Expand Down Expand Up @@ -357,10 +363,10 @@ end

function _spawn_set_thrpool(t::Task, tp::Symbol)
tpid = _sym_to_tpid(tp)
if _nthreads_in_pool(tpid) == 0
if tpid == -1 || _nthreads_in_pool(tpid) == 0
tpid = _sym_to_tpid(:default)
end
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, tpid)
@assert ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, tpid) == 1
nothing
end

Expand Down
2 changes: 1 addition & 1 deletion src/partr.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ JL_DLLEXPORT int jl_set_task_tid(jl_task_t *task, int16_t tid) JL_NOTSAFEPOINT

JL_DLLEXPORT int jl_set_task_threadpoolid(jl_task_t *task, int8_t tpid) JL_NOTSAFEPOINT
{
if (tpid < 0 || tpid >= jl_n_threadpools)
if (tpid < -1 || tpid >= jl_n_threadpools)
return 0;
task->threadpoolid = tpid;
return 1;
Expand Down
2 changes: 1 addition & 1 deletion src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT
if (tid < n)
return (int8_t)i;
}
return 0; // everything else uses threadpool 0 (though does not become part of any threadpool)
return -1; // everything else uses threadpool -1 (does not belong to any threadpool)
}

jl_ptls_t jl_init_threadtls(int16_t tid)
Expand Down