-
Notifications
You must be signed in to change notification settings - Fork 9
/
state.jl
97 lines (74 loc) · 2.98 KB
/
state.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
The state held in each Parallel Tempering [`Replica`](@ref).
"""
@informal state begin
"""
$SIGNATURES
The names (each a `Symbol`) of the continuous variables in the given [`state`](@ref).
"""
continuous_variables(state) = @abstract
"""
$SIGNATURES
The names (each a `Symbol`) of the discrete (Int) variables in the given state.
"""
discrete_variables(state) = @abstract
"""
$SIGNATURES
The storage within the [`state`](@ref) of the variable of the given name, typically an `Array`.
"""
variable(state, name::Symbol) = @abstract
"""
$SIGNATURES
Update the state's entry at symbol `name` and `index` with `value`.
"""
update_state!(state, name::Symbol, index, value) = @abstract
"""
$SIGNATURES
Extract a sample for postprocessing. By default, calls `copy()` but many overloads are
defined for different kinds of states.
Typically, this will be a flattened vector (i.e. concatenation of all variables, with discrete
ones converted to Float64) ready for post-processing.
The corresponding un-normalized log density might be appended at the very end.
If the state is transformed (e.g. for HMC), this will create a fresh vector
with an un-transformed (i.e. original parameterization) state in it.
The argument `extractor` is passed from the [`Inputs`](@ref).
"""
extract_sample(state, log_potential, extractor::Nothing) = extract_sample(state, log_potential)
"""
$SIGNATURES
A list of string labels for the flattened vectors returned by
[`extract_sample()`](@ref).
The key `:log_density` is used when the un-normalized log density
is included.
"""
sample_names(state, log_potential, extractor::Nothing) = sample_names(state, log_potential)
end
extract_sample(state, log_potential) = copy(state)
function sample_names(pt::PT)
a_replica = locals(pt.replicas)[1]
return sample_names(a_replica.state, find_log_potential(a_replica, pt.shared.tempering, pt.shared), pt.inputs.extractor)
end
# Implementations
const SINGLETON_VAR = [:singleton_variable]
continuous_variables(state::Union{Nothing, StreamState}) = SINGLETON_VAR # e.g. for TestSwapper
discrete_variables(state::Union{Nothing, StreamState}) = []
continuous_variables(state::Array) = SINGLETON_VAR
discrete_variables(state::Array) = []
function update_state!(state::Array, name::Symbol, index, value)
@assert name === :singleton_variable
state[index] = value
end
extract_sample(state::Array, log_potential) = [state; log_potential(state)]
function variable(state::Array, name::Symbol)
if name === :singleton_variable
state
else
error()
end
end
function variables end
sample_names(state::Array, log_potential) = [map(i -> Symbol("param_$i"), 1:length(state)); :log_density]
# For the stream interface, view the state as a black box
# and also we don't want that running with default block of recorders
# crashes.
continuous_variables(state::StreamState) = []