-
Notifications
You must be signed in to change notification settings - Fork 9
/
EntangledReplicas.jl
49 lines (45 loc) · 1.75 KB
/
EntangledReplicas.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
An implementation of [`replicas`](@ref) for distributed PT.
Contains:
$FIELDS
"""
@auto struct EntangledReplicas # implements the informal interface in replica.jl
"""
The subset of replicas hosted in this process
"""
locals
"""
A specialized distributed array that
maps chain indices to replica indices (global indices).
This corresponds to the mapping ``\\boldsymbol{j}`` in line 2 of
Algorithm 5 in [Syed et al, 2021](https://rss.onlinelibrary.wiley.com/doi/10.1111/rssb.12464).
"""
chain_to_replica_global_indices::PermutedDistributedArray{Int}
end
entangler(r::EntangledReplicas) = r.chain_to_replica_global_indices.entangler # an 'entangler' encapsulates the MPI details
load(r::EntangledReplicas) = entangler(r).load # load balancing information
locals(r::EntangledReplicas) = r.locals
communicator(r::EntangledReplicas) = entangler(r).communicator
"""
$SIGNATURES
Create distributed replicas.
See [`create_replicas`](@ref).
"""
@provides replicas function create_entangled_replicas(inputs::Inputs, shared::Shared, source)
n = n_chains(inputs)
entangler = Entangler(n)
my_globals = my_global_indices(entangler.load)
chain_to_replica_global_indices = PermutedDistributedArray(my_globals, entangler)
locals = _create_locals(my_globals, inputs, shared, source)
init_permutation!(locals, chain_to_replica_global_indices)
return EntangledReplicas(locals, chain_to_replica_global_indices)
end
function init_permutation!(locals, chain_to_replica_global_indices)
indices = Int[]
new_values = Int[]
for replica in locals
push!(indices, replica.chain)
push!(new_values, replica.replica_index)
end
permuted_set!(chain_to_replica_global_indices, indices, new_values)
end