Skip to content

Commit

Permalink
Move code for calculating key index into Julia
Browse files Browse the repository at this point in the history
Added inbounds as dict.jl seems to use them as well
  • Loading branch information
Zentrik committed Oct 31, 2023
1 parent 2056a56 commit 7f20833
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 29 deletions.
19 changes: 13 additions & 6 deletions base/iddict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ IdDict{Any, String} with 3 entries:
```
"""
mutable struct IdDict{K,V} <: AbstractDict{K,V}
# NOTE make sure to sync the struct definition with `jl_id_dict_t` in julia.h
ht::Memory{Any}
count::Int
ndel::Int
Expand Down Expand Up @@ -86,14 +85,22 @@ end

function ht_keyindex!(d::IdDict{K, V}, @nospecialize(key)) where {K, V}
!isa(key, K) && throw(ArgumentError("$(limitrepr(key)) is not a valid key for type $K"))
keyindex = ccall(:jl_eqtable_keyindex, Cssize_t, (Any, Any), d, key)
# keyindex - where a key is stored, or -pos if the key was not present and was inserted at pos

ht = d.ht
t = @_gc_preserve_begin ht

ref = Ref{Ptr{Any}}(pointer_from_objref(ht))
# # keyindex - where a key is stored, or -pos if the key was not present and was inserted at pos
keyindex = ccall(:jl_table_assign_bp, Cssize_t, (Ptr{Ptr{Any}}, Any, Any, Cint), ref, key, C_NULL, 0)
d.ht = unsafe_pointer_to_objref(ref[])

@_gc_preserve_end t

return abs(keyindex), keyindex < 0
end

function _setindex!(d::IdDict{K, V}, val::V, key::K, keyindex::Int, inserted::Bool) where {K, V}
d.ht[keyindex+1] = val
@inbounds d.ht[keyindex+1] = val
d.count += inserted

if d.ndel >= ((3*length(d.ht))>>2)
Expand Down Expand Up @@ -179,7 +186,7 @@ function get!(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where
_setindex!(d, val, key, keyindex, inserted)
return val::V
else
return d.ht[keyindex+1]::V
return @inbounds d.ht[keyindex+1]::V
end
end

Expand All @@ -203,7 +210,7 @@ function get!(default::Callable, d::IdDict{K,V}, @nospecialize(key)) where {K, V
_setindex!(d, val, key, keyindex, inserted)
return val::V
else
return d.ht[keyindex+1]::V
return @inbounds d.ht[keyindex+1]::V
end
end

Expand Down
17 changes: 3 additions & 14 deletions src/iddict.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
#define keyhash(k) jl_object_id_(jl_typeof(k), k)
#define h2index(hv, sz) (size_t)(((hv) & ((sz)-1)) * 2)

static inline ssize_t jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, jl_value_t *val, int insert_val);
JL_DLLEXPORT
inline ssize_t jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, jl_value_t *val, int insert_val);

JL_DLLEXPORT jl_genericmemory_t *jl_idtable_rehash(jl_genericmemory_t *a, size_t newsz)
{
Expand All @@ -32,7 +33,7 @@ JL_DLLEXPORT jl_genericmemory_t *jl_idtable_rehash(jl_genericmemory_t *a, size_t

// returns where a key is stored, or -pos if the key was not present and was inserted at pos
// result is 1-indexed
static inline ssize_t jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, jl_value_t *val, int insert_val)
inline ssize_t jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, jl_value_t *val, int insert_val)
{
// pa points to a **un**rooted address
uint_t hv;
Expand Down Expand Up @@ -198,17 +199,5 @@ size_t jl_eqtable_nextind(jl_genericmemory_t *t, size_t i)
return i;
}

JL_DLLEXPORT
ssize_t jl_eqtable_keyindex(jl_id_dict_t *d, jl_value_t *key)
{
jl_genericmemory_t *h = d->ht;

ssize_t index = jl_table_assign_bp(&h, key, NULL, 0);
jl_atomic_store_release((_Atomic(jl_genericmemory_t*)*)&d->ht, h);
jl_gc_wb(d, h);

return index;
}

#undef hash_size
#undef max_probe
10 changes: 1 addition & 9 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -701,14 +701,6 @@ typedef struct {
uint8_t fully_covers;
} jl_method_match_t;

// The following mirrors IdDict in "base/iddict.jl"
typedef struct {
JL_DATA_TYPE
jl_genericmemory_t *ht;
size_t count;
size_t ndel;
} jl_id_dict_t;

// constants and type objects -------------------------------------------------

#define JL_SMALL_TYPEOF(XX) \
Expand Down Expand Up @@ -1922,7 +1914,7 @@ JL_DLLEXPORT jl_genericmemory_t *jl_eqtable_put(jl_genericmemory_t *h JL_ROOTING
JL_DLLEXPORT jl_value_t *jl_eqtable_get(jl_genericmemory_t *h JL_PROPAGATES_ROOT, jl_value_t *key, jl_value_t *deflt) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_eqtable_pop(jl_genericmemory_t *h, jl_value_t *key, jl_value_t *deflt, int *found);
jl_value_t *jl_eqtable_getkey(jl_genericmemory_t *h JL_PROPAGATES_ROOT, jl_value_t *key, jl_value_t *deflt) JL_NOTSAFEPOINT;
JL_DLLEXPORT ssize_t jl_eqtable_keyindex(jl_id_dict_t *d, jl_value_t *key);
JL_DLLEXPORT ssize_t jl_table_assign_bp(jl_genericmemory_t **pa, jl_value_t *key, jl_value_t *val, int insert_val);

// system information
JL_DLLEXPORT int jl_errno(void) JL_NOTSAFEPOINT;
Expand Down

0 comments on commit 7f20833

Please sign in to comment.