-
Notifications
You must be signed in to change notification settings - Fork 98
/
history_recorder.jl
213 lines (182 loc) · 6.84 KB
/
history_recorder.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# HistoryRecorder
# maintained by @zsunberg
"""
A simulator that records the history for later examination
The simulation will be terminated when either
1) a terminal state is reached (as determined by `isterminal()` or
2) the discount factor is as small as `eps` or
3) max_steps have been executed
Keyword Arguments:
- `rng`: The random number generator for the simulation
- `capture_exception::Bool`: whether to capture an exception and store it in the history, or let it go uncaught, potentially killing the script
- `show_progress::Bool`: show a progress bar for the simulation
- `eps`
- `max_steps`
Usage (optional arguments in brackets):
hr = HistoryRecorder()
history = simulate(hr, pomdp, policy, [updater [, init_belief [, init_state]]])
"""
mutable struct HistoryRecorder <: Simulator
rng::AbstractRNG
# options
capture_exception::Bool
show_progress::Bool
# optional: if these are null, they will be ignored
max_steps::Union{Nothing,Any}
eps::Union{Nothing,Any}
end
# This is the only stable constructor
function HistoryRecorder(;rng=Random.default_rng(),
eps=nothing,
max_steps=nothing,
capture_exception=false,
show_progress=false)
return HistoryRecorder(rng, capture_exception, show_progress, max_steps, eps)
end
@POMDP_require simulate(sim::HistoryRecorder, pomdp::POMDP, policy::Policy) begin
@req updater(::typeof(policy))
up = updater(policy)
@subreq simulate(sim, pomdp, policy, up)
end
@POMDP_require simulate(sim::HistoryRecorder, pomdp::POMDP, policy::Policy, bu::Updater) begin
@req initialstate(::typeof(pomdp))
dist = initialstate(pomdp)
@subreq simulate(sim, pomdp, policy, bu, dist)
end
function simulate(sim::HistoryRecorder, pomdp::POMDP, policy::Policy, bu::Updater=updater(policy))
dist = initialstate(pomdp)
return simulate(sim, pomdp, policy, bu, dist)
end
@POMDP_require simulate(sim::HistoryRecorder, pomdp::POMDP, policy::Policy, bu::Updater, dist::Any) begin
P = typeof(pomdp)
S = statetype(pomdp)
A = actiontype(pomdp)
O = obstype(pomdp)
@req initialize_belief(::typeof(bu), ::typeof(dist))
@req isterminal(::P, ::S)
@req discount(::P)
@req gen(::P, ::S, ::A, ::typeof(sim.rng))
b = initialize_belief(bu, dist)
B = typeof(b)
@req action(::typeof(policy), ::B)
@req update(::typeof(bu), ::B, ::A, ::O)
end
function simulate(sim::HistoryRecorder,
pomdp::POMDP{S,A,O},
policy::Policy,
bu::Updater,
initialstate_dist::Any,
is::Any=rand(sim.rng, initialstate(pomdp))
) where {S,A,O}
initial_belief = initialize_belief(bu, initialstate_dist)
max_steps = something(sim.max_steps, typemax(Int))
if sim.eps != nothing
max_steps = min(max_steps, ceil(Int,log(sim.eps)/log(discount(pomdp))))
end
if sim.show_progress
if (sim.max_steps === nothing) && (sim.eps === nothing)
error("If show_progress=true in a HistoryRecorder, you must also specify max_steps or eps.")
end
prog = Progress(max_steps; desc="Simulating..." )
else
prog = nothing
end
it = POMDPSimIterator(default_spec(pomdp),
pomdp,
policy,
bu,
sim.rng,
initial_belief,
is,
max_steps)
history, exception, backtrace = collect_history(it, Val(sim.capture_exception), prog)
if sim.show_progress
finish!(prog)
end
return SimHistory(promote_history(history), discount(pomdp), exception, backtrace)
end
@POMDP_require simulate(sim::HistoryRecorder, mdp::MDP, policy::Policy) begin
init_state = rand(sim.rng, initialstate(mdp))
@subreq simulate(sim, mdp, policy, init_state)
end
@POMDP_require simulate(sim::HistoryRecorder, mdp::MDP, policy::Policy, initialstate::Any) begin
P = typeof(mdp)
S = statetype(mdp)
A = actiontype(mdp)
@req isterminal(::P, ::S)
@req action(::typeof(policy), ::S)
@req gen(::P, ::S, ::A, ::typeof(sim.rng))
@req discount(::P)
end
function simulate(sim::HistoryRecorder,
mdp::MDP{S,A}, policy::Policy,
init_state::S=rand(sim.rng, initialstate(mdp))) where {S,A}
max_steps = something(sim.max_steps, typemax(Int))
if sim.eps != nothing
max_steps = min(max_steps, ceil(Int,log(sim.eps)/log(discount(mdp))))
end
it = MDPSimIterator(default_spec(mdp),
mdp,
policy,
sim.rng,
init_state,
max_steps)
if sim.show_progress
if (sim.max_steps === nothing) && (sim.eps === nothing)
error("If show_progress=true in a HistoryRecorder, you must also specify max_steps or eps.")
end
prog = Progress(max_steps; desc="Simulating..." )
else
prog = nothing
end
history, exception, backtrace = collect_history(it, Val(sim.capture_exception), prog)
if sim.show_progress
finish!(prog)
end
return SimHistory(promote_history(history), discount(mdp), exception, backtrace)
end
function collect_history(it, cap_ex::Val{true}, prog::Union{Progress,Nothing})
exception = nothing
backtrace = nothing
history = NamedTuple[] # capturing part of the history is more important than this having a concrete type
try
for step in it
push!(history, step)
if prog !== nothing
next!(prog)
end
end
catch ex
exception = ex
backtrace = catch_backtrace()
end
return history, exception, backtrace
end
collect_history(it, cap_ex::Val{false}, prog::Nothing) = collect(it), nothing, nothing
function collect_history(it, cap_ex::Val{false}, prog::Progress)
h = collect(begin
next!(prog)
step
end for step in it)
return h, nothing, nothing
end
"""
Promotes all NamedTuples in the history to the same type.
"""
function promote_history(hist::AbstractVector)
if isconcretetype(eltype(hist))
return hist
elseif isempty(hist) # note, from above, also does not have concrete type
return NamedTuple{(), Tuple{}}[]
else
# it would really astound me if this branch was type stable
names = fieldnames(first(hist))
types = fieldtypes(first(hist))
for step in hist
@assert fieldnames(step) == names
types = map(promote_type, types, fieldtypes(step))
end
newtype = NamedTuple{names, Tuple{types...}}
return convert(Vector{newtype}, hist)
end
end