In [1]:
versioninfo()

Julia Version 1.5.3
Commit 788b2c77c1 (2020-11-09 13:37 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-9.0.1 (ORCJIT, skylake)
Environment:
  JULIA_NUM_THREADS = 4


In [2]:
using Base.Threads
using Base: AbstractLock, Callable

In [3]:
mutable struct ThreadSafeDict{K,V,L<:AbstractLock} <: AbstractDict{K,V}
    dict::Dict{K,V}
    lock::L
    ThreadSafeDict{K,V}() where {K,V} = 
        new{K,V,ReentrantLock}(Dict{K,V}(), ReentrantLock())
    ThreadSafeDict{K,V,L}() where {K,V,L<:AbstractLock} = 
        new{K,V,L}(Dict{K,V}(), L())
end

In [4]:
Base.iterate(d::ThreadSafeDict, state...) = iterate(d.dict, state...)
Base.length(d::ThreadSafeDict) = length(d.dict)
Base.isempty(d::ThreadSafeDict) = isempty(d.dict)

In [5]:
function Base.sizehint!(d::ThreadSafeDict, n::Integer)
    lock(d.lock) do
        sizehint!(d.dict, n)
    end
    return d
end

In [6]:
_unsafe_haskey(d::ThreadSafeDict, key) = haskey(d.dict, key)
function Base.haskey(d::ThreadSafeDict, key)
    lock(d.lock) do
        _unsafe_haskey(d, key)
    end
end

In [7]:
function Base.get(d::ThreadSafeDict, key, default)
    lock(d.lock) do
        if _unsafe_haskey(d, key)
            _unsafe_getindex(d, key)
        else
            default
        end
    end
end

function Base.get(default::Base.Callable, d::ThreadSafeDict, key)
    lock(d.lock) do
        if _unsafe_haskey(d, key)
            _unsafe_getindex(d, key)
        else
            default()
        end
    end
end

In [8]:
function Base.get!(d::ThreadSafeDict, key, default)
    lock(d.lock) do
        if _unsafe_haskey(d, key)
            _unsafe_getindex(d, key)
        else
            v = default
            _unsafe_addindex!(d, v, key)
            v
        end
    end
end

function Base.get!(default::Callable, d::ThreadSafeDict, key)
    lock(d.lock) do
        if _unsafe_haskey(d, key)
            _unsafe_getindex(d, key)
        else
            v = default()
            _unsafe_addindex!(d, v, key)
            v
        end
    end
end

In [9]:
_unsafe_getindex(d::ThreadSafeDict, key) = d.dict[key]
function Base.getindex(d::ThreadSafeDict, key)
    lock(d.lock) do
        if _unsafe_haskey(d, key)
            _unsafe_getindex(d, key)
        else
            throw(KeyError(key))
        end
    end
end

In [10]:
function _unsafe_addindex!(d::ThreadSafeDict, v, key)
    d.dict[key] = v
end
function Base.setindex!(d::ThreadSafeDict, v, key)
    lock(d.lock) do
        _unsafe_addindex!(d, v, key)
    end
    d
end

In [11]:
function Base.delete!(d::ThreadSafeDict, key)
    lock(d.lock) do
        delete!(d.dict, key)
    end
    d
end
function Base.pop!(d::ThreadSafeDict, key)
    lock(d.lock) do
        pop!(d.dict, key)
    end
end

In [12]:
function Base.empty!(d::ThreadSafeDict)
    lock(d.lock) do
        empty!(d.dict)
    end
    d
end

### 実験

In [13]:
using Random
p = randperm(10)

10-element Array{Int64,1}:
  8
  9
  4
  7
 10
  6
  3
  1
  5
  2

In [14]:
f!(d, n) = get!(d, n) do
    println("recursive $(n)->$(n-1) (threadid: $(threadid()))")
    v, _i = f!(d, n-1)
    (v + 1, threadid())
end

f! (generic function with 1 method)

#### Thread-Unsafe

In [15]:
d0 = Dict{Int,Tuple{Int,Int}}()

Dict{Int64,Tuple{Int64,Int64}}()

In [16]:
d0[1] = (1, 1)

(1, 1)

In [17]:
@threads for i = 1:10
    f!(d0, p[i])
end

recursive 7->6 (threadid: 2)
recursive 5->4 (threadid: 4)
recursive 8->7 (threadid: 1)
recursive 4->3 (threadid: 4)
recursive 3->2 (threadid: 3)
recursive 6->5 (threadid: 2)
recursive 3->2 (threadid: 4)
recursive 5->4 (threadid: 2)
recursive 2->1 (threadid: 4)
recursive 7->6 (threadid: 1)
recursive 4->3 (threadid: 2)
recursive 10->9 (threadid: 2)
recursive 2->1 (threadid: 3)
recursive 9->8 (threadid: 2)
recursive 6->5 (threadid: 1)
recursive 8->7 (threadid: 2)
recursive 9->8 (threadid: 1)


In [18]:
d0

Dict{Int64,Tuple{Int64,Int64}} with 10 entries:
  7  => (7, 1)
  4  => (4, 2)
  9  => (9, 1)
  10 => (10, 2)
  2  => (2, 3)
  3  => (3, 3)
  5  => (5, 2)
  8  => (8, 2)
  6  => (6, 1)
  1  => (1, 1)

#### Thread-Safe

In [19]:
d1 = ThreadSafeDict{Int,Tuple{Int,Int}}()

ThreadSafeDict{Int64,Tuple{Int64,Int64},ReentrantLock}()

In [20]:
d1[1] = (1, 1)

(1, 1)

In [21]:
@threads for i = 1:10
    f!(d1, p[i])
end

recursive 8->7 (threadid: 1)
recursive 7->6 (threadid: 1)
recursive 6->5 (threadid: 1)
recursive 5->4 (threadid: 1)
recursive 4->3 (threadid: 1)
recursive 3->2 (threadid: 1)
recursive 2->1 (threadid: 1)
recursive 9->8 (threadid: 1)
recursive 10->9 (threadid: 2)


In [22]:
d1

ThreadSafeDict{Int64,Tuple{Int64,Int64},ReentrantLock} with 10 entries:
  7  => (7, 1)
  4  => (4, 1)
  9  => (9, 1)
  10 => (10, 2)
  2  => (2, 1)
  3  => (3, 1)
  5  => (5, 1)
  8  => (8, 1)
  6  => (6, 1)
  1  => (1, 1)

#### NG (Dead-Lock occurs)

In [None]:
d2 = ThreadSafeDict{Int,Tuple{Int,Int},SpinLock}()

In [None]:
d2[1] = (1, 1)

In [None]:
@threads for i = 1:10
    f!(d2, p[i])
end