@@ -69,15 +69,25 @@ function add_var!(q, argtup, gcpres, ::Type{T}, argtupname, gcpresname, k) where
6969 end
7070end
7171@generated function _batch_no_reserve (
72- f!:: F , threadmask_tuple:: NTuple{N} , nthread_tuple, torelease_tuple, Nr, Nd, ulen, args:: Vararg{Any,K} ; threadlocal:: Bool = false
73- ) where {F,K,N}
72+ f!:: F , threadmask_tuple:: NTuple{N} , nthread_tuple, torelease_tuple, Nr, Nd, ulen, args:: Vararg{Any,K} ; threadlocal:: Val{thread_local} = Val ( false )
73+ ) where {F,K,N,thread_local }
7474 q = quote
7575 $ (Expr (:meta ,:inline ))
7676 # threads = UnsignedIteratorEarlyStop(threadmask, nthread)
7777 # threads_tuple = map(UnsignedIteratorEarlyStop, threadmask_tuple, nthread_tuple)
7878 # nthread_total = sum(nthread_tuple)
7979 Ndp = Nd + one (Nd)
8080 end
81+ launch_quote = if thread_local
82+ :(launch_batched_thread! (cfunc, tid, argtup, start, stop, i% UInt))
83+ else
84+ :(launch_batched_thread! (cfunc, tid, argtup, start, stop))
85+ end
86+ rem_quote = if thread_local
87+ :(f! (arguments, (start+ one (UInt)) % Int, ulen % Int, (sum (nthread_tuple)+ 1 )% Int))
88+ else
89+ :(f! (arguments, (start+ one (UInt)) % Int, ulen % Int))
90+ end
8191 block = quote
8292 start = zero (UInt)
8393 tid = 0x00000000
92102 tz += 0x00000001
93103 tid += tz
94104 tm >>>= tz
95- if threadlocal
96- launch_batched_thread! (cfunc, tid, argtup, start, stop, i% UInt)
97- else
98- launch_batched_thread! (cfunc, tid, argtup, start, stop)
99- end
105+ $ launch_quote
100106 start = stop
101107 end
102108 Nr -= nthread
103109 end
104- if threadlocal
105- f! (arguments, (start+ one (UInt)) % Int, ulen % Int, (sum (nthread_tuple)+ 1 )% Int)
106- else
107- f! (arguments, (start+ one (UInt)) % Int, ulen % Int)
108- end
110+ $ rem_quote
109111 for (threadmask, nthread, torelease) ∈ zip (threadmask_tuple, nthread_tuple, torelease_tuple)
110112 tm = mask (UnsignedIteratorEarlyStop (threadmask, nthread))
111113 tid = 0x00000000
127129 for k ∈ 1 : K
128130 add_var! (q, argt, gcpr, args[k], :args , :gcp , k)
129131 end
130- push! (q. args, :(arguments = $ argt), :(argtup = Reference (arguments)), :(cfunc = batch_closure (f!, argtup, Val {false} (), Val {threadlocal } ())), gcpr)
132+ push! (q. args, :(arguments = $ argt), :(argtup = Reference (arguments)), :(cfunc = batch_closure (f!, argtup, Val {false} (), Val {$thread_local } ())), gcpr)
131133 push! (q. args, nothing )
132134 q
133135end
@@ -227,15 +229,15 @@ end
227229
228230
229231@inline function batch (
230- f!:: F , (len, nbatches):: Tuple{Vararg{Integer,2}} , args:: Vararg{Any,K} ; threadlocal:: Bool = false
231- ) where {F,K}
232+ f!:: F , (len, nbatches):: Tuple{Vararg{Integer,2}} , args:: Vararg{Any,K} ; threadlocal:: Val{thread_local} = Val { false} ()
233+ ) where {F,K,thread_local }
232234 # threads, torelease = request_threads(Base.Threads.threadid(), nbatches - one(nbatches))
233235 threads, torelease = request_threads (nbatches - one (nbatches))
234236 nthreads = map (length,threads)
235237 nthread = sum (nthreads)
236238 ulen = len % UInt
237239 if nthread % Int32 ≤ zero (Int32)
238- if threadlocal
240+ if thread_local
239241 f! (args, one (Int), ulen % Int, 1 )
240242 else
241243 f! (args, one (Int), ulen % Int)
@@ -246,12 +248,12 @@ end
246248 Nd = Base. udiv_int (ulen, nbatch % UInt) # reasonable for `ulen` to be ≥ 2^32
247249 Nr = ulen - Nd * nbatch
248250
249- _batch_no_reserve (f!, map (mask,threads), nthreads, torelease, Nr, Nd, ulen, args... ; threadlocal= threadlocal )
251+ _batch_no_reserve (f!, map (mask,threads), nthreads, torelease, Nr, Nd, ulen, args... ; threadlocal)
250252end
251253function batch (
252- f!:: F , (len, nbatches, reserve_per_worker):: Tuple{Vararg{Integer,3}} , args:: Vararg{Any,K} ; threadlocal:: Bool = false
253- ) where {F,K}
254- batch (f!, (len, nbatches), args... ; threadlocal= false )
254+ f!:: F , (len, nbatches, reserve_per_worker):: Tuple{Vararg{Integer,3}} , args:: Vararg{Any,K} ; threadlocal:: Val{thread_local} = Val ( false )
255+ ) where {F,K,thread_local }
256+ batch (f!, (len, nbatches), args... ; threadlocal)
255257 # ulen = len % UInt
256258 # if nbatches > 1
257259 # requested_threads = reserve_per_worker*nbatches
0 commit comments