From e49189016ad0d86138ea99df14820038bb958a6e Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Fri, 11 Sep 2020 15:19:38 -0600 Subject: [PATCH] fixed bug with state action reward --- Project.toml | 5 +++-- src/POMDPModelTools.jl | 2 +- src/fully_observable_pomdp.jl | 4 ---- src/state_action_reward.jl | 4 ++-- test/runtests.jl | 14 +++++--------- test/test_fully_observable_pomdp.jl | 6 +++--- test/test_info.jl | 4 ++-- test/test_reward_model.jl | 16 ++++++++++++++++ 8 files changed, 32 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index 2a9a677..3a2bd4a 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Distributions = "0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23" POMDPLinter = "0.1" POMDPModels = "0.4.7" -POMDPs = "0.8, 0.9" +POMDPs = "0.9" UnicodePlots = "1" julia = "1" @@ -26,8 +26,9 @@ BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd" +DiscreteValueIteration = "4b033969-44f6-5439-a48b-c11fa3648068" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["BeliefUpdaters", "POMDPModels", "POMDPPolicies", "Test", "Pkg", "POMDPSimulators"] +test = ["BeliefUpdaters", "POMDPModels", "POMDPPolicies", "Test", "Pkg", "POMDPSimulators", "DiscreteValueIteration"] diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index 4b8a0e3..44936fe 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -26,7 +26,7 @@ include("visualization.jl") export action_info, solve_info, - update_info, + update_info include("info.jl") export diff --git a/src/fully_observable_pomdp.jl b/src/fully_observable_pomdp.jl index ce60ca7..46e10a9 100644 --- a/src/fully_observable_pomdp.jl +++ b/src/fully_observable_pomdp.jl @@ -43,7 +43,3 @@ POMDPs.convert_a(T::Type{A}, vec::V, pomdp::FullyObservablePOMDP) where {A,V<:Ab POMDPs.reward(pomdp::FullyObservablePOMDP, s, a) = reward(pomdp.mdp, s, a) POMDPs.initialstate(m::FullyObservablePOMDP) = initialstate(m.mdp) POMDPs.initialobs(m::FullyObservablePOMDP, s) = Deterministic(s) - -# deprecated in POMDPs v0.9 -POMDPs.initialstate_distribution(pomdp::FullyObservablePOMDP) = initialstate_distribution(pomdp.mdp) -POMDPs.initialstate(pomdp::FullyObservablePOMDP, rng::AbstractRNG) = initialstate(pomdp.mdp, rng) diff --git a/src/state_action_reward.jl b/src/state_action_reward.jl index bc85843..5ca1ad8 100644 --- a/src/state_action_reward.jl +++ b/src/state_action_reward.jl @@ -70,7 +70,7 @@ function mean_reward(m::MDP, s, a) rsum = 0.0 wsum = 0.0 for (sp, w) in weighted_iterator(td) - rsum += reward(m, s, a, sp) + rsum += w*reward(m, s, a, sp) wsum += w end return rsum/wsum @@ -87,7 +87,7 @@ function mean_reward(m::POMDP, s, a) for (sp, w) in weighted_iterator(td) od = observation(m, s, a, sp) for (o, ow) in weighted_iterator(od) - rsum += reward(m, s, a, sp, o) + rsum += ow*w*reward(m, s, a, sp, o) wsum += ow*w end end diff --git a/test/runtests.jl b/test/runtests.jl index 093bf56..6046f79 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,7 @@ using BeliefUpdaters using POMDPPolicies import Distributions.Categorical using SparseArrays +using DiscreteValueIteration @testset "POMDPModelTools" begin @testset "ordered" begin @@ -52,15 +53,10 @@ using SparseArrays include("test_obs_weight.jl") end - # require DiscreteValueIteration - @warn("skipping value iteration smoke testing - this should be replaced or re-enabled") - # @testset "visolve" begin - # POMDPs.add_registry() - # Pkg.add("DiscreteValueIteration") - # using DiscreteValueIteration - # include("test_fully_observable_pomdp.jl") - # include("test_underlying_mdp.jl") - # end + @testset "visolve" begin + include("test_fully_observable_pomdp.jl") + include("test_underlying_mdp.jl") + end @testset "vis" begin include("test_visualization.jl") diff --git a/test/test_fully_observable_pomdp.jl b/test/test_fully_observable_pomdp.jl index b9a412a..071d13f 100644 --- a/test/test_fully_observable_pomdp.jl +++ b/test/test_fully_observable_pomdp.jl @@ -9,8 +9,8 @@ let @test observations(pomdp) == states(pomdp) @test statetype(pomdp) == obstype(pomdp) - s_po = initialstate(pomdp, MersenneTwister(1)) - s_mdp = initialstate(mdp, MersenneTwister(1)) + s_po = rand(MersenneTwister(1), initialstate(pomdp)) + s_mdp = rand(MersenneTwister(1), initialstate(mdp)) @test s_po == s_mdp solver = ValueIterationSolver(max_iterations = 100) @@ -20,7 +20,7 @@ let mdp_policy.util == pomdp_policy.util end - is = initialstate(mdp, MersenneTwister(3)) + is = rand(MersenneTwister(3), initialstate(mdp)) for (sp, o, r) in stepthrough(pomdp, FunctionPolicy(o->:left), PreviousObservationUpdater(), diff --git a/test/test_info.jl b/test/test_info.jl index 9f80115..c2eb945 100644 --- a/test/test_info.jl +++ b/test/test_info.jl @@ -49,7 +49,7 @@ let =# pomdp = TigerPOMDP() - s = initialstate(pomdp, rng) + s = rand(rng, initialstate(pomdp)) up = VoidUpdater() policy = RandomPolicy(rng, pomdp) @@ -62,6 +62,6 @@ let d = initialstate_distribution(pomdp) b = initialize_belief(up, d) a = action(policy, b) - sp, o, r = gen(DDNOut(:sp,:o,:r), pomdp, rand(rng, d), a, rng) + sp, o, r = @gen(:sp,:o,:r)(pomdp, rand(rng, d), a, rng) @inferred update_info(up, b, a, o) end diff --git a/test/test_reward_model.jl b/test/test_reward_model.jl index 260ad01..eb9378a 100644 --- a/test/test_reward_model.jl +++ b/test/test_reward_model.jl @@ -23,3 +23,19 @@ POMDPs.stateindex(m::RewardModelPOMDP1, s) = s POMDPs.actionindex(m::RewardModelPOMDP1, a) = a rm = StateActionReward(RewardModelPOMDP1()) @test rm(1, 1) == 2 + +m = BabyPOMDP() +rm = LazyCachedSAR(m) +for s in states(m) + for a in actions(m) + @test reward(m, s, a) == rm(s, a) + end +end + +m = SimpleGridWorld() +rm = LazyCachedSAR(m) +for s in states(m) + for a in actions(m) + @test reward(m, s, a) == rm(s, a) + end +end