/
stepthrough.jl
213 lines (182 loc) · 8.21 KB
/
stepthrough.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# StepSimulator
# maintained by @zsunberg
struct StepSimulator <: Simulator
rng::AbstractRNG
max_steps::Union{Nothing,Any}
spec
end
function StepSimulator(spec; rng=Random.default_rng(), max_steps=nothing)
return StepSimulator(rng, max_steps, spec)
end
function simulate(sim::StepSimulator, mdp::MDP{S}, policy::Policy, init_state::S=rand(sim.rng, initialstate(mdp))) where {S}
symtuple = convert_spec(sim.spec, typeof(mdp))
max_steps = something(sim.max_steps, typemax(Int64))
return MDPSimIterator(symtuple, mdp, policy, sim.rng, init_state, max_steps)
end
function simulate(sim::StepSimulator, pomdp::POMDP, policy::Policy, bu::Updater=updater(policy))
dist = initialstate(pomdp)
return simulate(sim, pomdp, policy, bu, dist)
end
function simulate(sim::StepSimulator, pomdp::POMDP, policy::Policy, bu::Updater, dist::Any, is=rand(sim.rng, initialstate(pomdp)))
initial_belief = initialize_belief(bu, dist)
symtuple = convert_spec(sim.spec, typeof(pomdp))
max_steps = something(sim.max_steps, typemax(Int64))
return POMDPSimIterator(symtuple, pomdp, policy, bu, sim.rng, initial_belief, is, max_steps)
end
struct MDPSimIterator{SPEC, M<:MDP, P<:Policy, RNG<:AbstractRNG, S}
mdp::M
policy::P
rng::RNG
init_state::S
max_steps::Int
end
function MDPSimIterator(spec::Union{Tuple, Symbol}, mdp::MDP, policy::Policy, rng::AbstractRNG, init_state, max_steps::Int)
return MDPSimIterator{spec, typeof(mdp), typeof(policy), typeof(rng), typeof(init_state)}(mdp, policy, rng, init_state, max_steps)
end
Base.IteratorSize(::Type{<:MDPSimIterator}) = Base.SizeUnknown()
function Base.iterate(it::MDPSimIterator, is::Tuple{Int, S}=(1, it.init_state)) where S
if isterminal(it.mdp, is[2]) || is[1] > it.max_steps
return nothing
end
t = is[1]
s = is[2]
a, ai = action_info(it.policy, s)
out = @gen(:sp,:r,:info)(it.mdp, s, a, it.rng)
nt = merge(NamedTuple{(:sp,:r,:info)}(out), (t=t, s=s, a=a, action_info=ai))
return (out_tuple(it, nt), (t+1, nt.sp))
end
struct POMDPSimIterator{SPEC, M<:POMDP, P<:Policy, U<:Updater, RNG<:AbstractRNG, B, S}
pomdp::M
policy::P
updater::U
rng::RNG
init_belief::B
init_state::S
max_steps::Int
end
function POMDPSimIterator(spec::Union{Tuple,Symbol}, pomdp::POMDP, policy::Policy, up::Updater, rng::AbstractRNG, init_belief, init_state, max_steps::Int)
return POMDPSimIterator{spec,
typeof(pomdp),
typeof(policy),
typeof(up),
typeof(rng),
typeof(init_belief),
typeof(init_state)}(pomdp,
policy,
up,
rng,
init_belief,
init_state,
max_steps)
end
Base.IteratorSize(::Type{<:POMDPSimIterator}) = Base.SizeUnknown()
function Base.iterate(it::POMDPSimIterator, is::Tuple{Int,S,B} = (1, it.init_state, it.init_belief)) where {S,B}
if isterminal(it.pomdp, is[2]) || is[1] > it.max_steps
return nothing
end
t = is[1]
s = is[2]
b = is[3]
a, ai = action_info(it.policy, b)
out = @gen(:sp,:o,:r,:info)(it.pomdp, s, a, it.rng)
outnt = NamedTuple{(:sp,:o,:r,:info)}(out)
bp, ui = update_info(it.updater, b, a, outnt.o)
nt = merge(outnt, (t=t, b=b, s=s, a=a, action_info=ai, bp=bp, update_info=ui))
return (out_tuple(it, nt), (t+1, nt.sp, nt.bp))
end
function out_tuple(it::Union{MDPSimIterator{spec}, POMDPSimIterator{spec}}, all::NamedTuple) where spec
if isa(spec, Tuple)
return NamedTupleTools.select(all, spec)
else
@assert isa(spec, Symbol) "Invalid specification: $spec is not a Symbol or Tuple."
return all[spec]
end
end
convert_spec(spec, T::Type{M}) where {M<:POMDP} = convert_spec(spec, Set(tuple(:s, :a, :sp, :o, :r, :info, :bp, :b, :action_info, :update_info, :t)))
convert_spec(spec, T::Type{M}) where {M<:MDP} = convert_spec(spec, Set(tuple(:s, :a, :sp, :r, :info, :action_info, :t)))
function convert_spec(spec, recognized::Set{Symbol})
conv = convert_spec(spec)
convtpl = isa(conv, Tuple) ? conv : tuple(conv)
for s in convtpl
if s == :ai && !(:action_info in convtpl)
@warn("Using :ai to access the action info in a history is deprecated. Use :action_info instead.") # XXX get rid of in v0.4 or greater
elseif s == :ui && !(:update_info in convtpl)
@warn("Using :ui to access the update info in a history is deprecated. Use :update_info instead.") # XXX get rid of in v0.4 or greater
elseif !(s in recognized)
@warn("uncrecognized symbol $s in step iteration specification $spec.")
end
end
return conv
end
function convert_spec(spec::String)
syms = spec |> x->strip(x,['(',')']) |> x->split(x,',') |> x->strip.(x) |> x->Symbol.(x)
if length(syms) == 1
return Symbol(first(syms))
else
return tuple(syms...)
end
end
function convert_spec(spec::Tuple)
for s in spec
@assert isa(s, Symbol)
end
return spec
end
convert_spec(spec::Symbol) = spec
"""
CompleteSpec()
Default placeholder for a complete step output specification. Will include all DDNNodes, plus all known possible outputs in each step.
"""
struct CompleteSpec end
convert_spec(::CompleteSpec, T::Type{M}) where M <: MDP = default_spec(T)
convert_spec(::CompleteSpec, T::Type{M}) where M <: POMDP = default_spec(T)
default_spec(m::Union{MDP,POMDP}) = default_spec(typeof(m))
default_spec(T::Type{M}) where M <: MDP = tuple(:s, :a, :sp, :r, :info, :t, :action_info)
default_spec(T::Type{M}) where M <: POMDP = tuple(:s, :a, :sp, :o, :r, :info, :t, :action_info, :b, :bp, :update_info)
"""
stepthrough(problem, policy, [spec])
stepthrough(problem, policy, [spec], [rng=rng], [max_steps=max_steps])
stepthrough(mdp::MDP, policy::Policy, [init_state], [spec]; [kwargs...])
stepthrough(pomdp::POMDP, policy::Policy, [up::Updater, [initial_belief, [initial_state]]], [spec]; [kwargs...])
Create a simulation iterator. This is intended to be used with for loop syntax to output the results of each step *as the simulation is being run*.
Example:
pomdp = BabyPOMDP()
policy = RandomPolicy(pomdp)
for (s, a, o, r) in stepthrough(pomdp, policy, "s,a,o,r", max_steps=10)
println("in state \$s")
println("took action \$a")
println("received observation \$o and reward \$r")
end
The optional `spec` argument can be a string, tuple of symbols, or single symbol and follows the same pattern as [`eachstep`](@ref) called on a `SimHistory` object.
Under the hood, this function creates a `StepSimulator` with `spec` and returns a `[PO]MDPSimIterator` by calling simulate with all of the arguments except `spec`. All keyword arguments are passed to the `StepSimulator` constructor.
"""
function stepthrough end # for documentation
function stepthrough(mdp::MDP,
policy::Policy,
spec::Union{String, Tuple, Symbol}=default_spec(mdp);
kwargs...)
sim = StepSimulator(spec; kwargs...)
return simulate(sim, mdp, policy)
end
function stepthrough(mdp::MDP{S},
policy::Policy,
init_state::S,
spec::Union{String, Tuple, Symbol}=default_spec(mdp);
kwargs...) where {S}
sim = StepSimulator(spec; kwargs...)
return simulate(sim, mdp, policy, init_state)
end
function stepthrough(pomdp::POMDP, policy::Policy, args...; kwargs...)
spec_included=false
if !isempty(args) && isa(last(args), Union{String, Tuple, Symbol})
spec = last(args)
spec_included = true
if spec isa statetype(pomdp) && length(args) == 3
error("Ambiguity between `initial_state` and `spec` arguments in stepthrough. Please explicitly specify the initial state and spec.")
end
else
spec = default_spec(pomdp)
end
sim = StepSimulator(spec; kwargs...)
return simulate(sim, pomdp, policy, args[1:end-spec_included]...)
end