From eacdba978051dd3f282bf03cda75455e18a85125 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Tue, 28 Sep 2021 17:41:23 -0400 Subject: [PATCH] Add test_random_mutation and fix a bug in shrink --- src/dict.jl | 64 ++++++++++++++++--- .../src/test_dict.jl | 46 +++++++++++++ 2 files changed, 101 insertions(+), 9 deletions(-) diff --git a/src/dict.jl b/src/dict.jl index 318e760..bb0724a 100644 --- a/src/dict.jl +++ b/src/dict.jl @@ -339,7 +339,12 @@ function migrate!(dict::LinearProbingDict, oldslots, expand; basesize = nothing) if expand expand_parallel!(newslots, slots, basesize) else - migrate_serial!(newslots, slots) + if migrate_serial!(newslots, slots, Val(true)) === nothing + # Shrinking the slots have failed and rolled back. Keep using + # the old slots: + @atomic dict.slots = slots + return slots + end end # TODO: parallelize `shrink!` @@ -355,10 +360,13 @@ end function finishmove!(dict::LinearProbingDict, oldslots) lock(dict.migration) do slots = (@atomic dict.slots)::slots_type(dict) - # The caller observed `Moved` which only sets inside the `migration` - # lock. Thus, the migration should be finished once this lock is - # acquired: + # TODO: `oldslots` was used for a sanity check that `slots` is updated. + # However, now that shrink can rollback migration, it's not correct any + # more. Still keeping this code since it might make sense to actually + # not support shrinking. + #= @assert oldslots !== slots + =# return slots end end @@ -417,8 +425,7 @@ function expand_parallel_basecase!(newslots, slots, basesize, start0, ichunk, fi return (0, false) end - migrate_between(start, stop) = - migrate_serial_nofill!(newslots, slots, start, stop, Val(false))::Int + migrate_between(start, stop) = migrate_serial_nofill!(newslots, slots, start, stop)::Int # An empty slot is observed. There is at least one cluster started within # this chunk. @@ -500,8 +507,23 @@ function expand_parallel!(newslots, slots, basesize) end end -function migrate_serial!(newslots, slots) - nadded = migrate_serial_nofill!(newslots, slots, 1, length(slots), Val(false)) +function migrate_serial!( + newslots, + slots, + rollback_on_error::Union{Val{false},Val{true}} = Val(false), +) + nadded = migrate_serial_nofill!( + newslots, + slots, + 1, + length(slots), + Val(false), + rollback_on_error, + ) + if nadded === nothing + rollback_on_error::Val{true} + return nothing + end nadded = nadded::Int fill_undef!(newslots) return nadded @@ -530,7 +552,9 @@ function migrate_serial_nofill!( slots::AbstractVector{Slot}, start::Int, stop::Int, - stop_on_empty::Union{Val{false},Val{true}}, + # TODO: use custom singletong rather than Val: + stop_on_empty::Union{Val{false},Val{true}} = Val(false), + rollback_on_error::Union{Val{false},Val{true}} = Val(false), ) where {K,V,Slot<:DictSlot{K,V}} nadded = 0 for i in start:stop @@ -577,6 +601,10 @@ function migrate_serial_nofill!( nprobes += 1 if nprobes > c + if rollback_on_error === Val(true) + rollback_migration!(slots, start, i) + return nothing + end @static_error "unreachable: too many probings during migration" end end @@ -584,6 +612,24 @@ function migrate_serial_nofill!( return nadded end +function rollback_migration!(slots, start, stop) + for i in start:stop + ref = slots[i] + value = @atomic ref.value + while true + if value isa Moved{Empty} + value, ok = @atomicreplace ref.value value => Empty() + ok && break + # TODO: maybe simple store is OK, since no one should be touching it? + elseif value isa Moved + @static_error("unexpected Moved{Value}") + else + break + end + end + end +end + Base.IteratorSize(::Type{<:LinearProbingDict}) = Base.SizeUnknown() Base.IteratorSize(::Type{<:Base.KeySet{<:Any,<:LinearProbingDict}}) = Base.SizeUnknown() Base.IteratorSize(::Type{<:Base.ValueIterator{<:LinearProbingDict}}) = Base.SizeUnknown() diff --git a/test/ConcurrentCollectionsTests/src/test_dict.jl b/test/ConcurrentCollectionsTests/src/test_dict.jl index 360e2c3..39d9ef6 100644 --- a/test/ConcurrentCollectionsTests/src/test_dict.jl +++ b/test/ConcurrentCollectionsTests/src/test_dict.jl @@ -159,4 +159,50 @@ function test_shrink() end end +function random_mutation!(dict; nkeys = 8, repeat = 2^20, ntasks = Threads.nthreads()) + ks = 1:nkeys + locals = [ + ( + popped = zeros(valtype(dict), nkeys), # sum of popped values + added = zeros(valtype(dict), nkeys), # sum of all inserted values + ) for _ in 1:ntasks + ] + @sync for (; popped, added) in locals + Threads.@spawn begin + for _ in 1:repeat + k = rand(ks) + if rand(Bool) + y = trypop!(dict, k) + if y !== nothing + popped[k] += something(y) + end + else + added[k] += 1 + modify!(dict, k) do ref + Base.@_inline_meta + Some(ref === nothing ? 1 : ref[] + 1) + end + end + end + end + end + return locals +end + +function test_random_mutation(; kwargs...) + dict = ConcurrentDict{Int,Int}() + nkeys = 16 + locals = random_mutation!(dict; kwargs..., nkeys) + actual = zeros(valtype(dict), nkeys) + desired = zeros(valtype(dict), nkeys) + for (k, v) in dict + actual[k] = v + end + for (; popped, added) in locals + actual .+= popped + desired .+= added + end + @test actual == desired +end + end # module