Skip to content

Commit

Permalink
fixed bug with state action reward
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Sep 11, 2020
1 parent e1f5080 commit e491890
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 23 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Expand Up @@ -17,7 +17,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Distributions = "0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23"
POMDPLinter = "0.1"
POMDPModels = "0.4.7"
POMDPs = "0.8, 0.9"
POMDPs = "0.9"
UnicodePlots = "1"
julia = "1"

Expand All @@ -26,8 +26,9 @@ BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
DiscreteValueIteration = "4b033969-44f6-5439-a48b-c11fa3648068"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BeliefUpdaters", "POMDPModels", "POMDPPolicies", "Test", "Pkg", "POMDPSimulators"]
test = ["BeliefUpdaters", "POMDPModels", "POMDPPolicies", "Test", "Pkg", "POMDPSimulators", "DiscreteValueIteration"]
2 changes: 1 addition & 1 deletion src/POMDPModelTools.jl
Expand Up @@ -26,7 +26,7 @@ include("visualization.jl")
export
action_info,
solve_info,
update_info,
update_info
include("info.jl")

export
Expand Down
4 changes: 0 additions & 4 deletions src/fully_observable_pomdp.jl
Expand Up @@ -43,7 +43,3 @@ POMDPs.convert_a(T::Type{A}, vec::V, pomdp::FullyObservablePOMDP) where {A,V<:Ab
POMDPs.reward(pomdp::FullyObservablePOMDP, s, a) = reward(pomdp.mdp, s, a)
POMDPs.initialstate(m::FullyObservablePOMDP) = initialstate(m.mdp)
POMDPs.initialobs(m::FullyObservablePOMDP, s) = Deterministic(s)

# deprecated in POMDPs v0.9
POMDPs.initialstate_distribution(pomdp::FullyObservablePOMDP) = initialstate_distribution(pomdp.mdp)
POMDPs.initialstate(pomdp::FullyObservablePOMDP, rng::AbstractRNG) = initialstate(pomdp.mdp, rng)
4 changes: 2 additions & 2 deletions src/state_action_reward.jl
Expand Up @@ -70,7 +70,7 @@ function mean_reward(m::MDP, s, a)
rsum = 0.0
wsum = 0.0
for (sp, w) in weighted_iterator(td)
rsum += reward(m, s, a, sp)
rsum += w*reward(m, s, a, sp)
wsum += w
end
return rsum/wsum
Expand All @@ -87,7 +87,7 @@ function mean_reward(m::POMDP, s, a)
for (sp, w) in weighted_iterator(td)
od = observation(m, s, a, sp)
for (o, ow) in weighted_iterator(od)
rsum += reward(m, s, a, sp, o)
rsum += ow*w*reward(m, s, a, sp, o)
wsum += ow*w
end
end
Expand Down
14 changes: 5 additions & 9 deletions test/runtests.jl
Expand Up @@ -9,6 +9,7 @@ using BeliefUpdaters
using POMDPPolicies
import Distributions.Categorical
using SparseArrays
using DiscreteValueIteration

@testset "POMDPModelTools" begin
@testset "ordered" begin
Expand Down Expand Up @@ -52,15 +53,10 @@ using SparseArrays
include("test_obs_weight.jl")
end

# require DiscreteValueIteration
@warn("skipping value iteration smoke testing - this should be replaced or re-enabled")
# @testset "visolve" begin
# POMDPs.add_registry()
# Pkg.add("DiscreteValueIteration")
# using DiscreteValueIteration
# include("test_fully_observable_pomdp.jl")
# include("test_underlying_mdp.jl")
# end
@testset "visolve" begin
include("test_fully_observable_pomdp.jl")
include("test_underlying_mdp.jl")
end

@testset "vis" begin
include("test_visualization.jl")
Expand Down
6 changes: 3 additions & 3 deletions test/test_fully_observable_pomdp.jl
Expand Up @@ -9,8 +9,8 @@ let
@test observations(pomdp) == states(pomdp)
@test statetype(pomdp) == obstype(pomdp)

s_po = initialstate(pomdp, MersenneTwister(1))
s_mdp = initialstate(mdp, MersenneTwister(1))
s_po = rand(MersenneTwister(1), initialstate(pomdp))
s_mdp = rand(MersenneTwister(1), initialstate(mdp))
@test s_po == s_mdp

solver = ValueIterationSolver(max_iterations = 100)
Expand All @@ -20,7 +20,7 @@ let
mdp_policy.util == pomdp_policy.util
end

is = initialstate(mdp, MersenneTwister(3))
is = rand(MersenneTwister(3), initialstate(mdp))
for (sp, o, r) in stepthrough(pomdp,
FunctionPolicy(o->:left),
PreviousObservationUpdater(),
Expand Down
4 changes: 2 additions & 2 deletions test/test_info.jl
Expand Up @@ -49,7 +49,7 @@ let
=#

pomdp = TigerPOMDP()
s = initialstate(pomdp, rng)
s = rand(rng, initialstate(pomdp))

up = VoidUpdater()
policy = RandomPolicy(rng, pomdp)
Expand All @@ -62,6 +62,6 @@ let
d = initialstate_distribution(pomdp)
b = initialize_belief(up, d)
a = action(policy, b)
sp, o, r = gen(DDNOut(:sp,:o,:r), pomdp, rand(rng, d), a, rng)
sp, o, r = @gen(:sp,:o,:r)(pomdp, rand(rng, d), a, rng)
@inferred update_info(up, b, a, o)
end
16 changes: 16 additions & 0 deletions test/test_reward_model.jl
Expand Up @@ -23,3 +23,19 @@ POMDPs.stateindex(m::RewardModelPOMDP1, s) = s
POMDPs.actionindex(m::RewardModelPOMDP1, a) = a
rm = StateActionReward(RewardModelPOMDP1())
@test rm(1, 1) == 2

m = BabyPOMDP()
rm = LazyCachedSAR(m)
for s in states(m)
for a in actions(m)
@test reward(m, s, a) == rm(s, a)
end
end

m = SimpleGridWorld()
rm = LazyCachedSAR(m)
for s in states(m)
for a in actions(m)
@test reward(m, s, a) == rm(s, a)
end
end

0 comments on commit e491890

Please sign in to comment.