From 22760efcc4b500b7d9482014d6a7ae35e932df1d Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Mon, 13 Sep 2021 19:04:56 -0400 Subject: [PATCH 1/2] Fix dict --- .../src/ConcurrentCollectionsBenchmarks.jl | 1 + .../src/bench_dict_haskey.jl | 64 ++ .../src/bench_dict_histogram.jl | 3 + .../src/bench_dict_migration.jl | 16 +- src/dict.jl | 955 +++++++++--------- src/utils.jl | 29 + .../src/ConcurrentCollectionsTests.jl | 1 - .../src/test_bench_dict_histogram.jl | 10 + .../src/test_dict.jl | 88 +- .../src/test_dict_impl.jl | 18 - 10 files changed, 665 insertions(+), 520 deletions(-) create mode 100644 benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_haskey.jl delete mode 100644 test/ConcurrentCollectionsTests/src/test_dict_impl.jl diff --git a/benchmark/ConcurrentCollectionsBenchmarks/src/ConcurrentCollectionsBenchmarks.jl b/benchmark/ConcurrentCollectionsBenchmarks/src/ConcurrentCollectionsBenchmarks.jl index 9a675e0..64fe5a1 100644 --- a/benchmark/ConcurrentCollectionsBenchmarks/src/ConcurrentCollectionsBenchmarks.jl +++ b/benchmark/ConcurrentCollectionsBenchmarks/src/ConcurrentCollectionsBenchmarks.jl @@ -4,6 +4,7 @@ using BenchmarkTools: Benchmark, BenchmarkGroup include("utils.jl") include("bench_dict_histogram.jl") +include("bench_dict_haskey.jl") include("bench_dict_get_existing.jl") include("bench_dict_migration.jl") include("bench_queue_pushpop.jl") diff --git a/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_haskey.jl b/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_haskey.jl new file mode 100644 index 0000000..5ac1511 --- /dev/null +++ b/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_haskey.jl @@ -0,0 +1,64 @@ +module BenchDictHasKey + +using BenchmarkTools +using ConcurrentCollections + +function generate(; + datasize = 2^13, # `Base.Dict` is better on smaller size + keysize = 50, # expensive isequal; favors ConcurrentDict + nkeys = 100, +) + vs = UInt64.(1:datasize) + ks = string.(vs; pad = keysize) + # ks = vs + # ks = UInt32.(vs) + cdict = ConcurrentDict{eltype(ks),eltype(vs)}(zip(ks, vs)) + ks_100 = ks[1:nkeys] + ks_000 = string.(.-vs[1:nkeys]) + ks_050 = ifelse.(isodd.(vs[1:nkeys]), ks_100, ks_000) + return (; cdict, ks_100, ks_000, ks_050) +end + +const CACHE = Ref{Any}() + +function setup(; cases = [:ks_050, :ks_000], kwargs...) + data = generate(; kwargs...) + (; cdict) = data + dict = Dict(cdict) + CACHE[] = (; dict, data...) + + labelmap = Dict( + :ks_100 => "100% existing", + :ks_050 => "50% existing", # `Base.Dict` is better with 50% hit + :ks_000 => "0% existing", + ) + + suite = BenchmarkGroup() + for ksprop in cases + s1 = suite[labelmap[ksprop]] = BenchmarkGroup() + ks = getproperty(data, ksprop) + s1["base-seq"] = @benchmarkable( + count(k -> haskey(dict, k), ks), + setup = begin + dict = CACHE[].dict::$(typeof(dict)) + ks = CACHE[].$ksprop::$(typeof(ks)) + end, + evals = 1, + ) + s1["cdict-seq"] = @benchmarkable( + count(k -> haskey(dict, k), ks), + setup = begin + dict = CACHE[].cdict::$(typeof(cdict)) + ks = CACHE[].$ksprop::$(typeof(ks)) + end, + evals = 1, + ) + end + return suite +end + +function clear() + CACHE[] = nothing +end + +end # module diff --git a/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_histogram.jl b/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_histogram.jl index 5dac817..fdb8ae8 100644 --- a/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_histogram.jl +++ b/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_histogram.jl @@ -37,6 +37,9 @@ function hist_seq!(dict::ConcurrentDict, data) end function hist_parallel!(dict::ConcurrentDict, data; ntasks = Threads.nthreads()) + # for k in data + # dict[k] = 0 + # end @sync for chunk in Iterators.partition(data, cld(length(data), ntasks)) Threads.@spawn hist_seq!(dict, chunk) end diff --git a/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_migration.jl b/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_migration.jl index 76a9b6b..4c0dbef 100644 --- a/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_migration.jl +++ b/benchmark/ConcurrentCollectionsBenchmarks/src/bench_dict_migration.jl @@ -2,7 +2,8 @@ module BenchDictMigration using BenchmarkTools using ConcurrentCollections -using ConcurrentCollections.Implementations: LINEAR_PROBING_DICT_EXPAND_BASESIZE, migrate! +using ConcurrentCollections.Implementations: + LINEAR_PROBING_DICT_EXPAND_BASESIZE, migrate_serial!, new_slots_and_pairnodes pad16(x) = string(x; pad = 16) @@ -10,7 +11,7 @@ function generate(f = pad16; datasize = LINEAR_PROBING_DICT_EXPAND_BASESIZE[]) vs = UInt64.(1:datasize) ks = f.(vs) dict = ConcurrentDict{eltype(ks),eltype(vs)}(zip(ks, vs)) - return dict.slots + return dict end const CACHE = Ref{Any}() @@ -25,12 +26,15 @@ function setup(; generate_options...) suite = BenchmarkGroup() for key in keys(CACHE[]) - SlotsType = typeof(CACHE[][key]) + CacheType = typeof(CACHE[][key]) suite[key] = @benchmarkable( - migrate!(newslots, slots), + migrate_serial!(newslots, newpairnodes, slots, pairnodes), setup = begin - slots = copy(CACHE[][$key]::$SlotsType) - newslots = similar(slots, length(slots) * 2) + dict = CACHE[][$key]::$CacheType + slots = copy(dict.slots) + pairnodes = copy(dict.pairnodes) + newslots, newpairnodes = + new_slots_and_pairnodes(slots, pairnodes, true) end, evals = 1, ) diff --git a/src/dict.jl b/src/dict.jl index 5a3525c..83fce8c 100644 --- a/src/dict.jl +++ b/src/dict.jl @@ -1,145 +1,88 @@ -struct Moved{Key} - key::Key +@enum LPDKeyState::UInt8 LPD_EMPTY LPD_DELETED LPD_MOVED_EMPTY LPD_MOVED LPD_HASKEY +const LPD_NBITS = ceil(Int, log2(maximum(Int.(instances(LPDKeyState))) + 1)) +const LPD_BITMASK = ~(typemax(UInt8) << LPD_NBITS) + +struct KeyInfo{T<:Union{UInt32,UInt64}} + bits::T end -struct Empty end -struct MovedEmpty end -struct Deleted end -struct NoValue end - -const KeyUnion{Key} = Union{ - Key, # data is stored - Moved{Key}, # data is moved - MovedEmpty, # empty slot not usable anymore due to migration - Empty, # empty slot - Deleted, # deleted -} - -const RefKeyUnion{Key} = Union{ - Key, - RefValue{Moved{Key}}, # heap allocate Moved{Key} if Key is heap allocated - MovedEmpty, - Empty, - Deleted, -} - -abstract type AbstractPair{Key,Value} end -stored_key_type(::Type{AbstractPair{Key}}) where {Key} = Key -stored_value_type(::Type{AbstractPair{<:Any,Value}}) where {Value} = Value - -struct InlinedPair{Key,Value,KPad,VPad} <: AbstractPair{Key,Value} - key::IPadder{Inlined{KeyUnion{Key}},KPad} - value::IPadder{Value,VPad} -end - -@inline getkey(pair::InlinedPair) = pair.key.x.x -@inline getvalue(pair::InlinedPair) = pair.value.x - -function inlinedpair_type(::Type{Key}, ::Type{Value}) where {Key,Value} - KPad = padsize_for_cas(Inlined{KeyUnion{Key}}) - VPad = padsize_for_cas(InlinedPair{Key,Value,KPad,0}) - return InlinedPair{Key,Value,KPad,VPad} -end - -@inline InlinedPair{Key,Value}(key::KeyUnion{Key}, value::Value) where {Key,Value} = - inlinedpair_type(Key, Value)(key, value) - -@inline function InlinedPair{Key,Value,KPad,VPad}( - key::KeyUnion{Key}, - value::Value, -) where {Key,Value,KPad,VPad} - k = IPadder{Inlined{KeyUnion{Key}},KPad}(key) - v = IPadder{Value,VPad}(value) - return InlinedPair{Key,Value,KPad,VPad}(k, v) -end - -@inline function InlinedPair{Key,Value,KPad,VPad}( - key::KeyUnion{Key}, -) where {Key,Value,KPad,VPad} - if NoValue <: Value - InlinedPair{Key,Value,KPad,VPad}(key, NoValue()) - elseif Value <: Ref - InlinedPair{Key,Value,KPad,VPad}(key, Value()) +@inline function LPDKeyState(ki::KeyInfo{T}) where {T} + bits = getfield(ki, :bits) + statebits = bits & T(LPD_BITMASK) + state = statebits % UInt8 + if assertion_enabled() + return LPDKeyState(state) + end + if state <= UInt8(LPD_HASKEY) + return LPDKeyState(state) else - InlinedPair{Key,Value,KPad,VPad}(key, zerofill(Value)) + return LPD_HASKEY end end -macro _deref_moved(ex) - quote - x = $(esc(ex)) - if x isa RefValue - y = x[] - if y isa Moved - y - else - x - end - else - x +Base.zero(::Type{KeyInfo{T}}) where {T} = KeyInfo{T}(zero(T)) + +KeyInfo{T}(state::LPDKeyState, keydata::T) where {T} = + setstate(KeyInfo(keydata << LPD_NBITS), state) + +@inline function Base.getproperty(ki::KeyInfo, name::Symbol) + bits = getfield(ki, :bits) + if name === :state + return LPDKeyState(ki) + elseif name === :isempty + if assertion_enabled() + @assert iszero(bits) == (LPDKeyState(ki) === LPD_EMPTY) end + return iszero(bits) + elseif name === :isdeleted + return LPDKeyState(ki) === LPD_DELETED + elseif name === :ismovedempty + return LPDKeyState(ki) === LPD_MOVED_EMPTY + elseif name === :ismoved + return LPDKeyState(ki) === LPD_MOVED + elseif name === :haskey + return LPDKeyState(ki) === LPD_HASKEY + elseif name === :keydata + return bits >> LPD_NBITS + end + return getfield(ki, name) +end + +@inline function setstate(ki::KeyInfo{T}, state::LPDKeyState) where {T} + bits = getfield(ki, :bits) + if state === LPD_EMPTY + return KeyInfo(zero(bits)) + else + return KeyInfo((bits & ~T(LPD_BITMASK)) | UInt8(state)) end end -struct BoxedKeyPair{Key,Value,VPad} <: AbstractPair{Key,Value} - key::RefKeyUnion{Key} - value::IPadder{Value,VPad} +@inline function setdata(ki::KeyInfo{T}, keydata::T) where {T} + bits = getfield(ki, :bits) + state = bits & T(LPD_BITMASK) + databits = keydata << LPD_NBITS + return KeyInfo(databits | state) end -@inline getkey(pair::BoxedKeyPair) = @_deref_moved(pair.key) -@inline getvalue(pair::BoxedKeyPair) = pair.value.x - -function boxedkeypair_type(::Type{Key}, ::Type{Value}) where {Key,Value} - P = BoxedKeyPair{Key,Value,8 - sizeof(Value)} - fieldoffset(P, 2) == 8 || @static_error("invalid key type") - sizeof(P) == 16 || @static_error("invalid value size") - return P -end - -@inline BoxedKeyPair{Key,Value}( - key::Union{RefKeyUnion{Key},Moved{Key}}, - value::Value, -) where {Key,Value} = boxedkeypair_type(Key, Value)(key, value) - -@inline BoxedKeyPair{Key,Value,VPad}(key::Moved{Key}, value::Value) where {Key,Value,VPad} = - BoxedKeyPair{Key,Value,VPad}(Ref(Moved(key)), value) - -@inline function BoxedKeyPair{Key,Value,VPad}( - key::RefKeyUnion{Key}, - value::Value, -) where {Key,Value,VPad} - k = key - v = IPadder{Value,VPad}(value) - return BoxedKeyPair{Key,Value,VPad}(k, v) +mutable struct AtomicRef{T} + @atomic value::Union{Nothing,T} end +AtomicRef{T}() where {T} = AtomicRef{T}(nothing) +Base.eltype(::Type{AtomicRef{T}}) where {T} = T -@inline function BoxedKeyPair{Key,Value,VPad}( - key::Union{RefKeyUnion{Key},Moved{Key}}, -) where {Key,Value,VPad} - if NoValue <: Value - BoxedKeyPair{Key,Value,VPad}(key, NoValue()) - elseif Value <: Ref - BoxedKeyPair{Key,Value,VPad}(key, Value()) - else - BoxedKeyPair{Key,Value,VPad}(key, zerofill(Value)) - end -end - -function Base.show(io::IO, pair::AbstractPair) - @nospecialize pair - k = getkey(pair) - v = getvalue(pair) - if get(io, :typeinfo, nothing) !== typeof(pair) - print(io, typeof(pair), "(", k, ", ", v, ")") - return - end - show(io, k) - printstyled(io, " => "; color = :light_black) - show(io, v) +mutable struct PairNode{Key,Value} + slotid::UInt64 + key::Key + @atomic value::Value + @atomic next::Union{PairNode{Key,Value},Nothing} end -mutable struct LinearProbingDict{Key,Value,Slot} <: ConcurrentDict{Key,Value} - slots::Vector{Slot} +mutable struct LinearProbingDict{Key,Value} <: ConcurrentDict{Key,Value} + # TODO: Use `Vector{UInt128}` + @atomic slots::Union{Vector{UInt64},Nothing} + # TODO: Use `PairNode{Key,Value}` if `Value` is too large? + @atomic pairnodes::Vector{AtomicRef{PairNode{Key,Value}}} + slotids::typeof(cacheline_padded_vector(UInt64, 1)) migration::ReentrantLock # TODO: per-thread non-atomic counter for approximating deleted elements nadded::Threads.Atomic{Int} @@ -157,221 +100,103 @@ ConcurrentCollections.ConcurrentDict{Key,Value}() where {Key,Value} = LinearProbingDict{Key,Value}() function LinearProbingDict{Key,Value}() where {Key,Value} - # TODO: handle the case where key, value, and the metadata fits in an UInt - # TODO: use BoxedKeyPair if Value is small - FallbackSlot = RefValue{InlinedPair{Key,Value,0,0}} - if !isinlinable(Key) - Slot = boxedkeypair_type(Key, Value) - elseif aligned_sizeof(fieldtype(inlinedpair_type(Key, Value), 1)) <= 8 # atomic load works - if isinlinable(Value) && aligned_sizeof(inlinedpair_type(Key, Value)) <= 16 - Slot = inlinedpair_type(Key, Value) - else - Slot = something( - cas_compatible(inlinedpair_type(Key, Union{Value,NoValue})), - # cas_compatible(inlinedpair_type(Key, RefValue{Value})), - FallbackSlot, - ) - end - else - Slot = FallbackSlot - end - if !(Slot <: Ref) - @assert Base.allocatedinline(Slot) - end capacity = 4 # TODO: customizable init size? - slots = emptyslots(Slot, capacity) - let diff = pointer(slots, 2) - pointer(slots, 1) - if !ispow2(diff) - error("implementation error: slot size is not a power of 2: $diff") - end - if diff > 16 - error("implementation error: slot size too big: $diff") - end - end - return LinearProbingDict{Key,Value,Slot}( + # TODO: handle the case where key, value, and the metadata fits in an UInt + # TODO: check the availability of CAS2? + # TODO: use `PairNode{Key,Any}` if `Value` is too large + pairnodes = [AtomicRef{PairNode{Key,Value}}() for _ in 1:capacity] + slots = zeros(UInt64, 2 * capacity) + slotids = cacheline_padded_vector(UInt64, Threads.nthreads()) + slotids .= eachindex(slotids) + return LinearProbingDict{Key,Value}( slots, + pairnodes, + slotids, ReentrantLock(), Threads.Atomic{Int}(0), Threads.Atomic{Int}(0), ) end -emptyslots(::Type{Slot}, length::Integer) where {Slot} = - fillempty!(Vector{Slot}(undef, length)) - -fillempty!(slots::AbstractVector{Slot}) where {Slot<:AbstractPair} = - fill!(slots, Slot(Empty())) - -fillempty!(slots::AbstractVector{Slot}) where {P,Slot<:Ref{P}} = - fill!(slots, Slot(P(Empty()))) -# TODO: use undef as empty - -value_uint_type(::Type{Slot}) where {Value,Slot<:AbstractPair{<:Any,Value}} = - uint_for(Value) - -mutable struct InlinedSlotRef{Slot,KUInt,VUInt} - ptr::Ptr{Cvoid} - keyint::KUInt - valueint::VUInt - value_loaded::Bool - - @inline function InlinedSlotRef{Slot}(ptr::Ptr{Cvoid}, keyint::KUInt) where {Slot,KUInt} - VUInt = value_uint_type(Slot) - return new{Slot,KUInt,VUInt}(ptr, keyint, VUInt(0), false) - end -end - -@inline function load_slot( - slots::AbstractVector{Slot}, +@inline function prepare_pairnode!( + pairnodes::Vector{AtomicRef{Node}}, index, -) where {Key,Slot<:AbstractPair{Key}} - ptr = Ptr{Cvoid}(pointer(slots, index)) - if Slot <: InlinedPair - KUInt = uint_for(Inlined{KeyUnion{Key}}) - else - KUInt = UInt - end - keyptr = Ptr{KUInt}(ptr) - keyint = UnsafeAtomics.load(keyptr) - return InlinedSlotRef{Slot}(ptr, keyint) -end - -@inline getkey(slotref::InlinedSlotRef{Slot}) where {Key,Slot<:InlinedPair{Key}} = - from_bytes(Inlined{KeyUnion{Key}}, slotref.keyint).x - -@inline getkey(slotref::InlinedSlotRef{Slot}) where {Key,Slot<:BoxedKeyPair{Key}} = - @_deref_moved(unsafe_pointer_to_objref(Ptr{Ptr{Cvoid}}(slotref.keyint)))::KeyUnion{Key} - -@inline function load_valueint( - ::Type{Slot}, - ptr, -) where {Key,Value,Slot<:AbstractPair{Key,Value}} - VUInt = uint_for(Value) - valueptr = Ptr{VUInt}(ptr + fieldoffset(Slot, 2)) - valueint = UnsafeAtomics.load(valueptr) - return valueint -end - -@inline function Base.getindex( - slotref::InlinedSlotRef{Slot}, -) where {Key,Value,Slot<:AbstractPair{Key,Value}} - if slotref.value_loaded - valueint = slotref.valueint - else - valueint = load_valueint(Slot, slotref.ptr) - slotref.valueint = valueint - slotref.value_loaded = true + slotid, + key, + value, +) where {Node<:PairNode} + # node = Node(slotid, key, value, nothing) + ref = pairnodes[index] + head = @atomic ref.value + while true + # Allocating `Node` each time instead of `@atomic node.next = head` + # below is better. It looks like avoiding `@atomic` and optimizing for + # the happy case is better for the performance? + node = Node(slotid, key, value, head) + # @atomic node.next = head + head, success = @atomicreplace ref.value head => node + success && return end - value = from_bytes(Value, valueint) - return value end -@inline value_ref(slotref::InlinedSlotRef) = slotref -# Note on `modify!` design: It looks like (even relaxed) atomic load is not -# eliminated when the value is not used (). -# So, let's pass a `Ref`-like object to `modify!` and so that load is not issued -# when the user does request. - -allocate_slot(::AbstractVector{<:AbstractPair}) = nothing - -@inline cas_slot!(slotref, new_slot, root, key) = - cas_slot!(slotref, new_slot, root, key, NoValue()) - -@inline function cas_slot!( - slotref::InlinedSlotRef{Slot}, - ::Nothing, - root, - key::KeyUnion{Key}, - value, -) where {Key,Value,Slot<:AbstractPair{Key,Value}} - ptr = slotref.ptr - UIntType = uint_for(Slot) - if slotref.value_loaded - oldvalueint = slotref.valueint - else - oldvalueint = load_valueint(Slot, ptr) - end - if value isa NoValue - # TODO: store NoValue when Value <: NoValue - newvalueint = oldvalueint - else - # TODO: handle `Value isa Union` - newvalueint = uint_from(value) +@inline function cleanup_pairnode!(slots, pairnodes, index) + GC.@preserve slots begin + s2ptr = pointer(slots, 2 * index) + slotid = UnsafeAtomics.load(s2ptr) end - handle = nothing - if Slot <: InlinedPair - newkeyint = uint_from(Inlined{KeyUnion{Key}}(key)) - elseif key isa Moved{Key} - handle = Ref{Any}(Ref(key)) - GC.@preserve handle begin - newkeyint = unsafe_load(Ptr{UInt}(pointer_from_objref(handle))) + iszero(slotid) && unreachable() + ref = pairnodes[index] + node = (@atomic ref.value)::PairNode + while true + node.slotid == slotid && break + next = @atomic node.next + old, success = @atomicreplace ref.value node => next + if success + node = next::PairNode + else + node = old::PairNode end - else - newkeyint = UInt(_pointer_from_objref(key)) end - # ns = Slot(key, value) - nu = UIntType(newvalueint) - nu <<= fieldoffset(Slot, 2) * 8 - nu |= newkeyint - ou = UIntType(oldvalueint) - ou <<= fieldoffset(Slot, 2) * 8 - ou |= slotref.keyint - GC.@preserve handle begin - fu = UnsafeAtomics.cas!(Ptr{typeof(nu)}(ptr), ou, nu) + next = @atomic node.next + if next !== nothing + @atomicreplace node.next next => nothing end - # @show ou nu fu - return fu == ou -end - -struct RefSlotRef{R} - ptr::Ptr{Cvoid} - ref::R + return end -@inline function load_slot( - slots::AbstractVector{Slot}, - index, -) where {Key,Value,Slot<:Ref{<:InlinedPair{Key,Value}}} - ptr = Ptr{Cvoid}(pointer(slots, index)) - int = UnsafeAtomics.load(Ptr{UInt}(ptr)) - ref = unsafe_pointer_to_objref(Ptr{Cvoid}(int))::Slot - return RefSlotRef(ptr, ref) +@inline function load_pairnode(pairnodes, index, slotid) + ref = pairnodes[index] + node = (@atomic ref.value)::PairNode + while true + node.slotid == slotid && return node + node = (@atomic node.next)::PairNode + end end -@inline getkey(slotref::RefSlotRef) = slotref.ref[].key.x.x - -struct ImmutableRef{T} - x::T +mutable struct ValueRef{Value,Node<:PairNode} + node::Node + isloaded::Bool + value::Value + ValueRef{Value}(node::Node) where {Value,Node} = new{Value,Node}(node, false) end +# Note on `modify!` design: It looks like (even relaxed) atomic load is not +# eliminated when the value is not used (). +# So, let's pass a `Ref`-like object to `modify!` and so that load is not issued +# when the user does not look at the value. -@inline Base.getindex(r::ImmutableRef) = r.x - -@inline value_ref(ref::RefSlotRef) = ImmutableRef(ref.ref[].value.x) - -allocate_slot(::AbstractVector{Slot}) where {Slot<:Ref} = Ref(Slot()) -# One indirection to force heap allocation - -@inline function cas_slot!( - slotref::RefSlotRef{Slot}, - new_slot_ref::Ref{Slot}, - root, - key, - value, -) where {P,Slot<:Ref{P}} - ptr = slotref.ptr - new_slot = new_slot_ref[] - new_slot[] = value isa NoValue ? P(key) : P(key, value) - ref = slotref.ref - GC.@preserve ref new_slot_ref begin - ou = UInt(pointer_from_objref(ref)) - nu = UInt(unsafe_load(Ptr{Ptr{Cvoid}}(pointer_from_objref(new_slot_ref)))) - fu = UnsafeAtomics.cas!(Ptr{typeof(nu)}(ptr), ou, nu) +@inline function Base.getindex(ref::ValueRef) + if !ref.isloaded + node = ref.node + ref.value = @atomic node.value + ref.isloaded = true end - return fu == ou + return ref.value end -make_slot(::Type{P}, k, v) where {P} = P(k, v) -make_slot(::Type{R}, k, v) where {P,R<:Ref{P}} = R(P(k, v)) +@inline function Base.setindex!(ref::ValueRef, x) + ref.value = x + ref.isloaded = true +end function Base.getindex(d::LinearProbingDict{Key}, key) where {Key} y = tryget(d, key) @@ -382,6 +207,16 @@ function Base.getindex(d::LinearProbingDict{Key}, key) where {Key} end end +function Base.haskey(d::LinearProbingDict, key) + @inline f(::Nothing) = nothing + @inline f(_) = Keep(nothing) + y = modify!(f, d, key) + return y !== nothing +end + +Base.get(d::LinearProbingDict, key, default) = + something(ConcurrentCollections.tryget(d, key), default) + function ConcurrentCollections.tryget(d::LinearProbingDict, key) @inline f(::Nothing) = nothing @inline f(x) = Keep(x[]) @@ -420,85 +255,126 @@ end function ConcurrentCollections.modify!( f, - dict::LinearProbingDict{Key,Value,Slot}, + dict::LinearProbingDict{Key,Value}, key, -) where {Key,Value,Slot} +) where {Key,Value} key = convert(Key, key) - GC.@preserve dict key begin - h = reinterpret(Int, hash(key)) + slots, pairnodes = slots_and_pairnodes(dict) + newslotid = UInt64(0) - slots = atomic_getfield(dict, Val(:slots)) + h = hash(key) - if 2 * length_upper_bound(dict) > length(slots) - slots = expand!(dict, slots) - end + # The upper bits of hash that would be stored in `keyinfo.keydata`: + inlinedhash = h >> LPD_NBITS - # TODO: check if the allocation is eliminated for getindex - new_slot = allocate_slot(slots) + if 4 * length_upper_bound(dict) > length(slots) + slots, pairnodes = expand!(dict, slots, pairnodes) + end - while true - c = length(slots) - offset = h & (c - 1) # h % c - nprobes = 0 + while true + c = length(slots) ÷ 2 + offset = reinterpret(Int, h) & (c - 1) # h % c + nprobes = 0 + GC.@preserve slots begin while true index = offset + 1 - slotref = load_slot(slots, index) - sk = getkey(slotref) - # @show index sk slotref + s1ptr = pointer(slots, 2 * offset + 1) + s2ptr = pointer(slots, 2 * offset + 2) + keybits = UnsafeAtomics.load(s1ptr) + keyinfo = KeyInfo(keybits) - if sk isa Union{Moved{Key},MovedEmpty} - slots = finishmove!(dict, slots) + if keyinfo.ismoved || keyinfo.ismovedempty + slots, pairnodes = finishmove!(dict, slots, pairnodes) break # restart - elseif sk isa Empty + elseif keyinfo.isempty reply = f(nothing)::Union{Nothing,Some} - reply === nothing && return reply # optimization - nsk = key - elseif sk isa Key - if isequal(sk, key) - vref = value_ref(slotref) - reply = f(vref)::Union{Nothing,Some,Keep,Delete} - nsk = sk - else - @goto probing + reply === nothing && return reply + # Insertion: + if iszero(newslotid) + newslotid = dict.slotids[Threads.threadid()] += Threads.nthreads() + # TODO: Handle wrap-around of slotid? Reset it during migration? end - elseif sk isa Deleted - @goto probing - else - unexpected(sk) - end - - if reply isa Keep - return reply - elseif reply isa Union{Nothing,Delete} - if cas_slot!(slotref, new_slot, slots, Deleted()) - ndeleted = Threads.atomic_add!(dict.ndeleted, 1) + 1 - approx_len = dict.nadded[] - ndeleted - if approx_len < length(slots) ÷ 2 - shrink!(dict, slots) - end + prepare_pairnode!(pairnodes, index, newslotid, key, something(reply)) + oldslot = Pair(keybits, zero(keybits)) + newslot = Pair(KeyInfo{UInt64}(LPD_HASKEY, inlinedhash).bits, newslotid) + s12ptr = Ptr{typeof(oldslot)}(s1ptr) + found = UnsafeAtomics.cas!(s12ptr, oldslot, newslot) + if found === oldslot + Threads.atomic_add!(dict.nadded, 1) return reply end - else - if cas_slot!(slotref, new_slot, slots, nsk, something(reply)) - if sk isa Empty - Threads.atomic_add!(dict.nadded, 1) + foundinfo = KeyInfo(first(found)) + if foundinfo.ismoved | foundinfo.ismovedempty + slots, pairnodes = finishmove!(dict, slots, pairnodes) + break # restart + else + # Failed to insert a new entry. It means that there was + # another task successfully inserted a new slot. The + # linked list in `pairnodes[index]` needs cleanup now + # before continue probing. + cleanup_pairnode!(slots, pairnodes, index) + + # Retrying on CAS failure since this key may be inserted + # by another task. + continue + # TODO: Check the hash in `found`? If different, there's + # no need to retry. + end + elseif keyinfo.haskey + if keyinfo.keydata ==′ inlinedhash + slotid = UnsafeAtomics.load(s2ptr) + node = load_pairnode(pairnodes, index, slotid) + if isequal(node.key, key) + vref = ValueRef{Value}(node) + while true + reply = f(vref)::Union{Keep,Nothing,Delete,Some} + reply isa Keep && return reply + reply isa Union{Nothing,Delete} && break + # Update: + old = vref[] + new = something(reply::Some) + old, success = @atomicreplace node.value old => new + success && return reply + vref[] = old + end + + # Deletion: + oldslot = Pair(keybits, slotid) + newslot = Pair(setstate(keyinfo, LPD_DELETED).bits, slotid) + s12ptr = Ptr{typeof(oldslot)}(s1ptr) + if UnsafeAtomics.cas!(s12ptr, oldslot, newslot) === oldslot + ndeleted = Threads.atomic_add!(dict.ndeleted, 1) + 1 + approx_len = dict.nadded[] - ndeleted + half_len = length(slots) ÷ 4 + if length(slots) > 8 && approx_len < half_len + shrink!(dict, slots, pairnodes) + end + return reply + else + continue # CAS failed; retry + end end - return reply end + # Key doesn't match => continue probing + elseif keyinfo.isdeleted + # => continue probing + else + unexpected(keyinfo) end - # TODO: use the key loaded via CAS - continue # retry - @label probing nprobes += 1 if nprobes > c ÷ 4 - let oldslots = slots - slots = atomic_getfield(dict, Val(:slots)) + let newslots = @atomic dict.slots # Nonblocking check to see if the slots are migrated: - if slots === oldslots + if slots === newslots # @info "expand: length(slots) ≈ 2^$(floor(Int, log2(length(slots))))" # global DICT = dict - slots = expand!(dict, oldslots) + # TODO: Check the approximated table size here. It's + # possible that the hash table needs cleanup but not + # resize (i.e., too many deleted slots). + slots, pairnodes = expand!(dict, slots, pairnodes) + else + slots, pairnodes = slots_and_pairnodes(dict) end end break # restart @@ -510,59 +386,84 @@ function ConcurrentCollections.modify!( end end -expand!(dict, oldslots) = migrate!(dict, oldslots, true) -shrink!(dict, oldslots) = migrate!(dict, oldslots, false) +expand!(dict, oldslots, oldpairnodes) = migrate!(dict, oldslots, oldpairnodes, true) +shrink!(dict, oldslots, oldpairnodes) = migrate!(dict, oldslots, oldpairnodes, false) + +function new_slots_and_pairnodes(slots, pairnodes, expand) + newslots = zeros(eltype(slots), expand ? length(slots) * 2 : length(slots) ÷ 2) + # newslots = Mmap.mmap(Vector{UInt64}, expand ? length(slots) * 2 : length(slots) ÷ 2) + newpairnodes = [eltype(pairnodes)() for _ in 1:length(newslots)÷2] + # TODO: Can refs (and not just nodes) be reused? + return (newslots, newpairnodes) +end -function migrate!(dict, oldslots, expand) +function migrate!(dict::LinearProbingDict, expand::Bool; basesize = nothing) + slots, pairnodes = slots_and_pairnodes(dict) + return migrate!(dict, slots, pairnodes, expand; basesize) +end + +function migrate!(dict, oldslots, oldpairnodes, expand; basesize = nothing) # Since the migration is parallelized, workers running tasks blocked by the - # lock actually will contribute to the forward progress of the eintire + # lock actually will contribute to the forward progress of the entire # system. (Though the OS may suspend this worker thread before the tasks are # spawned.) lock(dict.migration) do - slots = atomic_getfield(dict, Val(:slots)) + slots = (@atomic dict.slots)::Vector{UInt64} + pairnodes = @atomic dict.pairnodes if slots !== oldslots - return slots + return slots, pairnodes end - newslots = similar(slots, expand ? length(slots) * 2 : length(slots) ÷ 2) + @atomic dict.slots = nothing + @assert pairnodes === oldpairnodes + (newslots, newpairnodes) = new_slots_and_pairnodes(slots, pairnodes, expand) if expand - nadded = expand_parallel!(newslots, slots) + nadded = expand_parallel!(newslots, newpairnodes, slots, pairnodes, basesize) else - nadded = migrate!(newslots, slots) + nadded = migrate_serial!(newslots, newpairnodes, slots, pairnodes) end # TODO: parallelize `shrink!` - # At this point, no other thread can be mutating the coutners (as they + # At this point, no other thread can be mutating the counters (as they # will observe `Moved`). Thus, it is safe to update the counter # non-atomically: dict.ndeleted[] = 0 dict.nadded[] = nadded - # This is the atomic "publlshing" operation that makes the `newslots` - # accssible to any tasks (including the ones that are/were not trying to - # acquire the `migration` lock). - atomic_setfield!(dict, Val(:slots), newslots) + # This is the atomic "publishing" operation that makes the `newslots` + # accessible to any tasks (including the ones that are/were not trying + # to acquire the `migration` lock). + @atomic dict.pairnodes = newpairnodes + @atomic dict.slots = newslots - return newslots + return newslots, newpairnodes end end -function finishmove!(dict, oldslots) +function finishmove!(dict, oldslots, oldpairnodes) lock(dict.migration) do - slots = atomic_getfield(dict, Val(:slots)) + slots = (@atomic dict.slots)::Vector{UInt64} + pairnodes = @atomic dict.pairnodes # The caller observed `Moved` which only sets inside the `migration` # lock. Thus, the migration should be finished once this lock is # acquired: @assert oldslots !== slots - return slots + @assert oldpairnodes !== pairnodes + return slots, pairnodes end end -function migrate!(newslots, slots) - fillempty!(newslots) - GC.@preserve newslots slots begin - nadded = migrate_impl!(newslots, slots)::Int +function slots_and_pairnodes(dict) + while true + pairnodes = @atomic dict.pairnodes + slots = @atomic dict.slots + if slots === nothing + return finishmove!(dict, slots, pairnodes) + else + if pairnodes === @atomic dict.pairnodes + return (slots, pairnodes) + end + end end - return nadded end struct Stopped @@ -573,43 +474,60 @@ end """ expand_parallel_basecase!(newslots, slots, basesize, start0) -> (nadded, seen_empty) -Process all clusters started within `start0:(start0 + basesize)` (mod `lengh(slots)`). +Process all clusters started within `start0:(start0 + basesize)` (mod `length(slots)`). -That is to say: +That is to say, _try_ to process `start0:(start0 + basesize)` but make sure to +avoid stepping into other base cases by using an empty slot to delimite the base +case boundaries. This uses the property of linear probing dictionary that +non-overlapping clusters (i.e., continuous regions of slots that are non-empty) +are mapped to non-overlapping regions when the `slots` array is doubled in size. + +More precisely: 1. Process all clusters started within `start0:(start0 + basesize - 1)`. -2. If more than one cluster is processed, process a cluster in which the start - position of the next chunk `start0 + basesize` (mod `lengh(slots)`) is included. +2. If more than one cluster is processed, process the cluster that contains the + start position of the next chunk `start0 + basesize` (mod `length(slots)`). """ -function expand_parallel_basecase!(newslots, slots, basesize, start0) - stop0 = min(start0 + basesize - 1, lastindex(slots)) - stpd = migrate_impl!(nothing, slots, start0, stop0, Val(true)) +function expand_parallel_basecase!( + newslots, + newpairnodes, + slots, + pairnodes, + basesize, + start0, +) + c = length(slots) ÷ 2 + stop0 = min(start0 + basesize - 1, c) + stpd = migrate_serial!(nothing, nothing, slots, pairnodes, start0, stop0, Val(true)) if stpd isa Int @assert stpd == 0 # This chunk does not own any clusters. return (0, false) end + migrate_between(start, stop, flag) = + migrate_serial!(newslots, newpairnodes, slots, pairnodes, start, stop, flag) + # An empty slot is observed. There is at least one cluster started within # this chunk. stpd::Stopped @assert stpd.nadded == 0 - nadded = migrate_impl!(newslots, slots, stpd.i + 1, stop0, Val(false))::Int + nadded = migrate_between(stpd.i + 1, stop0, Val(false))::Int # Process the cluster that includes `start0 + basesize` (if any). next_start = start0 + basesize - if next_start > lastindex(slots) - next_start = firstindex(slots) + if next_start > c + next_start = 1 end chunk_starts = ( - next_start:basesize:lastindex(slots), - firstindex(slots):basesize:next_start-1, # wrap around + next_start:basesize:c, + 1:basesize:next_start-1, # wrap around ) # Using `for half` so that the compiler does not unroll the loop. # TODO: check if it is working for half in 1:2, start in chunk_starts[half] - stop = min(start + basesize - 1, lastindex(slots)) - stpd = migrate_impl!(newslots, slots, start, stop, Val(true)) + stop = min(start + basesize - 1, c) + stpd = migrate_between(start, stop, Val(true)) if stpd isa Stopped nadded += stpd.nadded return (nadded, true) @@ -624,91 +542,142 @@ plus_or((a, b), (c, d)) = (a + c, b | d) # See`BenchDictMigration` for benchmarking this: const LINEAR_PROBING_DICT_EXPAND_BASESIZE = Ref(2^13) -function expand_parallel!(newslots, slots) +function expand_parallel!(newslots, newpairnodes, slots, pairnodes, basesize) + # TODO: Make the default `basesize` configurable? + basesize = something(basesize, LINEAR_PROBING_DICT_EXPAND_BASESIZE[]) @assert length(newslots) > length(slots) - minimum_basesize = LINEAR_PROBING_DICT_EXPAND_BASESIZE[] # TODO: configurable? - length(slots) <= minimum_basesize && return migrate!(newslots, slots) - basesize = min(minimum_basesize, cld(length(slots), Threads.nthreads())) - - fillempty!(newslots) # TODO: parallelize? - nadded, seen_empty = threaded_typed_mapreduce( - Tuple{Int,Bool}, - plus_or, - 1:basesize:lastindex(slots), - ) do start0 - return expand_parallel_basecase!(newslots, slots, basesize, start0) - end + length(slots) <= basesize && + return migrate_serial!(newslots, newpairnodes, slots, pairnodes) + basesize = min(basesize, cld(length(slots), 2 * Threads.nthreads())) + + c = length(slots) ÷ 2 + nadded, seen_empty = + threaded_typed_mapreduce(Tuple{Int,Bool}, plus_or, 1:basesize:c) do start0 + return expand_parallel_basecase!( + newslots, + newpairnodes, + slots, + pairnodes, + basesize, + start0, + ) + end if seen_empty return nadded else - # The `slots` are all non-empty: - return migrate!(newslots, slots) + # The `slots` are all non-empty. Fallback to serial migration: + return migrate_serial!(newslots, newpairnodes, slots, pairnodes) + end +end + +migrate_serial!(newslots, newpairnodes, slots, pairnodes) = migrate_serial!( + newslots, + newpairnodes, + slots, + pairnodes, + 1, + length(slots) ÷ 2, + Val(false), +)::Int + +function migrate_serial!( + newslots, + newpairnodes, + slots, + pairnodes, + start, + stop, + stop_on_empty, +) + GC.@preserve newslots slots begin + nadded = unsafe_migrate!( + newslots, + newpairnodes, + slots, + pairnodes, + start, + stop, + stop_on_empty, + ) end + return nadded end -migrate_impl!(newslots::AbstractVector, slots::AbstractVector) = - migrate_impl!(newslots, slots, firstindex(slots), lastindex(slots), Val(false)) - -function migrate_impl!( - newslots::Union{AbstractVector{Slot},Nothing}, - slots::AbstractVector{Slot}, +function unsafe_migrate!( + newslots::Union{AbstractVector{UInt64},Nothing}, + newpairnodes::Union{AbstractVector{R},Nothing}, + slots::AbstractVector{UInt64}, + pairnodes::AbstractVector{R}, start::Int, stop::Int, stop_on_empty::Union{Val{false},Val{true}}, -) where {Slot} +) where {R<:AtomicRef{<:PairNode}} nadded = 0 for i in start:stop - @label reload - slotref = load_slot(slots, i) - sk = getkey(slotref) - if sk isa Deleted - continue - elseif sk isa MovedEmpty - stop_on_empty == Val(true) && return Stopped(i, nadded) - continue - elseif sk isa Empty - # Mark that this slot is not usable anymore - if !cas_slot!(slotref, allocate_slot(slots), slots, MovedEmpty()) - @goto reload + offset = i - 1 + s1ptr = pointer(slots, 2 * offset + 1) + s2ptr = pointer(slots, 2 * offset + 2) + local tryset + let s2ptr = s2ptr + @inline function tryset(keyinfo, newstate) + local slotid = UnsafeAtomics.load(s2ptr) + local oldslot = Pair(keyinfo.bits, slotid) + local newslot = Pair(setstate(keyinfo, newstate).bits, slotid) + local s12ptr = Ptr{typeof(oldslot)}(s1ptr) + return UnsafeAtomics.cas!(s12ptr, oldslot, newslot) === oldslot + end + end + local keyinfo + while true + keybits = UnsafeAtomics.load(s1ptr) + keyinfo = KeyInfo(keybits) + if keyinfo.isdeleted + break # next index + elseif keyinfo.ismovedempty + stop_on_empty == Val(true) && return Stopped(i, nadded) + break # next index + elseif keyinfo.isempty + # Mark that this slot is not usable anymore + if !tryset(keyinfo, LPD_MOVED_EMPTY) + continue + end + stop_on_empty == Val(true) && return Stopped(i, nadded) + break # next index + elseif keyinfo.haskey + if !tryset(keyinfo, LPD_MOVED) + continue + end + else + @assert keyinfo.ismoved end - stop_on_empty == Val(true) && return Stopped(i, nadded) - continue + @goto move end + continue + @label move newslots === nothing && continue - sv = value_ref(slotref)[] - # @show i slotref sk sv - if sk isa Moved - key = sk.key - else - if !cas_slot!(slotref, allocate_slot(slots), slots, Moved(sk), sv) - @goto reload - end - key = sk - end - ns = make_slot(Slot, key, sv) - # TODO: batch allocation + newkeybits = setstate(keyinfo, LPD_HASKEY).bits + slotid = UnsafeAtomics.load(s2ptr) + node = load_pairnode(pairnodes, i, slotid) + key = node.key # Insertion to `newslots` does not have to use atomics since # it's protected by the `.migration` lock. - c = length(newslots) + c = length(newslots) ÷ 2 h = reinterpret(Int, hash(key)) offset = h & (c - 1) # h % c nprobes = 0 while true - index = offset + 1 # TODO: non-atomic ordering - slotref = load_slot(newslots, index) - sk = getkey(slotref) - if sk isa Empty - # @assert newslots[index].key.x.x === sk - # @show newslots[index] - # @show index ns - # TODO: create AlignedArray type so that we don't have - # to copy pads for each get/set - @inbounds newslots[index] = ns + local keybits = @inbounds newslots[2*offset+1] + local keyinfo = KeyInfo(keybits) + if keyinfo.isempty + @inbounds newslots[2*offset+1] = newkeybits + @inbounds newslots[2*offset+2] = slotid + ref = newpairnodes[offset+1] + @atomic ref.value = node nadded += 1 break end @@ -727,26 +696,28 @@ Base.IteratorSize(::Type{<:Base.KeySet{<:Any,<:LinearProbingDict}}) = Base.SizeU Base.IteratorSize(::Type{<:Base.ValueIterator{<:LinearProbingDict}}) = Base.SizeUnknown() function Base.iterate(dict::LinearProbingDict) - GC.@preserve dict begin - slots = atomic_getfield(dict, Val(:slots)) - end - return iterate(dict, (slots, firstindex(slots))) + slots, pairnodes = slots_and_pairnodes(dict) + return iterate(dict, (slots, pairnodes, 1)) end -function Base.iterate(::LinearProbingDict, (slots, index)) +function Base.iterate(::LinearProbingDict, (slots, pairnodes, index)) GC.@preserve slots begin - index < firstindex(slots) && return nothing + index < 1 && return nothing while true - index > lastindex(slots) && return nothing - s = load_slot(slots, index) - index += 1 - sk = getkey(s) - sv = value_ref(s)[] - sk isa Union{Empty,MovedEmpty,Deleted} && continue - if sk isa Moved - sk = sk.key + 2 * index > length(slots) && return nothing + offset = index - 1 + s1ptr = pointer(slots, 2 * offset + 1) + s2ptr = pointer(slots, 2 * offset + 2) + keybits = UnsafeAtomics.load(s1ptr) + keyinfo = KeyInfo(keybits) + if keyinfo.haskey | keyinfo.ismoved + slotid = UnsafeAtomics.load(s2ptr) + node = load_pairnode(pairnodes, index, slotid) + key = node.key + value = @atomic node.value + return (key => value), (slots, pairnodes, index + 1) end - return (sk => sv), (slots, index) + index += 1 end end end @@ -843,29 +814,27 @@ describe(map(length, cs)) ``` """ clusters(d::LinearProbingDict) = clusters(d.slots) -function clusters( - slots::AbstractVector{Slot}, -) where {Slot<:Union{AbstractPair,<:Ref{<:AbstractPair}}} +function clusters(slots::AbstractVector{UInt64}) cs = typeof(1:2)[] - i = firstindex(slots) + i = 1 while true while true - i > lastindex(slots) && return cs - slotref = load_slot(slots, i) + 2 * i > length(slots) && return cs + keyinfo = KeyInfo(slots[2*(i-1)+1]) i += 1 - if !(getkey(slotref) isa Union{Empty,MovedEmpty}) + if keyinfo.isempty | keyinfo.ismovedempty break end end start = i - 1 while true - if i > lastindex(slots) + if 2 * i > length(slots) push!(cs, start:i-1) return cs end - slotref = load_slot(slots, i) + keyinfo = KeyInfo(slots[2*(i-1)+1]) i += 1 - if getkey(slotref) isa Union{Empty,MovedEmpty} + if keyinfo.isempty | keyinfo.ismovedempty break end end diff --git a/src/utils.jl b/src/utils.jl index 1653d0c..f4b5569 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -210,6 +210,35 @@ end end end +@inline function check_embeddable(::Type{Storage}, ::Type{Data}) where {Storage,Data} + sizeof(Storage) ≥ sizeof(Data) || @static_error "sizeof(Storage) < sizeof(Data)" + Base.allocatedinline(Storage) || @static_error "!allocatedinline(Storage)" + Base.allocatedinline(Data) || @static_error "!allocatedinline(Data)" + isconcretetype(Storage) || @static_error "!isconcretetype(Storage)" +end + +@inline function unsafe_embed(::Type{Storage}, x) where {Storage} + Data = Some{typeof(x)} + check_embeddable(Storage, Data) + ref = Ref{Storage}() + GC.@preserve ref begin + ptr = Ptr{Data}(pointer_from_objref(ref)) + unsafe_store!(ptr, Data(x)) + end + return ref[] +end + +@inline function unsafe_extract(::Type{U}, x::Storage) where {U,Storage} + Data = Some{U} + check_embeddable(Storage, Data) + ref = Ref{Storage}(x) + GC.@preserve ref begin + ptr = Ptr{Data}(pointer_from_objref(ref)) + y = unsafe_load(ptr) + end + return something(y) +end + # Read /sys/devices/system/cpu/cpu0/cache/index0/coherency_line_size? const CACHELINE_SIZE = 64 diff --git a/test/ConcurrentCollectionsTests/src/ConcurrentCollectionsTests.jl b/test/ConcurrentCollectionsTests/src/ConcurrentCollectionsTests.jl index a6562ff..93f2d82 100644 --- a/test/ConcurrentCollectionsTests/src/ConcurrentCollectionsTests.jl +++ b/test/ConcurrentCollectionsTests/src/ConcurrentCollectionsTests.jl @@ -5,7 +5,6 @@ include("test_bench_dict_histogram.jl") include("test_bench_smoke.jl") include("test_crq.jl") include("test_dict.jl") -include("test_dict_impl.jl") include("test_dlcrq.jl") include("test_doctest.jl") include("test_lcrq.jl") diff --git a/test/ConcurrentCollectionsTests/src/test_bench_dict_histogram.jl b/test/ConcurrentCollectionsTests/src/test_bench_dict_histogram.jl index 00d09c5..eb47854 100644 --- a/test/ConcurrentCollectionsTests/src/test_bench_dict_histogram.jl +++ b/test/ConcurrentCollectionsTests/src/test_bench_dict_histogram.jl @@ -27,6 +27,16 @@ function test(datasize, fulldata) global FAILED = (; cdpar, dbase) end =# + @test sort(collect(setdiff(keys(dbase), keys(cdpar)))) == [] + @test sort(collect(setdiff(keys(cdpar), keys(dbase)))) == [] + diffvalues = [] + for (key, expected) in dbase + actual = cdpar[key] + if actual != expected + push!(diffvalues, (; key, actual, expected)) + end + end + @test diffvalues == [] @test Dict(cdpar) == dbase end end diff --git a/test/ConcurrentCollectionsTests/src/test_dict.jl b/test/ConcurrentCollectionsTests/src/test_dict.jl index 0b52c9f..565a777 100644 --- a/test/ConcurrentCollectionsTests/src/test_dict.jl +++ b/test/ConcurrentCollectionsTests/src/test_dict.jl @@ -1,9 +1,93 @@ -module DontTestDict +module TestDict using ConcurrentCollections -using ConcurrentCollections.Implementations: clusters +using ConcurrentCollections.Implementations: + LPDKeyState, + LPD_BITMASK, + LPD_DELETED, + LPD_EMPTY, + LPD_HASKEY, + LPD_MOVED, + LPD_MOVED_EMPTY, + LPD_NBITS, + KeyInfo, + clusters, + migrate!, + setdata, + setstate using Test +function test_keyinfo() + @test KeyInfo(UInt64(0)).state === LPD_EMPTY + @testset for state in instances(LPDKeyState) + @test KeyInfo{UInt64}(state, 0x0123456789abcdef).state === state + if state !== LPD_EMPTY + @test KeyInfo{UInt64}(state, 0x0123456789abcdef).keydata === 0x0123456789abcdef + end + @test setstate(KeyInfo(rand(UInt64)), state).state === state + keydata = rand(UInt64) >> LPD_NBITS + @test setdata(KeyInfo(rand(UInt64)), keydata).keydata === keydata + end +end + +function test_keyinfo_properties() + keyinfo = KeyInfo{UInt64}(rand(UInt64)) + enum_to_property = Dict( + LPD_EMPTY => :isempty, + LPD_DELETED => :isdeleted, + LPD_MOVED_EMPTY => :ismovedempty, + LPD_MOVED => :ismoved, + LPD_HASKEY => :haskey, + ) + properties = collect(values(enum_to_property)) + @testset for state in instances(LPDKeyState), prop in properties + if enum_to_property[state] === prop + @test getproperty(setstate(keyinfo, state), prop) + else + @test !getproperty(setstate(keyinfo, state), prop) + end + end +end + +function test_expand_and_shrink(n = 17) + d = ConcurrentDict{Int,Int}() + @testset "expand" begin + @testset for i in 1:n + d[i] = -i + @testset for k in 1:i + @test d[k] == -k + end + end + end + nfull = length(d.slots) + @testset "shrink" begin + @testset for i in n:-1:1 + @test pop!(d, i) == -i + @testset for k in 1:i-1 + @test d[k] == -k + end + end + end + @test length(d.slots) < nfull + return d +end + +function test_parallel_expand(n = 2^10, basesize = 8) + d = ConcurrentDict{Int,Int}(pairs(1:n)) + nslots = length(d.slots) + migrate!(d, true; basesize) + @test nslots < length(d.slots) + diffs = Pair{Int,Int}[] + for k in 1:n + v = d[k] + if v != k + push!(diffs, k => v) + end + end + @test diffs == [] + return d +end + function test_dict() @testset for npairs in [2, 100] test_dict(npairs) diff --git a/test/ConcurrentCollectionsTests/src/test_dict_impl.jl b/test/ConcurrentCollectionsTests/src/test_dict_impl.jl deleted file mode 100644 index fcd7deb..0000000 --- a/test/ConcurrentCollectionsTests/src/test_dict_impl.jl +++ /dev/null @@ -1,18 +0,0 @@ -module TestDictImpl - -using ConcurrentCollections -using ConcurrentCollections.Implementations: BoxedKeyPair, InlinedPair -using Test - -slottype(Key, Value) = eltype(ConcurrentDict{Key,Value}().slots) - -function var"test_slot type"() - @test slottype(Int8, Int) <: InlinedPair{Int8,Int} - @test slottype(Int16, Int) <: InlinedPair{Int16,Int} - @test slottype(Int32, Int) <: InlinedPair{Int32,Int} - @test slottype(Int, Int) <: Ref - @test slottype(Int32, Int32) <: InlinedPair{Int32,Int32} - @test slottype(String, Int) <: BoxedKeyPair{String,Int} -end - -end # module From edafde1b9269b116174f78b11d461455ba26ab42 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Wed, 15 Sep 2021 15:09:16 -0400 Subject: [PATCH 2/2] Store full hash in the slot --- src/dict.jl | 49 +++++++++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/src/dict.jl b/src/dict.jl index 83fce8c..9c8a329 100644 --- a/src/dict.jl +++ b/src/dict.jl @@ -142,7 +142,8 @@ end @inline function cleanup_pairnode!(slots, pairnodes, index) GC.@preserve slots begin s2ptr = pointer(slots, 2 * index) - slotid = UnsafeAtomics.load(s2ptr) + slot2 = UnsafeAtomics.load(s2ptr) + slotid = slotid_from_slot2(slot2) end iszero(slotid) && unreachable() ref = pairnodes[index] @@ -253,6 +254,13 @@ function ConcurrentCollections.trypop!(d::LinearProbingDict, key) end end +@inline slotid_from_slot2(slot2::UInt64) = slot2 & (typemax(UInt64) >> LPD_NBITS) + +@inline function reconstruct_full_hash(keyinfo::KeyInfo, slot2::UInt64) + hash2 = slot2 & ~(typemax(UInt64) >> LPD_NBITS) + return (keyinfo.keydata << LPD_NBITS) | (hash2 >> (64 - LPD_NBITS)) +end + function ConcurrentCollections.modify!( f, dict::LinearProbingDict{Key,Value}, @@ -264,8 +272,9 @@ function ConcurrentCollections.modify!( h = hash(key) - # The upper bits of hash that would be stored in `keyinfo.keydata`: - inlinedhash = h >> LPD_NBITS + # The upper and lower bits of hash: + hash1 = h >> LPD_NBITS # stored in slot 1 (`keyinfo.keydata`) + hash2 = h << (64 - LPD_NBITS) # stored in slot 2 if 4 * length_upper_bound(dict) > length(slots) slots, pairnodes = expand!(dict, slots, pairnodes) @@ -295,8 +304,9 @@ function ConcurrentCollections.modify!( # TODO: Handle wrap-around of slotid? Reset it during migration? end prepare_pairnode!(pairnodes, index, newslotid, key, something(reply)) + slot2 = hash2 | newslotid oldslot = Pair(keybits, zero(keybits)) - newslot = Pair(KeyInfo{UInt64}(LPD_HASKEY, inlinedhash).bits, newslotid) + newslot = Pair(KeyInfo{UInt64}(LPD_HASKEY, hash1).bits, slot2) s12ptr = Ptr{typeof(oldslot)}(s1ptr) found = UnsafeAtomics.cas!(s12ptr, oldslot, newslot) if found === oldslot @@ -321,10 +331,12 @@ function ConcurrentCollections.modify!( # no need to retry. end elseif keyinfo.haskey - if keyinfo.keydata ==′ inlinedhash - slotid = UnsafeAtomics.load(s2ptr) + if keyinfo.keydata ==′ hash1 + slot2 = UnsafeAtomics.load(s2ptr) + slotid = slotid_from_slot2(slot2) + stored_hash = reconstruct_full_hash(keyinfo, slot2) node = load_pairnode(pairnodes, index, slotid) - if isequal(node.key, key) + if stored_hash == h && isequal(node.key, key) vref = ValueRef{Value}(node) while true reply = f(vref)::Union{Keep,Nothing,Delete,Some} @@ -339,8 +351,8 @@ function ConcurrentCollections.modify!( end # Deletion: - oldslot = Pair(keybits, slotid) - newslot = Pair(setstate(keyinfo, LPD_DELETED).bits, slotid) + oldslot = Pair(keybits, slot2) + newslot = Pair(setstate(keyinfo, LPD_DELETED).bits, slot2) s12ptr = Ptr{typeof(oldslot)}(s1ptr) if UnsafeAtomics.cas!(s12ptr, oldslot, newslot) === oldslot ndeleted = Threads.atomic_add!(dict.ndeleted, 1) + 1 @@ -621,9 +633,9 @@ function unsafe_migrate!( local tryset let s2ptr = s2ptr @inline function tryset(keyinfo, newstate) - local slotid = UnsafeAtomics.load(s2ptr) - local oldslot = Pair(keyinfo.bits, slotid) - local newslot = Pair(setstate(keyinfo, newstate).bits, slotid) + local slot2 = UnsafeAtomics.load(s2ptr) + local oldslot = Pair(keyinfo.bits, slot2) + local newslot = Pair(setstate(keyinfo, newstate).bits, slot2) local s12ptr = Ptr{typeof(oldslot)}(s1ptr) return UnsafeAtomics.cas!(s12ptr, oldslot, newslot) === oldslot end @@ -659,15 +671,15 @@ function unsafe_migrate!( newslots === nothing && continue newkeybits = setstate(keyinfo, LPD_HASKEY).bits - slotid = UnsafeAtomics.load(s2ptr) + slot2 = UnsafeAtomics.load(s2ptr) + slotid = slotid_from_slot2(slot2) node = load_pairnode(pairnodes, i, slotid) - key = node.key + h = reconstruct_full_hash(keyinfo, slot2) # Insertion to `newslots` does not have to use atomics since # it's protected by the `.migration` lock. c = length(newslots) ÷ 2 - h = reinterpret(Int, hash(key)) - offset = h & (c - 1) # h % c + offset = reinterpret(Int, h) & (c - 1) # h % c nprobes = 0 while true # TODO: non-atomic ordering @@ -675,7 +687,7 @@ function unsafe_migrate!( local keyinfo = KeyInfo(keybits) if keyinfo.isempty @inbounds newslots[2*offset+1] = newkeybits - @inbounds newslots[2*offset+2] = slotid + @inbounds newslots[2*offset+2] = slot2 ref = newpairnodes[offset+1] @atomic ref.value = node nadded += 1 @@ -711,7 +723,8 @@ function Base.iterate(::LinearProbingDict, (slots, pairnodes, index)) keybits = UnsafeAtomics.load(s1ptr) keyinfo = KeyInfo(keybits) if keyinfo.haskey | keyinfo.ismoved - slotid = UnsafeAtomics.load(s2ptr) + slot2 = UnsafeAtomics.load(s2ptr) + slotid = slotid_from_slot2(slot2) node = load_pairnode(pairnodes, index, slotid) key = node.key value = @atomic node.value