From 545fd3ec2d1c140ccdba20fdfaadbeb1cacd8834 Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Mon, 16 Sep 2019 16:12:49 -0700 Subject: [PATCH 1/9] removed n_states, n_actions, ... --- Project.toml | 1 + src/POMDPModelTools.jl | 6 +- src/fully_observable_pomdp.jl | 3 - src/ordered_spaces.jl | 21 +++--- src/policy_evaluation.jl | 11 +-- src/sparse_tabular.jl | 54 +++++++------- src/underlying_mdp.jl | 2 - test/runtests.jl | 111 ++++++++++++++-------------- test/test_fully_observable_pomdp.jl | 2 - test/test_ordered_spaces.jl | 12 +-- test/test_tabular.jl | 23 +++--- test/test_underlying_mdp.jl | 1 - 12 files changed, 117 insertions(+), 130 deletions(-) diff --git a/Project.toml b/Project.toml index 0d9163b..efcc3c7 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [compat] Distributions = ">= 0.17" +POMDPs = "0.7.3, 0.9.0" julia = "1" [extras] diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index d9c0aab..8ca73f0 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -6,9 +6,9 @@ using LinearAlgebra using SparseArrays using UnicodePlots -import POMDPs: actions, n_actions, actionindex -import POMDPs: states, n_states, stateindex -import POMDPs: observations, n_observations, obsindex +import POMDPs: actions, actionindex +import POMDPs: states, stateindex +import POMDPs: observations, obsindex import POMDPs: sampletype, generate_sr, initialstate, isterminal, discount import POMDPs: implemented import Distributions: pdf, mode, mean, support diff --git a/src/fully_observable_pomdp.jl b/src/fully_observable_pomdp.jl index 90fe461..a7defe3 100644 --- a/src/fully_observable_pomdp.jl +++ b/src/fully_observable_pomdp.jl @@ -8,7 +8,6 @@ struct FullyObservablePOMDP{S, A} <: POMDP{S,A,S} end POMDPs.observations(pomdp::FullyObservablePOMDP) = states(pomdp.mdp) -POMDPs.n_observations(pomdp::FullyObservablePOMDP) = n_states(pomdp.mdp) POMDPs.obsindex(pomdp::FullyObservablePOMDP{S, A}, o::S) where {S, A} = stateindex(pomdp.mdp, o) POMDPs.convert_o(T::Type{V}, o, pomdp::FullyObservablePOMDP) where {V<:AbstractArray} = convert_s(T, s, pomdp.mdp) @@ -39,8 +38,6 @@ POMDPs.generate_sr(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) = genera POMDPs.reward(pomdp::FullyObservablePOMDP{S, A}, s::S, a::A) where {S,A} = reward(pomdp.mdp, s, a) POMDPs.isterminal(pomdp::FullyObservablePOMDP, s) = isterminal(pomdp.mdp, s) POMDPs.discount(pomdp::FullyObservablePOMDP) = discount(pomdp.mdp) -POMDPs.n_states(pomdp::FullyObservablePOMDP) = n_states(pomdp.mdp) -POMDPs.n_actions(pomdp::FullyObservablePOMDP) = n_actions(pomdp.mdp) POMDPs.stateindex(pomdp::FullyObservablePOMDP{S,A}, s::S) where {S,A} = stateindex(pomdp.mdp, s) POMDPs.actionindex(pomdp::FullyObservablePOMDP{S, A}, a::A) where {S,A} = actionindex(pomdp.mdp, a) POMDPs.convert_s(T::Type{V}, s, pomdp::FullyObservablePOMDP) where V<:AbstractArray = convert_s(T, s, pomdp.mdp) diff --git a/src/ordered_spaces.jl b/src/ordered_spaces.jl index 7df6d23..df759e4 100644 --- a/src/ordered_spaces.jl +++ b/src/ordered_spaces.jl @@ -7,7 +7,7 @@ Return an `AbstractVector` of actions ordered according to `actionindex(mdp, a)` `ordered_actions(mdp)` will always return an `AbstractVector{A}` `v` containing all of the actions in `actions(mdp)` in the order such that `actionindex(mdp, v[i]) == i`. You may wish to override this for your problem for efficiency. """ -ordered_actions(mdp::Union{MDP,POMDP}) = ordered_vector(actiontype(typeof(mdp)), a->actionindex(mdp,a), actions(mdp), n_actions(mdp), "action") +ordered_actions(mdp::Union{MDP,POMDP}) = ordered_vector(actiontype(typeof(mdp)), a->actionindex(mdp,a), actions(mdp), "action") """ ordered_states(mdp) @@ -16,7 +16,7 @@ Return an `AbstractVector` of states ordered according to `stateindex(mdp, a)`. `ordered_states(mdp)` will always return a `AbstractVector{A}` `v` containing all of the states in `states(mdp)` in the order such that `stateindex(mdp, v[i]) == i`. You may wish to override this for your problem for efficiency. """ -ordered_states(mdp::Union{MDP,POMDP}) = ordered_vector(statetype(typeof(mdp)), s->stateindex(mdp,s), states(mdp), n_states(mdp), "state") +ordered_states(mdp::Union{MDP,POMDP}) = ordered_vector(statetype(typeof(mdp)), s->stateindex(mdp,s), states(mdp), "state") """ ordered_observations(pomdp) @@ -25,9 +25,10 @@ Return an `AbstractVector` of observations ordered according to `obsindex(pomdp, `ordered_observations(mdp)` will always return a `AbstractVector{A}` `v` containing all of the observations in `observations(pomdp)` in the order such that `obsindex(pomdp, v[i]) == i`. You may wish to override this for your problem for efficiency. """ -ordered_observations(pomdp::POMDP) = ordered_vector(obstype(typeof(pomdp)), o->obsindex(pomdp,o), observations(pomdp), n_observations(pomdp), "observation") +ordered_observations(pomdp::POMDP) = ordered_vector(obstype(typeof(pomdp)), o->obsindex(pomdp,o), observations(pomdp), "observation") -function ordered_vector(T::Type, index::Function, space, len, singular, plural=singular*"s") +function ordered_vector(T::Type, index::Function, space, singular, plural=singular*"s") + len = length(space) a = Array{T}(undef, len) gotten = falses(len) for x in space @@ -39,7 +40,7 @@ function ordered_vector(T::Type, index::Function, space, len, singular, plural=s index was $id. n_$plural(...) was $len. - """) + """) end a[id] = x gotten[id] = true @@ -60,23 +61,23 @@ end @POMDP_require ordered_actions(mdp::Union{MDP,POMDP}) begin P = typeof(mdp) @req actionindex(::P, ::actiontype(P)) - @req n_actions(::P) @req actions(::P) as = actions(mdp) + @req length(::typeof(as)) end @POMDP_require ordered_states(mdp::Union{MDP,POMDP}) begin P = typeof(mdp) @req stateindex(::P, ::statetype(P)) - @req n_states(::P) @req states(::P) - as = states(mdp) + ss = states(mdp) + @req length(::typeof(ss)) end @POMDP_require ordered_observations(mdp::Union{MDP,POMDP}) begin P = typeof(mdp) @req obsindex(::P, ::obstype(P)) - @req n_observations(::P) @req observations(::P) - as = observations(mdp) + os = observations(mdp) + @req length(::typeof(os)) end diff --git a/src/policy_evaluation.jl b/src/policy_evaluation.jl index e501dc7..dfe90db 100644 --- a/src/policy_evaluation.jl +++ b/src/policy_evaluation.jl @@ -41,12 +41,12 @@ Create an |S|x|S| sparse transition matrix for a given policy. The row corresponds to the current state and column to the next state. Corresponds to ``T^π`` in equation (4.7) in Kochenderfer, *Decision Making Under Uncertainty*, 2015. """ function policy_transition_matrix(m::Union{MDP,POMDP}, p::Policy) - ns = n_states(m) rows = Int[] cols = Int[] probs = Float64[] - - for s in states(m) + state_space = states(m) + ns = length(state_space) + for s in state_space if !isterminal(m, s) # if terminal, the transition probabilities are all just zero si = stateindex(m, s) a = action(p, s) @@ -66,8 +66,9 @@ function policy_transition_matrix(m::Union{MDP,POMDP}, p::Policy) end function policy_reward_vector(m::Union{MDP,POMDP}, p::Policy; rewardfunction=POMDPs.reward) - r = zeros(n_states(m)) - for s in states(m) + state_space = states(m) + r = zeros(length(state_space)) + for s in state_space if !isterminal(m, s) # if terminal, the transition probabilities are all just zero si = stateindex(m, s) a = action(p, s) diff --git a/src/sparse_tabular.jl b/src/sparse_tabular.jl index 48405e8..3aaa725 100644 --- a/src/sparse_tabular.jl +++ b/src/sparse_tabular.jl @@ -38,8 +38,6 @@ end S = statetype(P) A = actiontype(P) @req discount(::P) - @req n_states(::P) - @req n_actions(::P) @subreq ordered_states(mdp) @subreq ordered_actions(mdp) @req transition(::P,::S,::A) @@ -49,6 +47,8 @@ end @req actions(::P, ::S) as = actions(mdp) ss = states(mdp) + @req length(::typeof(as)) + @req length(::typeof(ss)) a = first(as) s = first(ss) dist = transition(mdp, s, a) @@ -113,8 +113,6 @@ end A = actiontype(P) O = obstype(P) @req discount(::P) - @req n_states(::P) - @req n_actions(::P) @subreq ordered_states(pomdp) @subreq ordered_actions(pomdp) @subreq ordered_observations(pomdp) @@ -128,6 +126,8 @@ end @req obsindex(::P, ::O) as = actions(pomdp) ss = states(pomdp) + @req length(::typeof(as)) + @req length(::typeof(ss)) a = first(as) s = first(ss) dist = transition(pomdp, s, a) @@ -160,8 +160,9 @@ const SparseTabularProblem = Union{SparseTabularMDP, SparseTabularPOMDP} function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP}) # Thanks to zach - na = n_actions(mdp) - ns = n_states(mdp) + na = length(actions(mdp)) + state_space = states(mdp) + ns = length(state_space) transmat_row_A = [Int[] for _ in 1:n_actions(mdp)] transmat_col_A = [Int[] for _ in 1:n_actions(mdp)] transmat_data_A = [Float64[] for _ in 1:n_actions(mdp)] @@ -187,15 +188,17 @@ function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP}) end end end - transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)] + transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], ns, ns) for a in 1:na] # if an action is not valid from a state, the transition is 0.0 everywhere - # @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1" + # @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(ns)) for a in 1:na) "Transition probabilities must sum to 1" return transmats_A_S_S2 end function reward_s_a(mdp::Union{MDP, POMDP}) - reward_S_A = fill(-Inf, (n_states(mdp), n_actions(mdp))) # set reward for all actions to -Inf unless they are in actions(mdp, s) - for s in states(mdp) + state_space = states(mdp) + action_space = actions(mdp) + reward_S_A = fill(-Inf, (length(state_space), length(action_space))) # set reward for all actions to -Inf unless they are in actions(mdp, s) + for s in state_space if isterminal(mdp, s) reward_S_A[stateindex(mdp, s), :] .= 0.0 else @@ -227,12 +230,15 @@ function terminal_states_set(mdp::Union{MDP, POMDP}) end function observation_matrix_a_sp_o(pomdp::POMDP) - na = n_actions(pomdp) - ns = n_states(pomdp) - no = n_observations(pomdp) - obsmat_row_A = [Int[] for _ in 1:n_actions(pomdp)] - obsmat_col_A = [Int[] for _ in 1:n_actions(pomdp)] - obsmat_data_A = [Float64[] for _ in 1:n_actions(pomdp)] + state_space = states(pomdp) + action_space = actions(pomdp) + obs_space = observations(pomdp) + na = length(action_space) + ns = length(state_space) + no = length(obs_space) + obsmat_row_A = [Int[] for _ in 1:na] + obsmat_col_A = [Int[] for _ in 1:na] + obsmat_data_A = [Float64[] for _ in 1:na] for sp in states(pomdp) spi = stateindex(pomdp, sp) @@ -249,19 +255,16 @@ function observation_matrix_a_sp_o(pomdp::POMDP) end end end - obsmats_A_SP_O = [sparse(obsmat_row_A[a], obsmat_col_A[a], obsmat_data_A[a], n_states(pomdp), n_states(pomdp)) for a in 1:n_actions(pomdp)] - @assert all(all(sum(obsmats_A_SP_O[a], dims=2) .≈ ones(n_observations(pomdp))) for a in 1:n_actions(pomdp)) "Observation probabilities must sum to 1" + obsmats_A_SP_O = [sparse(obsmat_row_A[a], obsmat_col_A[a], obsmat_data_A[a], ns, ns) for a in 1:na] + @assert all(all(sum(obsmats_A_SP_O[a], dims=2) .≈ ones(no)) for a in 1:na) "Observation probabilities must sum to 1" return obsmats_A_SP_O end # MDP and POMDP common methods -POMDPs.n_states(prob::SparseTabularProblem) = size(prob.T[1], 1) -POMDPs.n_actions(prob::SparseTabularProblem) = size(prob.T, 1) - -POMDPs.states(p::SparseTabularProblem) = 1:n_states(p) -POMDPs.actions(p::SparseTabularProblem) = 1:n_actions(p) -POMDPs.actions(p::SparseTabularProblem, s::Int64) = [a for a in actions(p) if sum(transition_matrix(p, a)) ≈ n_states(p)] +POMDPs.states(p::SparseTabularProblem) = 1:size(p.T[1], 1) +POMDPs.actions(p::SparseTabularProblem) = 1:size(p.T, 1) +POMDPs.actions(p::SparseTabularProblem, s::Int64) = [a for a in actions(p) if sum(transition_matrix(p, a)) ≈ size(p.T[1], 1)] POMDPs.stateindex(::SparseTabularProblem, s::Int64) = s POMDPs.actionindex(::SparseTabularProblem, a::Int64) = a @@ -303,9 +306,8 @@ Accessor function for the reward matrix R[s, a] of a sparse tabular problem. reward_matrix(p::SparseTabularProblem) = p.R # POMDP only methods -POMDPs.n_observations(p::SparseTabularPOMDP) = size(p.O[1], 2) -POMDPs.observations(p::SparseTabularPOMDP) = 1:n_observations(p) +POMDPs.observations(p::SparseTabularPOMDP) = 1:size(p.O[1], 2) POMDPs.observation(p::SparseTabularPOMDP, a::Int64, sp::Int64) = SparseCat(findnz(p.O[a][sp, :])...) diff --git a/src/underlying_mdp.jl b/src/underlying_mdp.jl index 99bfcd4..3428b87 100644 --- a/src/underlying_mdp.jl +++ b/src/underlying_mdp.jl @@ -29,8 +29,6 @@ POMDPs.reward(mdp::UnderlyingMDP{P, S, A}, s::S, a::A) where {P,S,A} = reward(md POMDPs.reward(mdp::UnderlyingMDP{P, S, A}, s::S, a::A, sp::S) where {P,S,A} = reward(mdp.pomdp, s, a, sp) POMDPs.isterminal(mdp ::UnderlyingMDP{P, S, A}, s::S) where {P,S,A} = isterminal(mdp.pomdp, s) POMDPs.discount(mdp::UnderlyingMDP) = discount(mdp.pomdp) -POMDPs.n_actions(mdp::UnderlyingMDP) = n_actions(mdp.pomdp) -POMDPs.n_states(mdp::UnderlyingMDP) = n_states(mdp.pomdp) POMDPs.stateindex(mdp::UnderlyingMDP{P, S, A}, s::S) where {P,S,A} = stateindex(mdp.pomdp, s) POMDPs.stateindex(mdp::UnderlyingMDP{P, Int, A}, s::Int) where {P,A} = stateindex(mdp.pomdp, s) # fix ambiguity with src/convenience POMDPs.stateindex(mdp::UnderlyingMDP{P, Bool, A}, s::Bool) where {P,A} = stateindex(mdp.pomdp, s) diff --git a/test/runtests.jl b/test/runtests.jl index 4a5a324..6786e8e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,65 +9,68 @@ using POMDPPolicies import Distributions.Categorical using SparseArrays -@testset "ordered" begin - include("test_ordered_spaces.jl") -end +@testset "POMDPModelTools" begin + @testset "ordered" begin + include("test_ordered_spaces.jl") + end -# require POMDPModels -@testset "genbeliefmdp" begin - include("test_generative_belief_mdp.jl") -end -@testset "implement" begin - include("test_implementations.jl") -end -@testset "weightediter" begin - include("test_weighted_iteration.jl") -end -@testset "sparsecat" begin - include("test_sparse_cat.jl") -end -@testset "bool" begin - include("test_bool.jl") -end -@testset "deterministic" begin - include("test_deterministic.jl") -end -@testset "uniform" begin - include("test_uniform.jl") -end -@testset "terminalstate" begin - include("test_terminal_state.jl") -end + # require POMDPModels + @testset "genbeliefmdp" begin + include("test_generative_belief_mdp.jl") + end + @testset "implement" begin + include("test_implementations.jl") + end + @testset "weightediter" begin + include("test_weighted_iteration.jl") + end + @testset "sparsecat" begin + include("test_sparse_cat.jl") + end + @testset "bool" begin + include("test_bool.jl") + end + @testset "deterministic" begin + include("test_deterministic.jl") + end + @testset "uniform" begin + include("test_uniform.jl") + end + @testset "terminalstate" begin + include("test_terminal_state.jl") + end -# require POMDPModels -@testset "info" begin - include("test_info.jl") -end -@testset "obsweight" begin - include("test_obs_weight.jl") -end + # require POMDPModels + @testset "info" begin + include("test_info.jl") + end + @testset "obsweight" begin + include("test_obs_weight.jl") + end -# require DiscreteValueIteration -@testset "visolve" begin - POMDPs.add_registry() - Pkg.add("DiscreteValueIteration") - using DiscreteValueIteration - include("test_fully_observable_pomdp.jl") - include("test_underlying_mdp.jl") -end + # require DiscreteValueIteration + @testset "visolve" begin + POMDPs.add_registry() + Pkg.add("DiscreteValueIteration") + using DiscreteValueIteration + include("test_fully_observable_pomdp.jl") + include("test_underlying_mdp.jl") + end -@testset "vis" begin - include("test_visualization.jl") -end + @testset "vis" begin + include("test_visualization.jl") + end -@testset "evaluation" begin - include("test_evaluation.jl") -end + @testset "evaluation" begin + include("test_evaluation.jl") + end -@testset "pretty printing" begin - include("test_pretty_printing.jl") -end + @testset "pretty printing" begin + include("test_pretty_printing.jl") + end + + @testset "sparse tabular" begin + include("test_tabular.jl") + end -@testset "sparse tabular" begin - include("test_tabular.jl") end \ No newline at end of file diff --git a/test/test_fully_observable_pomdp.jl b/test/test_fully_observable_pomdp.jl index 4a19022..654220e 100644 --- a/test/test_fully_observable_pomdp.jl +++ b/test/test_fully_observable_pomdp.jl @@ -4,11 +4,9 @@ let pomdp = FullyObservablePOMDP(mdp) @test observations(pomdp) == states(pomdp) - @test n_observations(pomdp) == n_states(pomdp) @test statetype(pomdp) == obstype(pomdp) @test observations(pomdp) == states(pomdp) - @test n_observations(pomdp) == n_states(pomdp) @test statetype(pomdp) == obstype(pomdp) s_po = initialstate(pomdp, MersenneTwister(1)) diff --git a/test/test_ordered_spaces.jl b/test/test_ordered_spaces.jl index 7cc840b..51bd17b 100644 --- a/test/test_ordered_spaces.jl +++ b/test/test_ordered_spaces.jl @@ -2,9 +2,7 @@ let struct TigerPOMDPTestFixture <: POMDP{Bool, Int, Bool} end POMDPs.states(::TigerPOMDPTestFixture) = (true, false) POMDPs.stateindex(::TigerPOMDPTestFixture, s) = Int(s) + 1 - POMDPs.n_states(::TigerPOMDPTestFixture) = 2 POMDPs.actions(m::TigerPOMDPTestFixture) = 0:2 - POMDPs.n_actions(m::TigerPOMDPTestFixture) = 3 POMDPs.actionindex(m::TigerPOMDPTestFixture, s::Int) = s+1 POMDPs.observations(::TigerPOMDPTestFixture) = (true, false) POMDPs.obsindex(::TigerPOMDPTestFixture, o) = Int(o) + 1 @@ -19,14 +17,6 @@ end struct TM <: POMDP{Int, Int, Int} end POMDPs.states(::TM) = [1,3] -POMDPs.n_states(::TM) = 2 POMDPs.stateindex(::TM, s::Int) = s -@test_throws ErrorException ordered_states(TM()) - -struct TM2 <: POMDP{Int, Int, Int} end -POMDPs.states(::TM2) = [1,3] -POMDPs.n_states(::TM2) = 3 -POMDPs.stateindex(::TM2, s::Int) = s - -@test_logs (:warn,) ordered_states(TM2()) +@test_throws ErrorException ordered_states(TM()) \ No newline at end of file diff --git a/test/test_tabular.jl b/test/test_tabular.jl index 9ba8b07..f7f8f70 100644 --- a/test/test_tabular.jl +++ b/test/test_tabular.jl @@ -53,17 +53,17 @@ smdp2 = SparseTabularMDP(smdp) @test smdp2.R == smdp.R @test smdp2.discount == smdp.discount -smdp3 = SparseTabularMDP(smdp, reward = zeros(n_states(mdp), n_actions(mdp))) +smdp3 = SparseTabularMDP(smdp, reward = zeros(length(states(mdp)), length(actions(mdp)))) @test smdp3.T == smdp.T @test smdp3.R != smdp.R -smdp3 = SparseTabularMDP(smdp, transition = [sparse(1:n_states(mdp), 1:n_states(mdp), 1.0) for a in 1:n_actions(mdp)]) +smdp3 = SparseTabularMDP(smdp, transition = [sparse(1:length(states(mdp)), 1:length(states(mdp)), 1.0) for a in 1:length(actions(mdp))]) @test smdp3.T != smdp.T @test smdp3.R == smdp.R smdp3 = SparseTabularMDP(smdp, discount = 0.5) @test smdp3.discount != smdp.discount -@test size(transition_matrix(smdp, 1)) == (n_states(smdp), n_states(smdp)) -@test length(reward_vector(smdp, 1)) == n_states(smdp) +@test size(transition_matrix(smdp, 1)) == (length(states(smdp)), length(states(smdp))) +@test length(reward_vector(smdp, 1)) == length(states(smdp)) gw = SimpleGridWorld() sparsegw = SparseTabularMDP(gw) @@ -86,27 +86,24 @@ spomdp2 = SparseTabularPOMDP(spomdp) @test spomdp2.O == spomdp.O @test spomdp2.discount == spomdp.discount -spomdp3 = SparseTabularPOMDP(spomdp, reward = zeros(n_states(mdp), n_actions(mdp))) +spomdp3 = SparseTabularPOMDP(spomdp, reward = zeros(length(states(mdp)), length(actions(mdp)))) @test spomdp3.T == spomdp.T @test spomdp3.R != spomdp.R -spomdp3 = SparseTabularPOMDP(spomdp, transition = [sparse(1:n_states(mdp), 1:n_states(mdp), 1.0) for a in 1:n_actions(mdp)]) +spomdp3 = SparseTabularPOMDP(spomdp, transition = [sparse(1:length(states(mdp)), 1:length(states(mdp)), 1.0) for a in 1:length(actions(mdp))]) @test spomdp3.T != spomdp.T @test spomdp3.R == spomdp.R spomdp3 = SparseTabularPOMDP(spomdp, discount = 0.5) @test spomdp3.discount != spomdp.discount -@test size(observation_matrix(spomdp, 1)) == (n_states(spomdp), n_observations(spomdp)) +@test size(observation_matrix(spomdp, 1)) == (length(states(spomdp)), length(observations(spomdp))) @test observation_matrices(spomdp) == spomdp2.O @test transition_matrices(spomdp) == spomdp2.T @test reward_matrix(spomdp) == spomdp2.R ## Tests -@test n_states(pomdp) == n_states(spomdp) -@test n_actions(pomdp) == n_actions(spomdp) -@test n_observations(pomdp) == n_observations(spomdp) -@test length(states(spomdp)) == n_states(spomdp) -@test length(actions(spomdp)) == n_actions(spomdp) -@test length(observations(spomdp)) == n_observations(spomdp) +@test length(states(pomdp)) == length(states(spomdp)) +@test length(actions(pomdp)) == length(actions(spomdp)) +@test length(observations(pomdp)) == length(observations(spomdp)) @test statetype(spomdp) == Int64 @test actiontype(spomdp) == Int64 @test obstype(spomdp) == Int64 diff --git a/test/test_underlying_mdp.jl b/test/test_underlying_mdp.jl index 26ed853..3cfeca9 100644 --- a/test/test_underlying_mdp.jl +++ b/test/test_underlying_mdp.jl @@ -3,7 +3,6 @@ let mdp = UnderlyingMDP(pomdp) - @test n_states(mdp) == n_states(pomdp) @test states(mdp) == states(pomdp) s_mdp = rand(MersenneTwister(1), initialstate_distribution(mdp)) s_pomdp = rand(MersenneTwister(1), initialstate_distribution(pomdp)) From 7ee8c2985c89444360bf40640b433e57d2fbaf87 Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Tue, 17 Sep 2019 10:14:01 -0700 Subject: [PATCH 2/9] restore n_states, n_actions --- src/fully_observable_pomdp.jl | 2 ++ src/underlying_mdp.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/fully_observable_pomdp.jl b/src/fully_observable_pomdp.jl index a7defe3..455fbf9 100644 --- a/src/fully_observable_pomdp.jl +++ b/src/fully_observable_pomdp.jl @@ -44,3 +44,5 @@ POMDPs.convert_s(T::Type{V}, s, pomdp::FullyObservablePOMDP) where V<:AbstractAr POMDPs.convert_s(T::Type{S}, vec::V, pomdp::FullyObservablePOMDP) where {S,V<:AbstractArray} = convert_s(T, vec, pomdp.mdp) POMDPs.convert_a(T::Type{V}, a, pomdp::FullyObservablePOMDP) where V<:AbstractArray = convert_a(T, a, pomdp.mdp) POMDPs.convert_a(T::Type{A}, vec::V, pomdp::FullyObservablePOMDP) where {A,V<:AbstractArray} = convert_a(T, vec, pomdp.mdp) +POMDPs.n_actions(pomdp::FullyObservablePOMDP) = n_actions(pomdp.mdp) +POMDPs.n_states(pomdp::FullyObservablePOMDP) = n_states(pomdp.mdp) \ No newline at end of file diff --git a/src/underlying_mdp.jl b/src/underlying_mdp.jl index 3428b87..6e0c384 100644 --- a/src/underlying_mdp.jl +++ b/src/underlying_mdp.jl @@ -35,3 +35,5 @@ POMDPs.stateindex(mdp::UnderlyingMDP{P, Bool, A}, s::Bool) where {P,A} = statein POMDPs.actionindex(mdp::UnderlyingMDP{P, S, A}, a::A) where {P,S,A} = actionindex(mdp.pomdp, a) POMDPs.actionindex(mdp::UnderlyingMDP{P,S, Int}, a::Int) where {P,S} = actionindex(mdp.pomdp, a) POMDPs.actionindex(mdp::UnderlyingMDP{P,S, Bool}, a::Bool) where {P,S} = actionindex(mdp.pomdp, a) +POMDPs.n_actions(mdp::UnderlyingMDP) = n_actions(mdp.pomdp) +POMDPs.n_states(mdp::UnderlyingMDP) = n_states(mdp.pomdp) \ No newline at end of file From d621206276bfe0b99bbba056d2ceb61b68b70a0b Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Wed, 18 Sep 2019 16:53:51 -0700 Subject: [PATCH 3/9] added info update (no docs yet) --- Project.toml | 2 +- src/POMDPModelTools.jl | 7 ++-- src/info.jl | 64 +++++++++++++++++++++-------- test/runtests.jl | 2 +- test/test_fully_observable_pomdp.jl | 2 +- test/test_info.jl | 12 ++++-- 6 files changed, 62 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index efcc3c7..dbab86b 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [compat] Distributions = ">= 0.17" -POMDPs = "0.7.3, 0.9.0" +POMDPs = "0.7.3, 0.8.0" julia = "1" [extras] diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index 8ca73f0..dae3bef 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -22,11 +22,12 @@ include("visualization.jl") # info interface export - generate_sri, - generate_sori, + add_infonode, action_info, solve_info, - update_info + update_info, + generate_sri, + generate_sori include("info.jl") export diff --git a/src/info.jl b/src/info.jl index 4f3f3b3..f97f807 100644 --- a/src/info.jl +++ b/src/info.jl @@ -1,24 +1,6 @@ # functions for passing out info from simulations, similar to the info return from openai gym # maintained by @zsunberg -""" -Return a tuple containing the next state and reward and information (usually a `NamedTuple`, `Dict` or `nothing`) from that step. - -By default, returns `nothing` as info. -""" -function generate_sri(p::MDP, s, a, rng::AbstractRNG) - return generate_sr(p, s, a, rng)..., nothing -end - -""" -Return a tuple containing the next state, observation, and reward and information (usually a `NamedTuple`, `Dict` or `nothing`) from that step. - -By default, returns `nothing` as info. -""" -function generate_sori(p::POMDP, s, a, rng::AbstractRNG) - return generate_sor(p, s, a, rng)..., nothing -end - """ a, ai = action_info(policy, x) @@ -51,3 +33,49 @@ By default, returns `nothing` as info. function update_info(up::Updater, b, a, o) return update(up, b, a, o), nothing end + +""" + add_infonode(ddn::DDNStructure) + +Create a new DDNStructure object with a new node labeled :info with parents :s and :a +""" +function add_infonode(ddn) # for DDNStructure, but it is not declared in v0.7.3, so there is not annotation + add_node(ddn, :info, ConstantDDNNode(nothing), (:s, :a)) +end + +function add_infonode(ddn::POMDPs.DDNStructureV7{nodenames}) where nodenames + return POMDPs.DDNStructureV7{(nodenames..., :info)}() +end + +############################################################### +# Note all generate functions will be deprecated in POMDPs v0.8 +############################################################### + + +if DDNStructure(MDP) isa POMDPs.DDNStructureV7 + """ + Return a tuple containing the next state and reward and information (usually a `NamedTuple`, `Dict` or `nothing`) from that step. + + By default, returns `nothing` as info. + """ + function generate_sri(p::MDP, s, a, rng::AbstractRNG) + return generate_sr(p, s, a, rng)..., nothing + end + + """ + Return a tuple containing the next state, observation, and reward and information (usually a `NamedTuple`, `Dict` or `nothing`) from that step. + + By default, returns `nothing` as info. + """ + function generate_sori(p::POMDP, s, a, rng::AbstractRNG) + return generate_sor(p, s, a, rng)..., nothing + end + + POMDPs.gen(::DDNOut{(:sp,:o,:r,:i)}, m, s, a, rng) = generate_sori(m, s, a, rng) + POMDPs.gen(::DDNOut{(:sp,:o,:r,:info)}, m, s, a, rng) = generate_sori(m, s, a, rng) + POMDPs.gen(::DDNOut{(:sp,:r,:i)}, m, s, a, rng) = generate_sri(m, s, a, rng) + POMDPs.gen(::DDNOut{(:sp,:r,:info)}, m, s, a, rng) = generate_sri(m, s, a, rng) +else + @deprecate generate_sri(args...) gen(DDNOut(:sp,:r,:info), args...) + @deprecate generate_sori(args...) gen(DDNOut(:sp,:o,:r,:info), args...) +end diff --git a/test/runtests.jl b/test/runtests.jl index 6786e8e..85f6a44 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -73,4 +73,4 @@ using SparseArrays include("test_tabular.jl") end -end \ No newline at end of file +end diff --git a/test/test_fully_observable_pomdp.jl b/test/test_fully_observable_pomdp.jl index 654220e..82b32bd 100644 --- a/test/test_fully_observable_pomdp.jl +++ b/test/test_fully_observable_pomdp.jl @@ -1,5 +1,5 @@ let - mdp = GridWorld() + mdp = SimpleGridWorld() pomdp = FullyObservablePOMDP(mdp) diff --git a/test/test_info.jl b/test/test_info.jl index 21007be..9c7d73b 100644 --- a/test/test_info.jl +++ b/test/test_info.jl @@ -23,14 +23,20 @@ let rng = MersenneTwister(7) mdp = LegacyGridWorld() + POMDPs.DDNStructure(::Type{typeof(mdp)}) = DDNStructure(MDP) |> add_infonode + @test :info in nodenames(DDNStructure(mdp)) s = initialstate(mdp, rng) a = rand(rng, actions(mdp)) - @inferred generate_sri(mdp, s, a, rng) + sp, r, i = @inferred gen(DDNOut(:sp,:r,:info), mdp, s, a, rng) + @test i === nothing pomdp = TigerPOMDP() + POMDPs.DDNStructure(::Type{typeof(pomdp)}) = DDNStructure(POMDP) |> add_infonode + @test :info in nodenames(DDNStructure(pomdp)) s = initialstate(pomdp, rng) a = rand(rng, actions(pomdp)) - @inferred generate_sori(pomdp, s, a, rng) + sp, o, r, i = @inferred gen(DDNOut(:sp,:o,:r,:info), pomdp, s, a, rng) + @test i === nothing up = VoidUpdater() policy = RandomPolicy(rng, pomdp) @@ -43,6 +49,6 @@ let d = initialstate_distribution(pomdp) b = initialize_belief(up, d) a = action(policy, b) - sp, o = generate_so(pomdp, rand(rng, d), a, rng) + sp, o, r = gen(DDNOut(:sp,:o,:r), pomdp, rand(rng, d), a, rng) @inferred update_info(up, b, a, o) end From 5d007da3ca9a5d63b1b3992f6b8b086d4e6e6b6d Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Wed, 18 Sep 2019 17:03:48 -0700 Subject: [PATCH 4/9] only make value iteration solve an mdp --- test/test_fully_observable_pomdp.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fully_observable_pomdp.jl b/test/test_fully_observable_pomdp.jl index 82b32bd..fa0fe3a 100644 --- a/test/test_fully_observable_pomdp.jl +++ b/test/test_fully_observable_pomdp.jl @@ -15,6 +15,6 @@ let solver = ValueIterationSolver(max_iterations = 100) mdp_policy = solve(solver, mdp) - pomdp_policy = solve(solver, pomdp) + pomdp_policy = solve(solver, UnderlyingMDP(pomdp)) @test mdp_policy.util == pomdp_policy.util end From 3052882b20abcc7c119767c19fefba327003303a Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Wed, 18 Sep 2019 17:35:27 -0700 Subject: [PATCH 5/9] updating docs --- Project.toml | 2 +- docs/make.jl | 8 ++------ docs/mkdocs.yml | 23 ----------------------- docs/src/index.md | 2 +- docs/src/interface_extensions.md | 2 -- src/info.jl | 23 ++++++++++++++++++++++- test/test_underlying_mdp.jl | 2 +- 7 files changed, 27 insertions(+), 35 deletions(-) delete mode 100644 docs/mkdocs.yml diff --git a/Project.toml b/Project.toml index dbab86b..469f8c6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "POMDPModelTools" uuid = "08074719-1b2a-587c-a292-00f91cc44415" authors = ["JuliaPOMDP Contributors"] -version = "0.1.6" +version = "0.2.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/docs/make.jl b/docs/make.jl index 4d144b8..846fa6c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,3 +1,5 @@ +push!(LOAD_PATH, "../src/") + using Documenter, POMDPModelTools makedocs( @@ -8,10 +10,4 @@ makedocs( deploydocs( repo = "github.com/JuliaPOMDP/POMDPModelTools.jl.git", - julia = "1.0", - osname = "linux", - target = "build", - deps = nothing, - make = nothing ) - diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml deleted file mode 100644 index 73669ff..0000000 --- a/docs/mkdocs.yml +++ /dev/null @@ -1,23 +0,0 @@ -site_name: POMDPModelTools.jl -repo_url: https://github.com/JuliaPOMDP/POMDPModelTools.jl -site_description: Interface extensions and tools for POMDPs.jl models and solvers. -site_authors: JuliaPOMDP - -theme: readthedocs - -extra_css: - - assets/Documenter.css - -markdown_extensions: - - extra - - tables - - fenced_code - #- mdx_math - -extra_javascript: - - https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML - - assets/mathjaxhelper.js - -docs_dir: 'build' - -# do NOT include pages - it will automatically discover them diff --git a/docs/src/index.md b/docs/src/index.md index 5e9964c..c706d4f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,6 +1,6 @@ # About -POMDPModelTools is a collection of interface extensions and tools to make writing models and solvers for [POMDPs.jl](github.com/JuliaPOMDP/POMDPs.jl) easier. +POMDPModelTools is a collection of interface extensions and tools to make writing models and solvers for [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl) easier. ```@contents ``` diff --git a/docs/src/interface_extensions.md b/docs/src/interface_extensions.md index baa9d81..f057ea2 100644 --- a/docs/src/interface_extensions.md +++ b/docs/src/interface_extensions.md @@ -36,8 +36,6 @@ ordered_observations It is often the case that useful information besides the belief, state, action, etc is generated by a function in POMDPs.jl. This information can be useful for debugging or understanding the behavior of a solver, updater, or problem. The info interface provides a standard way for problems, policies, solvers or updaters to output this information. The recording simulators from [POMDPSimulators.jl](https://github.com/JuliaPOMDP/POMDPSimulators.jl) automatically record this information. ```@docs -generate_sri -generate_sori action_info solve_info update_info diff --git a/src/info.jl b/src/info.jl index f97f807..0df4b5c 100644 --- a/src/info.jl +++ b/src/info.jl @@ -37,7 +37,28 @@ end """ add_infonode(ddn::DDNStructure) -Create a new DDNStructure object with a new node labeled :info with parents :s and :a +Create a new DDNStructure object with a new node labeled :info for returning miscellaneous informationabout a simulation step. + +# Example (using POMDPs v0.8) + +``` +using POMDPs, POMDPModelTools, POMDPPolicies + +struct MyMDP <: MDP{Int, Int} end +POMDPs.DDNStructure(::Type{MyMDP}) = DDNStructure(MDP) |> add_infonode +function POMDPs.gen(m::MyMDP, s, a, rng) + r1 = rand(rng) + r2 = randn(rng) + return (sp = s + a + r1 + r2, r = s^2, info=(r1=r1, r2=r2)) +end + +m = MyMDP() +@show nodenames(DDNStructure(m)) +for (s,info) in stepthrough(m, FunctionPolicy(s->1), "s,info", max_steps=5) + @show s + @show info +end +``` """ function add_infonode(ddn) # for DDNStructure, but it is not declared in v0.7.3, so there is not annotation add_node(ddn, :info, ConstantDDNNode(nothing), (:s, :a)) diff --git a/test/test_underlying_mdp.jl b/test/test_underlying_mdp.jl index 3cfeca9..4e8b4e0 100644 --- a/test/test_underlying_mdp.jl +++ b/test/test_underlying_mdp.jl @@ -11,7 +11,7 @@ let solver = ValueIterationSolver(max_iterations = 100) mdp_policy = solve(solver, mdp) - pomdp_policy = solve(solver, pomdp) + pomdp_policy = solve(solver, UnderlyingMDP(pomdp)) @test mdp_policy.util == pomdp_policy.util actionindex(mdp, 1) From d5f2579d3cec3ebf07018091285b176798108e69 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Wed, 18 Sep 2019 18:29:46 -0700 Subject: [PATCH 6/9] added docs for info --- docs/make.jl | 2 +- docs/src/interface_extensions.md | 8 ++++++++ src/info.jl | 20 ++++++++++++++------ 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 846fa6c..d6b10db 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -4,7 +4,7 @@ using Documenter, POMDPModelTools makedocs( modules = [POMDPModelTools], - format = :html, + format = Documenter.HTML(), sitename = "POMDPModelTools.jl" ) diff --git a/docs/src/interface_extensions.md b/docs/src/interface_extensions.md index f057ea2..016bdda 100644 --- a/docs/src/interface_extensions.md +++ b/docs/src/interface_extensions.md @@ -35,6 +35,14 @@ ordered_observations It is often the case that useful information besides the belief, state, action, etc is generated by a function in POMDPs.jl. This information can be useful for debugging or understanding the behavior of a solver, updater, or problem. The info interface provides a standard way for problems, policies, solvers or updaters to output this information. The recording simulators from [POMDPSimulators.jl](https://github.com/JuliaPOMDP/POMDPSimulators.jl) automatically record this information. +To specify info for a problem (in POMDPs v0.8 and above), one should modify the problem's DDN with the `add_infonode` function, then return the info in `gen`. There is an example of this pattern in the docstring below: + +```@docs +add_infonode +``` + +To specify info from policies, solvers, or updaters, implement the following functions: + ```@docs action_info solve_info diff --git a/src/info.jl b/src/info.jl index 0df4b5c..0daddd3 100644 --- a/src/info.jl +++ b/src/info.jl @@ -34,34 +34,42 @@ function update_info(up::Updater, b, a, o) return update(up, b, a, o), nothing end +# once POMDPs v0.8 is released, this should be a jldoctest """ add_infonode(ddn::DDNStructure) Create a new DDNStructure object with a new node labeled :info for returning miscellaneous informationabout a simulation step. +Typically, the object in info is associative (i.e. a `Dict` or `NamedTuple`) with keys corresponding to different pieces of information. + # Example (using POMDPs v0.8) -``` -using POMDPs, POMDPModelTools, POMDPPolicies +```julia +using POMDPs, POMDPModelTools, POMDPPolicies, POMDPSimulators, Random struct MyMDP <: MDP{Int, Int} end -POMDPs.DDNStructure(::Type{MyMDP}) = DDNStructure(MDP) |> add_infonode + +# add the info node to the DDN +POMDPs.DDNStructure(::Type{MyMDP}) = mdp_ddn() |> add_infonode + +# the dynamics involve two random numbers - here we record the values for each in info function POMDPs.gen(m::MyMDP, s, a, rng) r1 = rand(rng) r2 = randn(rng) - return (sp = s + a + r1 + r2, r = s^2, info=(r1=r1, r2=r2)) + return (sp=s+a+r1+r2, r=s^2, info=(r1=r1, r2=r2)) end m = MyMDP() @show nodenames(DDNStructure(m)) -for (s,info) in stepthrough(m, FunctionPolicy(s->1), "s,info", max_steps=5) +p = FunctionPolicy(s->1) +for (s,info) in stepthrough(m, p, 1, "s,info", max_steps=5, rng=MersenneTwister(2)) @show s @show info end ``` """ function add_infonode(ddn) # for DDNStructure, but it is not declared in v0.7.3, so there is not annotation - add_node(ddn, :info, ConstantDDNNode(nothing), (:s, :a)) + add_node(ddn, :info, ConstantDDNNode(nothing), nodenames(ddn)) end function add_infonode(ddn::POMDPs.DDNStructureV7{nodenames}) where nodenames From fd9511875951452b1d72447ffbc15b1fc6882e49 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Wed, 18 Sep 2019 19:27:38 -0700 Subject: [PATCH 7/9] cleaned up some stuff for 0.8 compat, but not ready yet --- src/distributions/bool.jl | 5 ---- src/distributions/deterministic.jl | 1 + src/distributions/sparse_cat.jl | 1 + src/distributions/uniform.jl | 2 ++ src/fully_observable_pomdp.jl | 40 ++++++++++++++++++++++-------- src/generative_belief_mdp.jl | 11 +++++--- src/underlying_mdp.jl | 11 +++++--- test/test_bool.jl | 3 ++- test/test_deterministic.jl | 8 +++--- test/test_sparse_cat.jl | 8 +++--- test/test_uniform.jl | 16 ++++++------ 11 files changed, 67 insertions(+), 39 deletions(-) diff --git a/src/distributions/bool.jl b/src/distributions/bool.jl index 50af37b..7399d5b 100644 --- a/src/distributions/bool.jl +++ b/src/distributions/bool.jl @@ -23,11 +23,6 @@ function Base.iterate(d::BoolDistribution, state::Bool) end support(d::BoolDistribution) = [true, false] - -==(d1::BoolDistribution, d2::BoolDistribution) = d1.p == d2.p - -Base.hash(d::BoolDistribution) = hash(d.p) - Base.length(d::BoolDistribution) = 2 Base.show(io::IO, m::MIME"text/plain", d::BoolDistribution) = showdistribution(io, m, d, title="BoolDistribution") diff --git a/src/distributions/deterministic.jl b/src/distributions/deterministic.jl index 5b59482..6f72a82 100644 --- a/src/distributions/deterministic.jl +++ b/src/distributions/deterministic.jl @@ -13,6 +13,7 @@ rand(rng::AbstractRNG, d::Deterministic) = d.val rand(d::Deterministic) = d.val support(d::Deterministic) = (d.val,) sampletype(::Type{Deterministic{T}}) where T = T +Random.gentype(::Type{Deterministic{T}}) where T = T pdf(d::Deterministic, x) = convert(Float64, x == d.val) mode(d::Deterministic) = d.val mean(d::Deterministic{N}) where N<:Number = d.val / 1 # / 1 is to make this return a similar type to Statistics.mean diff --git a/src/distributions/sparse_cat.jl b/src/distributions/sparse_cat.jl index 37b3f46..fc80287 100644 --- a/src/distributions/sparse_cat.jl +++ b/src/distributions/sparse_cat.jl @@ -90,6 +90,7 @@ end Base.length(d::SparseCat) = min(length(d.vals), length(d.probs)) Base.eltype(D::Type{SparseCat{V,P}}) where {V, P} = Pair{eltype(V), eltype(P)} sampletype(D::Type{SparseCat{V,P}}) where {V, P} = eltype(V) +Random.gentype(D::Type{SparseCat{V,P}}) where {V, P} = eltype(V) function mean(d::SparseCat) vsum = zero(eltype(d.vals)) diff --git a/src/distributions/uniform.jl b/src/distributions/uniform.jl index 65bd8c8..5c56473 100644 --- a/src/distributions/uniform.jl +++ b/src/distributions/uniform.jl @@ -24,6 +24,7 @@ end support(d::Uniform) = d.set sampletype(::Type{Uniform{T}}) where T = eltype(T) +Random.gentype(::Type{Uniform{T}}) where T = eltype(T) function pdf(d::Uniform, s) if s in d.set @@ -49,6 +50,7 @@ end pdf(d::UnsafeUniform, s) = 1.0/length(d.collection) support(d::UnsafeUniform) = d.collection sampletype(::Type{UnsafeUniform{T}}) where T = eltype(T) +Random.gentype(::Type{UnsafeUniform{T}}) where T = eltype(T) # Common Implementations diff --git a/src/fully_observable_pomdp.jl b/src/fully_observable_pomdp.jl index 455fbf9..286dea1 100644 --- a/src/fully_observable_pomdp.jl +++ b/src/fully_observable_pomdp.jl @@ -3,8 +3,18 @@ Turn `MDP` `mdp` into a `POMDP` where the observations are the states of the MDP. """ -struct FullyObservablePOMDP{S, A} <: POMDP{S,A,S} - mdp::MDP{S, A} +struct FullyObservablePOMDP{M,S,A} <: POMDP{S,A,S} + mdp::M +end + +function FullyObservablePOMDP(m::MDP) + return FullyObservablePOMDP{typeof(m), statetype(m), actiontype(m)}(m) +end +mdptype(::Type{FullyObservablePOMDP{M,<:Any,<:Any}}) where M = M + +function POMDPs.DDNStructure(::Type{M}) where M <: FullyObservablePOMDP + MM = mdptype(M) + add_node(DDNStructure(MM), :o, FunctionDDNNode((m,sp)->sp), :sp) end POMDPs.observations(pomdp::FullyObservablePOMDP) = states(pomdp.mdp) @@ -14,12 +24,10 @@ POMDPs.convert_o(T::Type{V}, o, pomdp::FullyObservablePOMDP) where {V<:AbstractA POMDPs.convert_o(T::Type{S}, vec::V, pomdp::FullyObservablePOMDP) where {S,V<:AbstractArray} = convert_s(T, vec, pomdp.mdp) -function POMDPs.generate_o(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) - return s -end +POMDPs.gen(::DDNVar{:o}, m::FullyObservablePOMDP, sp, rng) = sp -function POMDPs.observation(pomdp::FullyObservablePOMDP, s, a) - return Deterministic(s) +function POMDPs.observation(pomdp::FullyObservablePOMDP, a, sp) + return Deterministic(sp) end function POMDPs.observation(pomdp::FullyObservablePOMDP, s, a, sp) @@ -33,9 +41,6 @@ POMDPs.actions(pomdp::FullyObservablePOMDP) = actions(pomdp.mdp) POMDPs.transition(pomdp::FullyObservablePOMDP{S,A}, s::S, a::A) where {S,A} = transition(pomdp.mdp, s, a) POMDPs.initialstate_distribution(pomdp::FullyObservablePOMDP) = initialstate_distribution(pomdp.mdp) POMDPs.initialstate(pomdp::FullyObservablePOMDP, rng::AbstractRNG) = initialstate(pomdp.mdp, rng) -POMDPs.generate_s(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) = generate_s(pomdp.mdp, s, a, rng) -POMDPs.generate_sr(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) = generate_sr(pomdp.mdp, s, a, rng) -POMDPs.reward(pomdp::FullyObservablePOMDP{S, A}, s::S, a::A) where {S,A} = reward(pomdp.mdp, s, a) POMDPs.isterminal(pomdp::FullyObservablePOMDP, s) = isterminal(pomdp.mdp, s) POMDPs.discount(pomdp::FullyObservablePOMDP) = discount(pomdp.mdp) POMDPs.stateindex(pomdp::FullyObservablePOMDP{S,A}, s::S) where {S,A} = stateindex(pomdp.mdp, s) @@ -44,5 +49,18 @@ POMDPs.convert_s(T::Type{V}, s, pomdp::FullyObservablePOMDP) where V<:AbstractAr POMDPs.convert_s(T::Type{S}, vec::V, pomdp::FullyObservablePOMDP) where {S,V<:AbstractArray} = convert_s(T, vec, pomdp.mdp) POMDPs.convert_a(T::Type{V}, a, pomdp::FullyObservablePOMDP) where V<:AbstractArray = convert_a(T, a, pomdp.mdp) POMDPs.convert_a(T::Type{A}, vec::V, pomdp::FullyObservablePOMDP) where {A,V<:AbstractArray} = convert_a(T, vec, pomdp.mdp) + +POMDPs.gen(d::DDNOut, m::FullyObservablePOMDP, s, a, rng) = gen(d, m.mdp, s, a, rng) +POMDPs.gen(d::DDNNode, m::FullyObservablePOMDP, args...) = gen(d, m.mdp, args...) +POMDPs.gen(m::FullyObservablePOMDP, s, a, rng) = gen(m.mdp, s, a, rng) + + +# deprecated in POMDPs v0.8 +POMDPs.generate_s(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) = generate_s(pomdp.mdp, s, a, rng) +POMDPs.generate_sr(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) = generate_sr(pomdp.mdp, s, a, rng) +POMDPs.reward(pomdp::FullyObservablePOMDP{S, A}, s::S, a::A) where {S,A} = reward(pomdp.mdp, s, a) POMDPs.n_actions(pomdp::FullyObservablePOMDP) = n_actions(pomdp.mdp) -POMDPs.n_states(pomdp::FullyObservablePOMDP) = n_states(pomdp.mdp) \ No newline at end of file +POMDPs.n_states(pomdp::FullyObservablePOMDP) = n_states(pomdp.mdp) +function POMDPs.generate_o(pomdp::FullyObservablePOMDP, s, rng::AbstractRNG) + return s +end diff --git a/src/generative_belief_mdp.jl b/src/generative_belief_mdp.jl index cbebc98..945e8b8 100644 --- a/src/generative_belief_mdp.jl +++ b/src/generative_belief_mdp.jl @@ -14,15 +14,20 @@ function GenerativeBeliefMDP(pomdp::P, up::U) where {P<:POMDP, U<:Updater} GenerativeBeliefMDP{P, U, typeof(b0), actiontype(pomdp)}(pomdp, up) end -function generate_sr(bmdp::GenerativeBeliefMDP, b, a, rng::AbstractRNG) +function POMDPs.gen(bmdp::GenerativeBeliefMDP, b, a, rng::AbstractRNG) s = rand(rng, b) if isterminal(bmdp.pomdp, s) bp = gbmdp_handle_terminal(bmdp.pomdp, bmdp.updater, b, s, a, rng::AbstractRNG)::typeof(b) return bp, 0.0 end - sp, o, r = generate_sor(bmdp.pomdp, s, a, rng) # maybe this should have been generate_or? + sp, o, r = gen(DDNOut(:sp,:o,:r), bmdp.pomdp, s, a, rng) # maybe this should have been generate_or? bp = update(bmdp.updater, b, a, o) - return bp, r + return (sp=bp, r=r) +end + +function generate_sr(bmdp::GenerativeBeliefMDP, b, a, rng::AbstractRNG) + x = gen(bmdp, b, a, rng) + return x.sp, x.r end function initialstate(bmdp::GenerativeBeliefMDP, rng::AbstractRNG) diff --git a/src/underlying_mdp.jl b/src/underlying_mdp.jl index 6e0c384..dba4dda 100644 --- a/src/underlying_mdp.jl +++ b/src/underlying_mdp.jl @@ -20,8 +20,6 @@ UnderlyingMDP(m::MDP) = m POMDPs.transition(mdp::UnderlyingMDP{P, S, A}, s::S, a::A) where {P,S,A}= transition(mdp.pomdp, s, a) POMDPs.initialstate_distribution(mdp::UnderlyingMDP) = initialstate_distribution(mdp.pomdp) -POMDPs.generate_s(mdp::UnderlyingMDP, s, a, rng::AbstractRNG) = generate_s(mdp.pomdp, s, a, rng) -POMDPs.generate_sr(mdp::UnderlyingMDP, s, a, rng::AbstractRNG) = generate_sr(mdp.pomdp, s, a, rng) POMDPs.initialstate(mdp::UnderlyingMDP, rng::AbstractRNG) = initialstate(mdp.pomdp, rng) POMDPs.states(mdp::UnderlyingMDP) = states(mdp.pomdp) POMDPs.actions(mdp::UnderlyingMDP) = actions(mdp.pomdp) @@ -35,5 +33,12 @@ POMDPs.stateindex(mdp::UnderlyingMDP{P, Bool, A}, s::Bool) where {P,A} = statein POMDPs.actionindex(mdp::UnderlyingMDP{P, S, A}, a::A) where {P,S,A} = actionindex(mdp.pomdp, a) POMDPs.actionindex(mdp::UnderlyingMDP{P,S, Int}, a::Int) where {P,S} = actionindex(mdp.pomdp, a) POMDPs.actionindex(mdp::UnderlyingMDP{P,S, Bool}, a::Bool) where {P,S} = actionindex(mdp.pomdp, a) + +POMDPs.gen(d::Union{DDNOut,DDNNode}, mdp::UnderlyingMDP, s, a, rng) = gen(d, mdp.pomdp, s, a, rng) +POMDPs.gen(mdp::UnderlyingMDP, s, a, rng) = gen(m.pomdp, s, a, rng) + +# deprecated in POMDPs v0.8 POMDPs.n_actions(mdp::UnderlyingMDP) = n_actions(mdp.pomdp) -POMDPs.n_states(mdp::UnderlyingMDP) = n_states(mdp.pomdp) \ No newline at end of file +POMDPs.n_states(mdp::UnderlyingMDP) = n_states(mdp.pomdp) +POMDPs.generate_s(mdp::UnderlyingMDP, s, a, rng::AbstractRNG) = generate_s(mdp.pomdp, s, a, rng) +POMDPs.generate_sr(mdp::UnderlyingMDP, s, a, rng::AbstractRNG) = generate_sr(mdp.pomdp, s, a, rng) diff --git a/test/test_bool.jl b/test/test_bool.jl index 50907d2..d151947 100644 --- a/test/test_bool.jl +++ b/test/test_bool.jl @@ -10,9 +10,10 @@ let # testing == d2 = BoolDistribution(0.3) @test d == d2 + @test BoolDistribution(0.4) != BoolDistribution(0.1) # testing hash - @test hash(d) == hash(d.p) + @test hash(d) == hash(d2) @test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="BoolDistribution"), d) end diff --git a/test/test_deterministic.jl b/test/test_deterministic.jl index 10b4ef3..de2c2ec 100644 --- a/test/test_deterministic.jl +++ b/test/test_deterministic.jl @@ -3,8 +3,8 @@ d = Deterministic(1) @test rand(d) == 1 @test rand(MersenneTwister(4), d) == 1 @test collect(support(d)) == [1] -@test sampletype(d) == typeof(1) -@test sampletype(typeof(d)) == typeof(1) +@test Random.gentype(d) == typeof(1) +@test Random.gentype(typeof(d)) == typeof(1) @test pdf(d, 0) == 0.0 @test pdf(d, 1) == 1.0 @test mode(d) == 1 @@ -15,8 +15,8 @@ d2 = Deterministic(:symbol) @test rand(d2) == :symbol @test rand(MersenneTwister(4), d2) == :symbol @test collect(support(d2)) == [:symbol] -@test sampletype(d2) == typeof(:symbol) -@test sampletype(typeof(d2)) == typeof(:symbol) +@test Random.gentype(d2) == typeof(:symbol) +@test Random.gentype(typeof(d2)) == typeof(:symbol) @test pdf(d2, :another) == 0.0 @test pdf(d2, :symbol) == 1.0 @test mode(d2) == :symbol diff --git a/test/test_sparse_cat.jl b/test/test_sparse_cat.jl index 3d5d93d..17891fc 100644 --- a/test/test_sparse_cat.jl +++ b/test/test_sparse_cat.jl @@ -5,8 +5,8 @@ let @test pdf(d, :c) == 0.0 @test pdf(d, :a) == 0.4 @test mode(d) == :b - @test sampletype(d) == Symbol - @test sampletype(typeof(d)) == Symbol + @test Random.gentype(d) == Symbol + @test Random.gentype(typeof(d)) == Symbol @inferred rand(Random.GLOBAL_RNG, d) dt = SparseCat((:a, :b, :d), (0.4, 0.5, 0.1)) @@ -15,8 +15,8 @@ let @test pdf(dt, :c) == 0.0 @test pdf(dt, :a) == 0.4 @test mode(dt) == :b - @test sampletype(dt) == Symbol - @test sampletype(typeof(dt)) == Symbol + @test Random.gentype(dt) == Symbol + @test Random.gentype(typeof(dt)) == Symbol @inferred rand(Random.GLOBAL_RNG, dt) rng = MersenneTwister(14) diff --git a/test/test_uniform.jl b/test/test_uniform.jl index 8fc8da5..df39602 100644 --- a/test/test_uniform.jl +++ b/test/test_uniform.jl @@ -3,8 +3,8 @@ d = Uniform([1]) @test rand(d) == 1 @test rand(MersenneTwister(4), d) == 1 @test collect(support(d)) == [1] -@test sampletype(d) == typeof(1) -@test sampletype(typeof(d)) == typeof(1) +@test Random.gentype(d) == typeof(1) +@test Random.gentype(typeof(d)) == typeof(1) @test pdf(d, 0) == 0.0 @test pdf(d, 1) == 1.0 @test mode(d) == 1 @@ -18,8 +18,8 @@ d2 = Uniform((:symbol,)) @test rand(d2) == :symbol @test rand(MersenneTwister(4), d2) == :symbol @test collect(support(d2)) == [:symbol] -@test sampletype(d2) == typeof(:symbol) -@test sampletype(typeof(d2)) == typeof(:symbol) +@test Random.gentype(d2) == typeof(:symbol) +@test Random.gentype(typeof(d2)) == typeof(:symbol) @test pdf(d2, :another) == 0.0 @test pdf(d2, :symbol) == 1.0 @test mode(d2) == :symbol @@ -35,8 +35,8 @@ d3 = UnsafeUniform([1]) @test rand(d3) == 1 @test rand(MersenneTwister(4), d3) == 1 @test collect(support(d3)) == [1] -@test sampletype(d3) == typeof(1) -@test sampletype(typeof(d3)) == typeof(1) +@test Random.gentype(d3) == typeof(1) +@test Random.gentype(typeof(d3)) == typeof(1) @test pdf(d3, 1) == 1.0 @test mean(d3) == 1 @test mode(d3) == 1 @@ -49,8 +49,8 @@ d4 = UnsafeUniform((:symbol,)) @test rand(d4) == :symbol @test rand(MersenneTwister(4), d4) == :symbol @test collect(support(d4)) == [:symbol] -@test sampletype(d4) == typeof(:symbol) -@test sampletype(typeof(d4)) == typeof(:symbol) +@test Random.gentype(d4) == typeof(:symbol) +@test Random.gentype(typeof(d4)) == typeof(:symbol) # @test pdf(d4, :another) == 0.0 # this will not work @test pdf(d4, :symbol) == 1.0 @test mode(d4) == :symbol From 56c6cdb56e8028b1e21298d8f9a4893873abd4e6 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Thu, 19 Sep 2019 17:14:06 -0700 Subject: [PATCH 8/9] tests all pass with both POMDPs 0.7.3 and 0.8 --- Project.toml | 3 ++- src/fully_observable_pomdp.jl | 23 +++++++++++++---------- src/underlying_mdp.jl | 3 ++- test/runtests.jl | 1 + test/test_fully_observable_pomdp.jl | 10 ++++++++++ test/test_underlying_mdp.jl | 4 ++++ 6 files changed, 32 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 469f8c6..0054d74 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ POMDPs = "0.7.3, 0.8.0" julia = "1" [extras] +BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd" @@ -25,4 +26,4 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "POMDPModels", "POMDPSimulators", "POMDPPolicies", "Pkg"] +test = ["Test", "POMDPModels", "POMDPSimulators", "POMDPPolicies", "BeliefUpdaters", "Pkg"] diff --git a/src/fully_observable_pomdp.jl b/src/fully_observable_pomdp.jl index 286dea1..135afd9 100644 --- a/src/fully_observable_pomdp.jl +++ b/src/fully_observable_pomdp.jl @@ -10,21 +10,23 @@ end function FullyObservablePOMDP(m::MDP) return FullyObservablePOMDP{typeof(m), statetype(m), actiontype(m)}(m) end -mdptype(::Type{FullyObservablePOMDP{M,<:Any,<:Any}}) where M = M + +mdptype(::Type{FullyObservablePOMDP{M,S,A}}) where {M,S,A} = M function POMDPs.DDNStructure(::Type{M}) where M <: FullyObservablePOMDP MM = mdptype(M) - add_node(DDNStructure(MM), :o, FunctionDDNNode((m,sp)->sp), :sp) + add_obsnode(DDNStructure(MM)) end +add_obsnode(ddn) = add_node(ddn, :o, FunctionDDNNode((m,sp)->sp), (:sp,)) # for ::DDNStructure, but this is not declared yet POMDPs in v0.7.3 + POMDPs.observations(pomdp::FullyObservablePOMDP) = states(pomdp.mdp) POMDPs.obsindex(pomdp::FullyObservablePOMDP{S, A}, o::S) where {S, A} = stateindex(pomdp.mdp, o) POMDPs.convert_o(T::Type{V}, o, pomdp::FullyObservablePOMDP) where {V<:AbstractArray} = convert_s(T, s, pomdp.mdp) POMDPs.convert_o(T::Type{S}, vec::V, pomdp::FullyObservablePOMDP) where {S,V<:AbstractArray} = convert_s(T, vec, pomdp.mdp) - -POMDPs.gen(::DDNVar{:o}, m::FullyObservablePOMDP, sp, rng) = sp +POMDPs.gen(::DDNNode{:o}, m::FullyObservablePOMDP, sp, rng) = sp function POMDPs.observation(pomdp::FullyObservablePOMDP, a, sp) return Deterministic(sp) @@ -38,27 +40,28 @@ end POMDPs.states(pomdp::FullyObservablePOMDP) = states(pomdp.mdp) POMDPs.actions(pomdp::FullyObservablePOMDP) = actions(pomdp.mdp) -POMDPs.transition(pomdp::FullyObservablePOMDP{S,A}, s::S, a::A) where {S,A} = transition(pomdp.mdp, s, a) +POMDPs.transition(pomdp::FullyObservablePOMDP, s, a) = transition(pomdp.mdp, s, a) POMDPs.initialstate_distribution(pomdp::FullyObservablePOMDP) = initialstate_distribution(pomdp.mdp) POMDPs.initialstate(pomdp::FullyObservablePOMDP, rng::AbstractRNG) = initialstate(pomdp.mdp, rng) POMDPs.isterminal(pomdp::FullyObservablePOMDP, s) = isterminal(pomdp.mdp, s) POMDPs.discount(pomdp::FullyObservablePOMDP) = discount(pomdp.mdp) -POMDPs.stateindex(pomdp::FullyObservablePOMDP{S,A}, s::S) where {S,A} = stateindex(pomdp.mdp, s) -POMDPs.actionindex(pomdp::FullyObservablePOMDP{S, A}, a::A) where {S,A} = actionindex(pomdp.mdp, a) +POMDPs.stateindex(pomdp::FullyObservablePOMDP, s) = stateindex(pomdp.mdp, s) +POMDPs.actionindex(pomdp::FullyObservablePOMDP, a) = actionindex(pomdp.mdp, a) POMDPs.convert_s(T::Type{V}, s, pomdp::FullyObservablePOMDP) where V<:AbstractArray = convert_s(T, s, pomdp.mdp) POMDPs.convert_s(T::Type{S}, vec::V, pomdp::FullyObservablePOMDP) where {S,V<:AbstractArray} = convert_s(T, vec, pomdp.mdp) POMDPs.convert_a(T::Type{V}, a, pomdp::FullyObservablePOMDP) where V<:AbstractArray = convert_a(T, a, pomdp.mdp) POMDPs.convert_a(T::Type{A}, vec::V, pomdp::FullyObservablePOMDP) where {A,V<:AbstractArray} = convert_a(T, vec, pomdp.mdp) -POMDPs.gen(d::DDNOut, m::FullyObservablePOMDP, s, a, rng) = gen(d, m.mdp, s, a, rng) POMDPs.gen(d::DDNNode, m::FullyObservablePOMDP, args...) = gen(d, m.mdp, args...) POMDPs.gen(m::FullyObservablePOMDP, s, a, rng) = gen(m.mdp, s, a, rng) - +POMDPs.reward(pomdp::FullyObservablePOMDP, s, a) = reward(pomdp.mdp, s, a) # deprecated in POMDPs v0.8 +add_obsnode(ddn::POMDPs.DDNStructureV7{(:s,:a,:sp,:r)}) = POMDPs.DDNStructureV7{(:s,:a,:sp,:o,:r)}() +add_obsnode(ddn::POMDPs.DDNStructureV7) = error("FullyObservablePOMDP only supports MDPs with the standard DDN Structure (DDNStructureV7{(:s,:a,:sp,:r)}) with POMDPs v0.7.") + POMDPs.generate_s(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) = generate_s(pomdp.mdp, s, a, rng) POMDPs.generate_sr(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) = generate_sr(pomdp.mdp, s, a, rng) -POMDPs.reward(pomdp::FullyObservablePOMDP{S, A}, s::S, a::A) where {S,A} = reward(pomdp.mdp, s, a) POMDPs.n_actions(pomdp::FullyObservablePOMDP) = n_actions(pomdp.mdp) POMDPs.n_states(pomdp::FullyObservablePOMDP) = n_states(pomdp.mdp) function POMDPs.generate_o(pomdp::FullyObservablePOMDP, s, rng::AbstractRNG) diff --git a/src/underlying_mdp.jl b/src/underlying_mdp.jl index dba4dda..cc56340 100644 --- a/src/underlying_mdp.jl +++ b/src/underlying_mdp.jl @@ -34,7 +34,8 @@ POMDPs.actionindex(mdp::UnderlyingMDP{P, S, A}, a::A) where {P,S,A} = actioninde POMDPs.actionindex(mdp::UnderlyingMDP{P,S, Int}, a::Int) where {P,S} = actionindex(mdp.pomdp, a) POMDPs.actionindex(mdp::UnderlyingMDP{P,S, Bool}, a::Bool) where {P,S} = actionindex(mdp.pomdp, a) -POMDPs.gen(d::Union{DDNOut,DDNNode}, mdp::UnderlyingMDP, s, a, rng) = gen(d, mdp.pomdp, s, a, rng) +POMDPs.gen(d::DDNOut, mdp::UnderlyingMDP, s, a, rng) = gen(d, mdp.pomdp, s, a, rng) +POMDPs.gen(d::DDNNode, mdp::UnderlyingMDP, s, a, rng) = gen(d, mdp.pomdp, s, a, rng) POMDPs.gen(mdp::UnderlyingMDP, s, a, rng) = gen(m.pomdp, s, a, rng) # deprecated in POMDPs v0.8 diff --git a/test/runtests.jl b/test/runtests.jl index 85f6a44..0ccdb9f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Random using Test using Pkg using POMDPSimulators +using BeliefUpdaters using POMDPPolicies import Distributions.Categorical using SparseArrays diff --git a/test/test_fully_observable_pomdp.jl b/test/test_fully_observable_pomdp.jl index fa0fe3a..084b1c8 100644 --- a/test/test_fully_observable_pomdp.jl +++ b/test/test_fully_observable_pomdp.jl @@ -17,4 +17,14 @@ let mdp_policy = solve(solver, mdp) pomdp_policy = solve(solver, UnderlyingMDP(pomdp)) @test mdp_policy.util == pomdp_policy.util + + is = initialstate(mdp, MersenneTwister(3)) + for (sp, o, r) in stepthrough(pomdp, + FunctionPolicy(o->:left), + PreviousObservationUpdater(), + is, is, "sp,o,r", + rng=MersenneTwister(2), + max_steps=10) + @test sp == o + end end diff --git a/test/test_underlying_mdp.jl b/test/test_underlying_mdp.jl index 4e8b4e0..5505d4e 100644 --- a/test/test_underlying_mdp.jl +++ b/test/test_underlying_mdp.jl @@ -16,6 +16,10 @@ let actionindex(mdp, 1) + for (sp, r) in stepthrough(mdp, FunctionPolicy(o->1), "sp,r", rng=MersenneTwister(2), max_steps=10) + @test sp isa statetype(pomdp) + end + # test mdp passthrough m = SimpleGridWorld() @test UnderlyingMDP(m) === m From 44cf4cd8b847ea84d0d5ced9b31b4ef7cc09ae2e Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Thu, 19 Sep 2019 17:36:42 -0700 Subject: [PATCH 9/9] don't test for inferrence in GenerativeBeliefMDP in Julia 1.0 --- .travis.yml | 1 + test/test_generative_belief_mdp.jl | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index d02387d..8b05c0b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ os: julia: - 1.0 + - 1 notifications: email: false diff --git a/test/test_generative_belief_mdp.jl b/test/test_generative_belief_mdp.jl index 0b175fc..617e7e7 100644 --- a/test/test_generative_belief_mdp.jl +++ b/test/test_generative_belief_mdp.jl @@ -4,5 +4,7 @@ let bmdp = GenerativeBeliefMDP(pomdp, up) b = initialstate(bmdp, Random.GLOBAL_RNG) - @inferred generate_sr(bmdp, b, true, MersenneTwister(4)) + if VERSION >= v"1.1" + @inferred generate_sr(bmdp, b, true, MersenneTwister(4)) + end end