-
Notifications
You must be signed in to change notification settings - Fork 9
/
StreamTarget.jl
117 lines (95 loc) · 3.78 KB
/
StreamTarget.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
A [`target`](@ref) based on running worker processes, one for each replica,
each communicating with Pigeons
using [standard streams](https://en.wikipedia.org/wiki/Standard_streams).
These worker processes can be implemented in an arbitrary programming language.
[`StreamTarget`](@ref) implements [`log_potential`](@ref) and [`explorer`](@ref)
by invoking worker processes via standard stream communication.
The standard stream is less efficient than alternatives such as
protobuff, but it has the advantage of being supported by nearly all
programming languages in existence.
Also in many practical cases, since the worker
process is invoked only three times per chain per iteration, it is
unlikely to be the bottleneck (overhead is in the order of 0.1ms).
The worker process should be able to reply to commands of the following forms
(one command per line):
- `log_potential(0.6)` in the worker's `stdin` to which it should return a response of the form
`response(-124.23)` in its `stdout`, providing in this example the joint log density at `beta = 0.6`;
- `call_sampler!(0.4)` signaling that one round of local exploration should be performed
at `beta = 0.4`, after which the worker should signal it is done with `response()`.
"""
abstract type StreamTarget end
"""
$SIGNATURES
Dispose of the child processes associated with the pt's
[`StreamState`](@ref)'s
"""
function kill_child_processes(pt)
for replica in locals(pt.replicas)
Expect.kill(replica.state.worker_process)
end
end
"""
$SIGNATURES
Return [`StreamState`](@ref) by following these steps:
1. create a `Cmd` that uses the provided `rng` to set the random seed properly, as well
as target-specific configurations provided by `target`.
2. Create [`StreamState`](@ref) from the `Cmd` created in step 1 and return it.
"""
initialization(target::StreamTarget, rng::AbstractRNG, replica_index::Int64) = @abstract
# Internals
struct StreamPath end
#=
Only store beta, since the worker process
will take care of path construction
=#
@auto struct StreamPotential
beta
end
default_explorer(target::StreamTarget) = target
#=
Delegate exploration to the worker process.
=#
function step!(explorer::StreamTarget, replica, shared)
log_potential = find_log_potential(replica, shared.tempering, shared)
call_sampler!(log_potential, replica.state)
end
#=
Delegate iid sampling to the worker process.
Same call as explorer, rely on the worker to
detect that the annealing parameter is zero.
=#
sample_iid!(log_potential::StreamPotential, replica, shared) =
call_sampler!(log_potential, replica.state)
create_path(target::StreamTarget, ::Inputs) = StreamPath()
interpolate(path::StreamPath, beta) = StreamPotential(beta)
(log_potential::StreamPotential)(state::StreamState) =
invoke_worker(
state,
"log_potential($(log_potential.beta))",
Float64
)
call_sampler!(log_potential::StreamPotential, state::StreamState) =
invoke_worker(
state,
"call_sampler!($(log_potential.beta))"
)
# convert a random UInt64 to positive Int64/Java-Long by dropping the sign bit
java_seed(rng::AbstractRNG) = (rand(split(rng), UInt64) >>> 1) % Int64
#=
Simple stdin/stdout text-based protocol.
=#
function invoke_worker(
state::StreamState,
request::AbstractString,
return_type::Type{T} = Nothing) where {T}
println(state.worker_process, request)
prefix = expect!(state.worker_process, "response(")
if state.replica_index == 1 &&
length(prefix) > 4 # otherwise running on windows spits a lot of empty lines
# display output for replica 1 to show e.g. info messages
print(prefix)
end
response_str = expect!(state.worker_process, ")")
return T == Nothing ? nothing : parse(T, response_str)
end