Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements for IdDict #51091

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 36 additions & 12 deletions base/iddict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ IdDict{Any, String} with 3 entries:
```
"""
mutable struct IdDict{K,V} <: AbstractDict{K,V}
# NOTE make sure to sync the struct definition with `jl_id_dict_t` in julia.h
ht::Memory{Any}
count::Int
ndel::Int
Expand Down Expand Up @@ -83,18 +84,35 @@ function sizehint!(d::IdDict, newsz)
rehash!(d, newsz)
end

function setindex!(d::IdDict{K,V}, @nospecialize(val), @nospecialize(key)) where {K, V}
# get (index) for the key
# index - where a key is stored, or -pos if not present
# and was inserted at pos
function ht_keyindex2!(d::IdDict{K,V}, @nospecialize(key)) where {K, V}
!isa(key, K) && throw(KeyTypeError(K, key))
if !(val isa V) # avoid a dynamic call
val = convert(V, val)::V
end
return ccall(:jl_eqtable_keyindex, Cssize_t, (Any, Any), d, key)
end

@propagate_inbounds function _setindex!(d::IdDict{K,V}, val::V, keyindex::Int) where {K, V}
d.ht[keyindex+1] = val
d.count += 1

if d.ndel >= ((3*length(d.ht))>>2)
rehash!(d, max((length(d.ht)%UInt)>>1, 32))
d.ndel = 0
end
inserted = RefValue{Cint}(0)
d.ht = ccall(:jl_eqtable_put, Memory{Any}, (Any, Any, Any, Ptr{Cint}), d.ht, key, val, inserted)
d.count += inserted[]
return nothing
end

function setindex!(d::IdDict{K,V}, @nospecialize(val), @nospecialize(key)) where {K, V}
if !(val isa V) # avoid a dynamic call
val = convert(V, val)::V
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a discrepancy here in the state of the IdDict during this convert call or if this convert failed. The ht_keyindex! call has fully inserted the key, but not the value. For Dict, neither has occurred yet, and all insertion and accounting happens at once in _setindex!. But in IdDict here, it looks like half of the insertion occurs in ht_keyindex! (effectively inserting it in the deleted state) but then the final accounting update for it doesn't happen until after the _setindex! call later in this method. Do we update ndel over this call so the value is consistent?

There also may need to be an age counter added to the IdDict, so it can detect concurrent modifications that happen in this convert call (or move this convert call back to the top of the function)

end
keyindex = ht_keyindex2!(d, key)
if keyindex >= 0
@inbounds d.ht[keyindex+1] = val
else
@inbounds _setindex!(d, val, -keyindex)
end
return d
end

Expand Down Expand Up @@ -153,16 +171,22 @@ end

length(d::IdDict) = d.count

isempty(d::IdDict) = length(d) == 0

copy(d::IdDict) = typeof(d)(d)

function get!(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V}
val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, secret_table_token)
if val === secret_table_token
keyindex = ht_keyindex2!(d, key)

if keyindex < 0
# If convert call fails we need the key to be deleted
d.ndel += 1
val = isa(default, V) ? default : convert(V, default)::V
setindex!(d, val, key)
return val
else
d.ndel -= 1
@inbounds _setindex!(d, val, -keyindex)
return val::V
else
return @inbounds d.ht[keyindex+1]::V
end
end

Expand Down
40 changes: 29 additions & 11 deletions src/iddict.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#define keyhash(k) jl_object_id_(jl_typetagof(k), k)
#define h2index(hv, sz) (size_t)(((hv) & ((sz)-1)) * 2)

static inline int jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, jl_value_t *val);
static inline ssize_t jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, jl_value_t *val, int insert_val);

JL_DLLEXPORT jl_genericmemory_t *jl_idtable_rehash(jl_genericmemory_t *a, size_t newsz)
{
Expand All @@ -22,14 +22,16 @@ JL_DLLEXPORT jl_genericmemory_t *jl_idtable_rehash(jl_genericmemory_t *a, size_t
newa = jl_alloc_memory_any(newsz);
for (i = 0; i < sz; i += 2) {
if (ol[i + 1] != NULL) {
jl_table_assign_bp(&newa, ol[i], ol[i + 1]);
jl_table_assign_bp(&newa, ol[i], ol[i + 1], 1);
}
}
JL_GC_POP();
return newa;
}

static inline int jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, jl_value_t *val)
// returns where a key is stored, or -pos if the key was not present and was inserted at pos
// result is 1-indexed
static inline ssize_t jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, jl_value_t *val, int insert_val)
{
// pa points to a **un**rooted address
uint_t hv;
Expand Down Expand Up @@ -61,9 +63,11 @@ static inline int jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, j
}
if (jl_egal(key, k2)) {
if (jl_atomic_load_relaxed(&tab[index + 1]) != NULL) {
jl_atomic_store_release(&tab[index + 1], val);
jl_gc_wb(a, val);
return 0;
if (insert_val == 1) {
jl_atomic_store_release(&tab[index + 1], val);
jl_gc_wb(a, val);
}
return index+1;
}
// `nothing` is our sentinel value for deletion, so need to keep searching if it's also our search key
assert(key == jl_nothing);
Expand All @@ -82,9 +86,11 @@ static inline int jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, j
if (empty_slot != -1) {
jl_atomic_store_release(&tab[empty_slot], key);
jl_gc_wb(a, key);
jl_atomic_store_release(&tab[empty_slot + 1], val);
jl_gc_wb(a, val);
return 1;
if (insert_val == 1) {
jl_atomic_store_release(&tab[empty_slot + 1], val);
jl_gc_wb(a, val);
}
return -(empty_slot+1);
}

/* table full */
Expand Down Expand Up @@ -143,9 +149,9 @@ inline _Atomic(jl_value_t*) *jl_table_peek_bp(jl_genericmemory_t *a, jl_value_t
JL_DLLEXPORT
jl_genericmemory_t *jl_eqtable_put(jl_genericmemory_t *h, jl_value_t *key, jl_value_t *val, int *p_inserted)
{
int inserted = jl_table_assign_bp(&h, key, val);
ssize_t result = jl_table_assign_bp(&h, key, val, 1);
if (p_inserted)
*p_inserted = inserted;
*p_inserted = (result < 0);
return h;
}

Expand Down Expand Up @@ -191,6 +197,18 @@ size_t jl_eqtable_nextind(jl_genericmemory_t *t, size_t i)
return i;
}

JL_DLLEXPORT
ssize_t jl_eqtable_keyindex(jl_id_dict_t *d, jl_value_t *key)
{
jl_genericmemory_t *h = d->ht;

ssize_t index = jl_table_assign_bp(&h, key, NULL, 0);
jl_atomic_store_release((_Atomic(jl_genericmemory_t*)*)&d->ht, h);
jl_gc_wb(d, h);

return index;
}

#undef hash_size
#undef max_probe
#undef h2index
9 changes: 9 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,14 @@ typedef struct {
uint8_t fully_covers;
} jl_method_match_t;

// The following mirrors IdDict in "base/iddict.jl"
typedef struct {
JL_DATA_TYPE
jl_genericmemory_t *ht;
size_t count;
size_t ndel;
} jl_id_dict_t;

// constants and type objects -------------------------------------------------

#define JL_SMALL_TYPEOF(XX) \
Expand Down Expand Up @@ -1945,6 +1953,7 @@ JL_DLLEXPORT jl_genericmemory_t *jl_eqtable_put(jl_genericmemory_t *h JL_ROOTING
JL_DLLEXPORT jl_value_t *jl_eqtable_get(jl_genericmemory_t *h JL_PROPAGATES_ROOT, jl_value_t *key, jl_value_t *deflt) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_eqtable_pop(jl_genericmemory_t *h, jl_value_t *key, jl_value_t *deflt, int *found);
jl_value_t *jl_eqtable_getkey(jl_genericmemory_t *h JL_PROPAGATES_ROOT, jl_value_t *key, jl_value_t *deflt) JL_NOTSAFEPOINT;
JL_DLLEXPORT ssize_t jl_eqtable_keyindex(jl_id_dict_t *d, jl_value_t *key);

// system information
JL_DLLEXPORT int jl_errno(void) JL_NOTSAFEPOINT;
Expand Down
85 changes: 85 additions & 0 deletions test/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,81 @@ end
@test_throws ArgumentError IdDict([1, 2, 3, 4])
# test rethrow of error in ctor
@test_throws DomainError IdDict((sqrt(p[1]), sqrt(p[2])) for p in zip(-1:2, -1:2))

h = IdDict()
Zentrik marked this conversation as resolved.
Show resolved Hide resolved
for i=1:500
get!(h, i, i+1)
end
for i=1:500
@test (h[i] == i+1)
end

h = IdDict()
for i=1:500
h[i] = i+1
end
for i=1:500
@test (h[i] == i+1)
end
for i=1:2:500
delete!(h, i)
end
for i=1:2:500
h[i] = i+1
end
for i=1:500
@test (h[i] == i+1)
end
for i=1:500
delete!(h, i)
end
@test isempty(h)
h[77] = 100
@test h[77] == 100
for i=1:500
h[i] = i+1
end
for i=1:2:500
delete!(h, i)
end
for i=501:1000
h[i] = i+1
end
for i=2:2:499
@test h[i] == i+1
end
for i=500:1000
@test h[i] == i+1
end

h = IdDict{Any,Any}("a" => 3)
@test h["a"] == 3
h["a","b"] = 4
@test h["a","b"] == h[("a","b")] == 4
h["a","b","c"] = 4
@test h["a","b","c"] == h[("a","b","c")] == 4

@testset "eltype, keytype and valtype" begin
@test eltype(h) == Pair{Any,Any}
@test keytype(h) == Any
@test valtype(h) == Any

td = IdDict{AbstractString,Float64}()
@test eltype(td) == Pair{AbstractString,Float64}
@test keytype(td) == AbstractString
@test valtype(td) == Float64
@test keytype(IdDict{AbstractString,Float64}) === AbstractString
@test valtype(IdDict{AbstractString,Float64}) === Float64
end
# test rethrow of error in ctor
@test_throws DomainError IdDict((sqrt(p[1]), sqrt(p[2])) for p in zip(-1:2, -1:2))

h = IdDict()
h[1] = 2
h[1] = 4
@test h[1] == 4
@test length(h) == 1
@test length(keys(h)) == 1
end

@testset "issue 30165, get! for IdDict" begin
Expand Down Expand Up @@ -1540,3 +1615,13 @@ end
@test valtype(D{K, V}) == V
end
end

# Check mutating IdDict during get! works
let d = IdDict()
function f()
d[1] = 4
return -2
end
@test get!(f, d, 1) === -2
@test length(d) == 1
end