Skip to content

Commit

Permalink
minor test update
Browse files Browse the repository at this point in the history
  • Loading branch information
Omastto1 committed Mar 19, 2021
1 parent 5312ba2 commit de466a5
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 47 deletions.
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
DiscreteValueIteration = "4b033969-44f6-5439-a48b-c11fa3648068"
FiniteHorizonPOMDPs = "8a13bbfe-798e-11e9-2f1c-eba9ee5ef093"
POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755"
POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
Expand Down
4 changes: 2 additions & 2 deletions test/custom1dtest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fhsolver = FiniteHorizonSolver()
# MDPs initialization
mdp = CustomFHExample(no_states, _horizon, actions, actionCost, actionsImpact, reward_states, reward, discount_factor, noise)

# check implementation of required methods
# check implementation of required methods
# @POMDPLinter.show_requirements FiniteHorizonPOMDPs.solve(fhsolver, mdp)

# initialize the solver
Expand All @@ -34,7 +34,7 @@ FHPolicy = FiniteHorizonPOMDPs.solve(fhsolver, mdp);
# Compare resulting policies
@test all((FiniteHorizonPOMDPs.action(FHPolicy, s) == action(VIPolicy, s) for s in states(mdp)))

# Compare FHMDP and IHMDP states
# Compare FHMDP and IHMDP states
fh_states = Iterators.flatten([FiniteHorizonPOMDPs.stage_states(mdp, i) for i=1:FiniteHorizonPOMDPs.horizon(mdp) + 1])
ih_states = states(mdp)

Expand Down
2 changes: 1 addition & 1 deletion test/fixhorizon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ end
@testset "solver" begin
fhgw = fixhorizon(SimpleGridWorld(), 3)
# TODO: Change test to pass without boolean value
# @test FiniteHorizonPOMDPs.solve(fhgw)
# @test test_solver(FiniteHorizonSolver(), fhgw)
end
7 changes: 2 additions & 5 deletions test/fixhorizon1dtest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mdp = FHExample(no_states, actions, actionCost, actionsImpact, reward_states, re
# Wrap mdp into finite horizon wrapper
fhex = fixhorizon(mdp, _horizon)

# check implementation of required methods
# check implementation of required methods
# @POMDPLinter.show_requirements FiniteHorizonPOMDPs.solve(fhsolver, fhex)

# initialize the solver
Expand All @@ -38,13 +38,10 @@ FHPolicy = FiniteHorizonPOMDPs.solve(fhsolver, fhex);
# Compare resulting policies
@test all((FiniteHorizonPOMDPs.action(FHPolicy, s) == action(VIPolicy, s) for s in states(fhex)))

# Compare FHMDP and IHMDP states
# Compare FHMDP and IHMDP states
fh_states = Iterators.flatten([FiniteHorizonPOMDPs.stage_states(fhex, i) for i=1:fhex.horizon + 1])
ih_states = states(fhex)

z = zip(fh_states, ih_states)

@test all((fh == ih for (fh, ih) in z))



46 changes: 12 additions & 34 deletions test/instances/Pyramid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using FiniteHorizonPOMDPs
#####################
# MDP and State types
#####################
struct PyramidState
struct PyramidState
position::Int64
epoch::Int64
end
Expand All @@ -31,57 +31,35 @@ FiniteHorizonPOMDPs.horizon(mdp::PyramidMDP) = mdp.horizon
# changed elements of POMDPs interface
###################################

# Creates (horizon(mdp) - 1) * mdp.no_states states to be evaluated and mdp.no_states sink states
function POMDPs.states(mdp::PyramidMDP)::Array{PyramidState}
mdp_states = PyramidState[]
for e=1:horizon(mdp) + 1
for i=1:e
push!(mdp_states, PyramidState(i, e))
end
end

return mdp_states
# Creates (horizon(mdp) + 1) * mdp.no_states states to be evaluated and mdp.no_states sink states
function POMDPs.states(mdp::PyramidMDP)
return [PyramidState(i, e) for e in 1:horizon(mdp) + 1 for i in 1:e]
end

POMDPs.stateindex(mdp::PyramidMDP, ss::PyramidState)::Int64 = sum(1:ss.epoch - 1) + ss.position
POMDPs.isterminal(mdp::PyramidMDP, ss::PyramidState) = FiniteHorizonPOMDPs.stage(mdp, ss) > horizon(mdp) || [ss.epoch, ss.position] in mdp.reward_states
POMDPs.isterminal(mdp::PyramidMDP, ss::PyramidState) = FiniteHorizonPOMDPs.stage(mdp, ss) > horizon(mdp) || (ss.epoch, ss.position) in mdp.reward_states

# returns transition distributions - works only for 1D Gridworld with possible moves to left and to right
function POMDPs.transition(mdp::PyramidMDP, ss::PyramidState, a::Symbol)::SparseCat{Vector{PyramidState},Vector{Float64}}
sp = PyramidState[]
prob = Float64[]

# add original transition target and probability
position = ss.position + mdp.actionsImpact[a]
push!(sp, PyramidState(position, ss.epoch + 1))
push!(prob, 1. - mdp.noise)

# add noise transition target and probability
noise_action = a == :l ? :r : :l
position = ss.position + mdp.actionsImpact[noise_action]
push!(sp, PyramidState(position, ss.epoch + 1))
push!(prob, mdp.noise)
function POMDPs.transition(mdp::PyramidMDP, ss::PyramidState, a::Symbol)::SparseCat
sp = ( PyramidState(ss.position + mdp.actionsImpact[a], ss.epoch + 1),
PyramidState(ss.position + mdp.actionsImpact[a == :l ? :r : :l], ss.epoch + 1))
prob = (1. - mdp.noise, mdp.noise)

return SparseCat(sp, prob)
end


POMDPs.actions(mdp::PyramidMDP)::Vector{Symbol} = mdp.actions
POMDPs.actions(mdp::PyramidMDP, ss::PyramidState) = mdp.actions
POMDPs.actionindex(mdp::PyramidMDP, a::Symbol)::Int64 = findall(x->x==a, POMDPs.actions(mdp))[1]
POMDPs.actionindex(mdp::PyramidMDP, a::Symbol)::Int64 = findfirst(x->x==a, POMDPs.actions(mdp))

###############################
# FiniteHorizonPOMDPs interface
###############################
FiniteHorizonPOMDPs.stage(mdp::PyramidMDP, ss::PyramidState) = ss.epoch

function FiniteHorizonPOMDPs.stage_states(mdp::PyramidMDP, stage::Int)
mdp_states = PyramidState[]
for i=1:stage
push!(mdp_states, PyramidState(i, stage))
end

return mdp_states
return (PyramidState(i, stage) for i in 1:stage)
end

FiniteHorizonPOMDPs.stage_stateindex(mdp::PyramidMDP, ss::PyramidState) = ss.position
Expand All @@ -90,7 +68,7 @@ FiniteHorizonPOMDPs.stage_stateindex(mdp::PyramidMDP, ss::PyramidState) = ss.pos
# Forwarded parts of POMDPs interface
###############################
function isreward(mdp::PyramidMDP, ss::PyramidState)::Bool
return [ss.epoch, ss.position] in mdp.reward_states
return (ss.epoch, ss.position) in mdp.reward_states
end

function POMDPs.reward(mdp::PyramidMDP, ss::PyramidState, a::Symbol, sp::PyramidState)::Float64
Expand Down
8 changes: 3 additions & 5 deletions test/pyramidtest.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using POMDPModelTools

_horizon = 10
actions = [:l, :r]
actionCost = 1.
Expand All @@ -14,7 +12,7 @@ fhsolver = FiniteHorizonSolver()
# MDPs initialization
mdp = PyramidMDP(_horizon, actions, actionCost, actionsImpact, reward_states, reward, discount_factor, noise)

# check implementation of required methods
# check implementation of required methods
# @POMDPLinter.show_requirements FiniteHorizonPOMDPs.solve(fhsolver, mdp)

# initialize the solver
Expand All @@ -33,10 +31,10 @@ FHPolicy = FiniteHorizonPOMDPs.solve(fhsolver, mdp);
# Compare resulting policies
@test all((FiniteHorizonPOMDPs.action(FHPolicy, s) == action(VIPolicy, s) for s in states(mdp)))

# Compare FHMDP and IHMDP states
# Compare FHMDP and IHMDP states
fh_states = Iterators.flatten([FiniteHorizonPOMDPs.stage_states(mdp, i) for i=1:FiniteHorizonPOMDPs.horizon(mdp) + 1])
ih_states = states(mdp)

z = zip(fh_states, ih_states)

@test all((fh == ih for (fh, ih) in z))
@test all((fh == ih for (fh, ih) in z))

0 comments on commit de466a5

Please sign in to comment.