/
threadtasks.jl
95 lines (87 loc) · 2.31 KB
/
threadtasks.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
struct ThreadTask
p::Ptr{UInt}
end
Base.pointer(tt::ThreadTask) = tt.p
@inline taskpointer(tid::T) where {T} = THREADPOOLPTR[] + tid*(THREADBUFFERSIZE%T)
@inline function _call(p::Ptr{UInt})
fptr = load(p + sizeof(UInt), Ptr{Cvoid})
assume(fptr ≠ C_NULL)
ccall(fptr, Cvoid, (Ptr{UInt},), p)
end
@inline function launch(f::F, tid::Integer, args::Vararg{Any,K}) where {F,K}
p = taskpointer(tid)
f(p, args...)
# exchange must happen atomically, to prevent it from switching to `WAIT` after reading
state = _atomic_xchg!(p, TASK)
state == WAIT && wake_thread!(tid)
return nothing
end
function (tt::ThreadTask)()
p = pointer(tt)
max_wait = one(UInt32) << 20
wait_counter = max_wait
GC.@preserve THREADPOOL begin
while true
if _atomic_state(p) == TASK
_call(p)
wait_counter = zero(UInt32)
continue
end
pause()
if (wait_counter += one(UInt32)) > max_wait
wait_counter = zero(UInt32)
_atomic_cas_cmp!(p, SPIN, WAIT) && Base.wait()
end
end
end
end
function _sleep(p::Ptr{UInt})
_atomic_store!(p, WAIT)
Base.wait();
return nothing
end
function sleep_all_tasks()
fptr = @cfunction(_sleep, Cvoid, (Ptr{UInt},))
for tid ∈ eachindex(TASKS)
p = taskpointer(tid)
ThreadingUtilities.store!(p, fptr, sizeof(UInt))
_atomic_cas_cmp!(p, SPIN, TASK)
end
for tid ∈ eachindex(TASKS)
wait(tid)
end
end
# 1-based tid, pushes into task 2-nthreads()
@noinline function wake_thread!(_tid::T) where {T <: Integer}
tid = _tid % Int
tidp1 = tid + one(tid)
assume(unsigned(length(Base.Workqueues)) > unsigned(tid))
assume(unsigned(length(TASKS)) > unsigned(tidp1))
@inbounds push!(Base.Workqueues[tidp1], TASKS[tid])
ccall(:jl_wakeup_thread, Cvoid, (Int16,), tid % Int16)
end
@noinline function checktask(tid)
t = TASKS[tid]
if istaskfailed(t)
display(t)
dump(t)
println()
initialize_task(tid)
return true
end
yield()
false
end
# 1-based tid
@inline wait(tid::Integer) = wait(taskpointer(tid), tid)
@inline wait(p::Ptr{UInt}) = wait(p, (p - THREADPOOLPTR[]) ÷ (THREADBUFFERSIZE))
@inline function wait(p::Ptr{UInt}, tid)
counter = 0x00000000
while _atomic_state(p) == TASK
pause()
if ((counter += 0x00000001) > 0x00010000)
checktask(tid) && return true
end
end
false
end