Skip to content
This repository has been archived by the owner on Apr 26, 2023. It is now read-only.

Commit

Permalink
finished implementing stepthrough
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jun 15, 2017
1 parent 4dd6d74 commit be611c2
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 34 deletions.
21 changes: 21 additions & 0 deletions README.md
Expand Up @@ -126,6 +126,27 @@ Within each class directory, each file contains one tool. Each file should clear

Note: by default, since there is no observation before the first action, on the first call to the `do` block, `obs` is `nothing`.

- `stepthrough.jl`: The `stepthrough` function exposes a simulation as an iterator so that the steps can be iterated through with a for loop syntax as follows:
```julia
pomdp = BabyPOMDP()
policy = RandomPolicy(pomdp)

for (s, a, o, r) in stepthrough(pomdp, policy, "s,a,o,r", max_steps=10)
println("in state $s")
println("took action $o")
println("received observation $o and reward $r")
end
```
For more information, see the documentation for the `stepthrough` function.

The `StepSimulator` contained in this file can provide the same functionality with the following syntax:
```julia
sim = StepSimulator("s,a,r,sp")
for (s,a,r,sp) in simulate(sim, problem, policy)
# do something
end
```

### Testing
- `model.jl`: generic functions for testing POMDP models.
- `solver.jl`: standard functions for testing solvers. New solvers should be able to be used with the functions in this file.
Expand Down
5 changes: 5 additions & 0 deletions src/POMDPToolbox.jl
Expand Up @@ -113,6 +113,11 @@ include("simulators/sim.jl")
export HistoryRecorder
include("simulators/history_recorder.jl")

export
StepSimulator,
stepthrough
include("simulators/stepthrough.jl")

# model tools
export uniform_state_distribution
include("model/initial.jl")
Expand Down
16 changes: 16 additions & 0 deletions src/simulators/history_recorder.jl
Expand Up @@ -241,3 +241,19 @@ function simulate{S,A}(sim::HistoryRecorder,

return MDPHistory(sh, ah, rh, discount(mdp), sim.exception, sim.backtrace)
end

function get_initial_state(sim::Simulator, initial_state_dist)
if isnull(sim.initial_state)
return rand(sim.rng, initial_state_dist)
else
return get(sim.initial_state)
end
end

function get_initial_state(sim::Simulator, mdp::Union{MDP,POMDP})
if isnull(sim.initial_state)
return initial_state(mdp, sim.rng)
else
return get(sim.initial_state)
end
end
124 changes: 90 additions & 34 deletions src/simulators/stepthrough.jl
@@ -1,23 +1,19 @@
# StepSimulator
# maintained by @zsunberg

type StepSimulator
type StepSimulator <: Simulator
rng::AbstractRNG
initial_state::Nullable{Any}
max_steps::Nullable{Any}
spec
end
function StepSimulator(spec; rng=Base.GLOBAL_RNG, initial_state=nothing, max_steps=nothing)
return StepSimulator(rng, initial_state, max_steps, spec)
end

function simulate{S}(sim::StepSimulator, mdp::MDP{S}, policy::Policy, init_state::S=get_initial_state(sim, mdp))
symtuple = convert_spec(sim.spec, MDP)
return MDPSimIterator{symtuple,
typeof(mdp),
typeof(policy),
typeof(sim.rng), S}(mdp,
policy,
sim.rng,
init_state,
max_steps)
return MDPSimIterator(symtuple, mdp, policy, sim.rng, init_state, get(sim.max_steps, typemax(Int64)))
end

function simulate(sim::StepSimulator, pomdp::POMDP, policy::Policy, bu::Updater=updater(policy))
Expand All @@ -29,12 +25,7 @@ function simulate(sim::StepSimulator, pomdp::POMDP, policy::Policy, bu::Updater,
initial_state = get_initial_state(sim, dist)
initial_belief = initialize_belief(bu, dist)
symtuple = convert_spec(sim.spec, POMDP)
return POMDPSimIterator{symtuple,
typeof(pomdp),
typeof(policy),
typeof(bu),
typeof(rng)
}
return POMDPSimIterator(symtuple, pomdp, policy, bu, sim.rng, initial_belief, initial_state, get(sim.max_steps, typemax(Int64)))
end

immutable MDPSimIterator{SPEC, M<:MDP, P<:Policy, RNG<:AbstractRNG, S}
Expand All @@ -45,9 +36,13 @@ immutable MDPSimIterator{SPEC, M<:MDP, P<:Policy, RNG<:AbstractRNG, S}
max_steps::Int
end

Base.done{S}(it::MDPSimIterator, is::Tuple{Int, S}) = isterminal(it.mdp, is[2]) || is[1] > max_steps
function MDPSimIterator(spec::Union{Tuple, Symbol}, mdp::MDP, policy::Policy, rng::AbstractRNG, init_state, max_steps::Int)
return MDPSimIterator{spec, typeof(mdp), typeof(policy), typeof(rng), typeof(init_state)}(mdp, policy, rng, init_state, max_steps)
end

Base.done{S}(it::MDPSimIterator, is::Tuple{Int, S}) = isterminal(it.mdp, is[2]) || is[1] > it.max_steps
Base.start(it::MDPSimIterator) = (1, it.init_state)
function Base.step(it::MDPSimIterator, is::Tuple{Int, S})
function Base.next{S}(it::MDPSimIterator, is::Tuple{Int, S})
s = is[2]
a = action(it.policy, s)
sp, r = generate_sr(it.mdp, s, a, it.rng)
Expand All @@ -61,16 +56,32 @@ immutable POMDPSimIterator{SPEC, M<:POMDP, P<:Policy, U<:Updater, RNG<:AbstractR
rng::RNG
init_belief::B
init_state::S
max_steps::Int
end
function POMDPSimIterator(spec::Union{Tuple,Symbol}, pomdp::POMDP, policy::Policy, up::Updater, rng::AbstractRNG, init_belief, init_state, max_steps::Int)
return POMDPSimIterator{spec,
typeof(pomdp),
typeof(policy),
typeof(up),
typeof(rng),
typeof(init_belief),
typeof(init_state)}(pomdp,
policy,
up,
rng,
init_belief,
init_state,
max_steps)
end

Base.done{S,B}(it::POMDPSimIterator, is::Tuple{Int, S, B}) = isterminal(it.mdp, is[2]) || is[1] > max_steps
Base.done{S,B}(it::POMDPSimIterator, is::Tuple{Int, S, B}) = isterminal(it.pomdp, is[2]) || is[1] > it.max_steps
Base.start(it::POMDPSimIterator) = (1, it.init_state, it.init_belief)
function Base.step(it::POMDPSimIterator, is::Tuple{Int, S, B})
function Base.next{S,B}(it::POMDPSimIterator, is::Tuple{Int, S, B})
s = is[2]
b = is[3]
a = action(it.policy, b)
sp, o, r = generate_sor(it.mdp, s, a, it.rng)
bp = updater(it.updater, b, a, o)
sp, o, r = generate_sor(it.pomdp, s, a, it.rng)
bp = update(it.updater, b, a, o)
return (out_tuple(it, (s, a, r, sp, b, o, bp)), (is[1]+1, sp, bp))
end

Expand All @@ -89,23 +100,24 @@ sym_to_ind = Dict(sym=>i for (i, sym) in enumerate([:s,:a,:r,:sp,:b,:o,:bp]))
return tuple($(calls...))
end
else
@assert isa(spec, Symbol)
@assert isa(spec, Symbol) "Invalid specification: $spec is not a Symbol or Tuple."
return quote
return all[$(sym_to_ind[spec])]
end
end
end

convert_spec(spec, T::Type{POMDP}) = convert_spec(spec, Set(:sp, :bp, :s, :a, :r, :b, :o))
convert_spec(spec, T::Type{MDP}) = convert_spec(spec, Set(:sp, :s, :a, :r))
convert_spec(spec, T::Type{POMDP}) = convert_spec(spec, Set(tuple(:sp, :bp, :s, :a, :r, :b, :o)))
convert_spec(spec, T::Type{MDP}) = convert_spec(spec, Set(tuple(:sp, :s, :a, :r)))

function convert_spec(spec, recognized::Set{Symbol})
conv = convert_spec(spec)
for s in conv
for s in (isa(conv, Tuple) ? conv : tuple(conv))
if !(s in recognized)
warn("uncrecognized symbol $s in step iteration specification $spec.")
end
end
return conv
end

function convert_spec(spec::String)
Expand All @@ -129,18 +141,62 @@ end

convert_spec(spec::Symbol) = spec

function get_initial_state(sim::Simulator, initial_state_dist)
if isnull(sim.initial_state)
return rand(sim.rng, initial_state_dist)
else
return get(sim.initial_state)
"""
stepthrough(problem, policy, [spec])
stepthrough(problem, policy, [spec], [rng=rng], [max_steps=max_steps], [initial_state=initial_state])
Create a simulation iterator. This is intended to be used with for loop syntax to output the results of each step *as the simulation is being run*.
Example:
pomdp = BabyPOMDP()
policy = RandomPolicy(pomdp)
for (s, a, o, r) in stepthrough(pomdp, policy, "s,a,o,r", max_steps=10)
println("in state \$s")
println("took action \$o")
println("received observation \$o and reward \$r")
end
The spec argument can be a string, tuple of symbols, or single symbol and follows the same pattern as `eachstep` called on a `SimHistory` object.
Under the hood, this function creates a `StepSimulator` with `spec` and returns a `[PO]MDPSimIterator` by calling simulate with all of the arguments except `spec`. All keyword arguments are passed to the `StepSimulator` constructor.
"""
function stepthrough end # for documentation

function stepthrough(mdp::MDP, policy::Policy, spec::Union{String, Tuple, Symbol}=(:s,:a,:r,:sp); kwargs...)
sim = StepSimulator(spec; kwargs...)
return simulate(sim, mdp, policy)
end

"""
stepthrough(mdp::MDP, policy::Policy, [init_state], [spec="sarsp"]; [kwargs...])
Step through an mdp simulation. The initial state is optional. If no spec is given, (s, a, r, sp) is used.
"""
function stepthrough{S}(mdp::MDP{S},
policy::Policy,
init_state::S,
spec::Union{String, Tuple, Symbol}=(:s,:a,:r,:sp);
kwargs...)
sim = StepSimulator(spec; kwargs...)
return simulate(sim, mdp, policy, init_state)
end

function get_initial_state(sim::Simulator, mdp::Union{MDP,POMDP})
if isnull(sim.initial_state)
return initial_state(mdp, sim.rng)
"""
stepthrough(pomdp::POMDP, policy::Policy, [up::Updater, [initial_belief]], [spec="ao"]; [kwargs...])
Step through a pomdp simulation. the updater and initial belief are optional. If no spec is given, (a, o) is used.
"""
function stepthrough(pomdp::POMDP, policy::Policy, args...; kwargs...)
spec_included=false
if isa(last(args), Union{String, Tuple, Symbol})
spec = last(args)
spec_included = true
else
return get(sim.initial_state)
spec=(:a,:o)
end
sim = StepSimulator(spec; kwargs...)
return simulate(sim, pomdp, policy, args[1:end-spec_included]...)
end

1 change: 1 addition & 0 deletions test/runtests.jl
Expand Up @@ -7,6 +7,7 @@ include("test_random_solver.jl")
include("test_rollout.jl")
include("test_history_recorder.jl")
include("test_sim.jl")
include("test_stepthrough.jl")
include("test_belief.jl")
include("test_particle.jl")
include("test_solver_test.jl")
Expand Down
61 changes: 61 additions & 0 deletions test/test_stepthrough.jl
@@ -0,0 +1,61 @@
# mdp step simulator and stepthrough
let
mdp = GridWorld()
solver = RandomSolver(MersenneTwister(2))
policy = solve(solver, mdp)
sim = StepSimulator("s,sp,r,a", rng=MersenneTwister(3), max_steps=100)
n_steps = 0
for (s, sp, r, a) in simulate(sim, mdp, policy)
@test isa(s, state_type(mdp))
@test isa(sp, state_type(mdp))
@test isa(r, Float64)
@test isa(a, action_type(mdp))
n_steps += 1
end
@test n_steps <= 100

n_steps = 0
for s in stepthrough(mdp, policy, "s", rng=MersenneTwister(4), max_steps=100)
@test isa(s, state_type(mdp))
n_steps += 1
end
@test n_steps <= 100
end

# pomdp step simulator and stepthrough
let
mdp = BabyPOMDP()
policy = FeedWhenCrying()
up = PrimedPreviousObservationUpdater(true)
sim = StepSimulator("s,sp,r,a,b", rng=MersenneTwister(3), max_steps=100)
n_steps = 0
for (s, sp, r, a, b) in simulate(sim, mdp, policy, up)
@test isa(s, state_type(mdp))
@test isa(sp, state_type(mdp))
@test isa(r, Float64)
@test isa(a, action_type(mdp))
@test isa(b, Bool)
n_steps += 1
end
@test n_steps == 100

n_steps = 0
for r in stepthrough(mdp, policy, "r", rng=MersenneTwister(4), max_steps=100)
@test isa(r, Float64)
@test r <= 0
n_steps += 1
end
@test n_steps == 100
end

# example from stepthrough documentation
let
pomdp = BabyPOMDP()
policy = RandomPolicy(pomdp)

for (s, a, o, r) in stepthrough(pomdp, policy, "s,a,o,r", max_steps=10)
println("in state $s")
println("took action $o")
println("received observation $o and reward $r")
end
end

0 comments on commit be611c2

Please sign in to comment.