diff --git a/src/workstealing.jl b/src/workstealing.jl index ee18c19..679811e 100644 --- a/src/workstealing.jl +++ b/src/workstealing.jl @@ -30,6 +30,8 @@ Base.@propagate_inbounds function Base.setindex!(A::CircularVector, v, i::Int) A.data[indexof(A, i)] = v end +Base.pointer(A::CircularVector, i::Integer) = pointer(A.data, indexof(A, i)) + function tryresize(A::CircularVector, log2inc::Integer, indices) log2length = A.log2length + log2inc n = 1 << log2length @@ -134,11 +136,26 @@ function ConcurrentCollections.trypopfirst!(deque::WorkStealingDeque) if current_size <= 0 return nothing end - r = Some(buffer[top]) - if @atomicreplace(deque.top, top => top + 1)[2] - return r + if Base.allocatedinline(eltype(buffer)) + r = Some(buffer[top]) + if @atomicreplace(deque.top, top => top + 1)[2] + return r + else + return nothing + end else - return nothing + ptr = UnsafeAtomics.load(Ptr{Ptr{Cvoid}}(pointer(buffer, top)), monotonic) + if @atomicreplace(deque.top, top => top + 1)[2] + # Safety: The above CAS verifies that the slot `buffer[top]` + # contained the valid element. We can now materialize it as a Julia + # value. + GC.@preserve buffer begin + r = Some(unsafe_pointer_to_objref(ptr)) + end + return r + else + return nothing + end end end diff --git a/test/ConcurrentCollectionsTests/src/test_work_stealing_deque.jl b/test/ConcurrentCollectionsTests/src/test_work_stealing_deque.jl index c737c90..773212a 100644 --- a/test/ConcurrentCollectionsTests/src/test_work_stealing_deque.jl +++ b/test/ConcurrentCollectionsTests/src/test_work_stealing_deque.jl @@ -28,18 +28,19 @@ function random_pushpop(xs, ntasks = Threads.nthreads() - 1) deque = WorkStealingDeque{eltype(xs)}() local tasks, zs - done = Threads.Atomic{Bool}(false) try tasks = map(1:ntasks) do _ Threads.@spawn begin local ys = eltype(xs)[] while true - r = trypopfirst!(deque) + local r = trypopfirst!(deque) if r === nothing - done[] && break + GC.safepoint() continue end - push!(ys, something(r)) + local y = something(r) + y == -1 && break + push!(ys, y) end ys end @@ -51,20 +52,25 @@ function random_pushpop(xs, ntasks = Threads.nthreads() - 1) # continue if mod(i, 8) == 0 r = trypop!(deque) + GC.safepoint() r === nothing && continue push!(zs, something(r)) end end finally - done[] = true + for _ in 1:ntasks + push!(deque, -1) + end end return zs, fetch.(tasks) end function test_random_push_pop() - @testset for T in [Int, Any, Int, Any] - test_random_push_pop(T) + @testset for trial in 1:100 + @testset for T in [Int, Any] + test_random_push_pop(T) + end end end @@ -77,8 +83,11 @@ function test_random_push_pop(T::Type, xs = 1:2^10) @test all(allunique, yss) @debug "random_pushpop(xs)" length(zs) length.(yss) ys = sort!(foldl(append!, yss; init = copy(zs))) - @debug "random_pushpop(xs)" setdiff(ys, xs) setdiff(xs, ys) length(xs) length(ys) + @test length(ys) == length(xs) + @test setdiff(ys, xs) == [] + @test setdiff(xs, ys) == [] @test ys == xs + return (; zs, yss) end end # module