Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POMDPs v0.8 Compatibility #26

Merged
merged 10 commits into from
Sep 20, 2019
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

[compat]
Distributions = ">= 0.17"
POMDPs = "0.7.3, 0.9.0"
julia = "1"

[extras]
Expand Down
6 changes: 3 additions & 3 deletions src/POMDPModelTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ using LinearAlgebra
using SparseArrays
using UnicodePlots

import POMDPs: actions, n_actions, actionindex
import POMDPs: states, n_states, stateindex
import POMDPs: observations, n_observations, obsindex
import POMDPs: actions, actionindex
import POMDPs: states, stateindex
import POMDPs: observations, obsindex
import POMDPs: sampletype, generate_sr, initialstate, isterminal, discount
import POMDPs: implemented
import Distributions: pdf, mode, mean, support
Expand Down
3 changes: 0 additions & 3 deletions src/fully_observable_pomdp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ struct FullyObservablePOMDP{S, A} <: POMDP{S,A,S}
end

POMDPs.observations(pomdp::FullyObservablePOMDP) = states(pomdp.mdp)
POMDPs.n_observations(pomdp::FullyObservablePOMDP) = n_states(pomdp.mdp)
POMDPs.obsindex(pomdp::FullyObservablePOMDP{S, A}, o::S) where {S, A} = stateindex(pomdp.mdp, o)

POMDPs.convert_o(T::Type{V}, o, pomdp::FullyObservablePOMDP) where {V<:AbstractArray} = convert_s(T, s, pomdp.mdp)
Expand Down Expand Up @@ -39,8 +38,6 @@ POMDPs.generate_sr(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) = genera
POMDPs.reward(pomdp::FullyObservablePOMDP{S, A}, s::S, a::A) where {S,A} = reward(pomdp.mdp, s, a)
POMDPs.isterminal(pomdp::FullyObservablePOMDP, s) = isterminal(pomdp.mdp, s)
POMDPs.discount(pomdp::FullyObservablePOMDP) = discount(pomdp.mdp)
POMDPs.n_states(pomdp::FullyObservablePOMDP) = n_states(pomdp.mdp)
POMDPs.n_actions(pomdp::FullyObservablePOMDP) = n_actions(pomdp.mdp)
POMDPs.stateindex(pomdp::FullyObservablePOMDP{S,A}, s::S) where {S,A} = stateindex(pomdp.mdp, s)
POMDPs.actionindex(pomdp::FullyObservablePOMDP{S, A}, a::A) where {S,A} = actionindex(pomdp.mdp, a)
POMDPs.convert_s(T::Type{V}, s, pomdp::FullyObservablePOMDP) where V<:AbstractArray = convert_s(T, s, pomdp.mdp)
Expand Down
21 changes: 11 additions & 10 deletions src/ordered_spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Return an `AbstractVector` of actions ordered according to `actionindex(mdp, a)`

`ordered_actions(mdp)` will always return an `AbstractVector{A}` `v` containing all of the actions in `actions(mdp)` in the order such that `actionindex(mdp, v[i]) == i`. You may wish to override this for your problem for efficiency.
"""
ordered_actions(mdp::Union{MDP,POMDP}) = ordered_vector(actiontype(typeof(mdp)), a->actionindex(mdp,a), actions(mdp), n_actions(mdp), "action")
ordered_actions(mdp::Union{MDP,POMDP}) = ordered_vector(actiontype(typeof(mdp)), a->actionindex(mdp,a), actions(mdp), "action")

"""
ordered_states(mdp)
Expand All @@ -16,7 +16,7 @@ Return an `AbstractVector` of states ordered according to `stateindex(mdp, a)`.

`ordered_states(mdp)` will always return a `AbstractVector{A}` `v` containing all of the states in `states(mdp)` in the order such that `stateindex(mdp, v[i]) == i`. You may wish to override this for your problem for efficiency.
"""
ordered_states(mdp::Union{MDP,POMDP}) = ordered_vector(statetype(typeof(mdp)), s->stateindex(mdp,s), states(mdp), n_states(mdp), "state")
ordered_states(mdp::Union{MDP,POMDP}) = ordered_vector(statetype(typeof(mdp)), s->stateindex(mdp,s), states(mdp), "state")

"""
ordered_observations(pomdp)
Expand All @@ -25,9 +25,10 @@ Return an `AbstractVector` of observations ordered according to `obsindex(pomdp,

`ordered_observations(mdp)` will always return a `AbstractVector{A}` `v` containing all of the observations in `observations(pomdp)` in the order such that `obsindex(pomdp, v[i]) == i`. You may wish to override this for your problem for efficiency.
"""
ordered_observations(pomdp::POMDP) = ordered_vector(obstype(typeof(pomdp)), o->obsindex(pomdp,o), observations(pomdp), n_observations(pomdp), "observation")
ordered_observations(pomdp::POMDP) = ordered_vector(obstype(typeof(pomdp)), o->obsindex(pomdp,o), observations(pomdp), "observation")

function ordered_vector(T::Type, index::Function, space, len, singular, plural=singular*"s")
function ordered_vector(T::Type, index::Function, space, singular, plural=singular*"s")
len = length(space)
a = Array{T}(undef, len)
gotten = falses(len)
for x in space
Expand All @@ -39,7 +40,7 @@ function ordered_vector(T::Type, index::Function, space, len, singular, plural=s
index was $id.

n_$plural(...) was $len.
""")
""")
end
a[id] = x
gotten[id] = true
Expand All @@ -60,23 +61,23 @@ end
@POMDP_require ordered_actions(mdp::Union{MDP,POMDP}) begin
P = typeof(mdp)
@req actionindex(::P, ::actiontype(P))
@req n_actions(::P)
@req actions(::P)
as = actions(mdp)
@req length(::typeof(as))
end

@POMDP_require ordered_states(mdp::Union{MDP,POMDP}) begin
P = typeof(mdp)
@req stateindex(::P, ::statetype(P))
@req n_states(::P)
@req states(::P)
as = states(mdp)
ss = states(mdp)
@req length(::typeof(ss))
end

@POMDP_require ordered_observations(mdp::Union{MDP,POMDP}) begin
P = typeof(mdp)
@req obsindex(::P, ::obstype(P))
@req n_observations(::P)
@req observations(::P)
as = observations(mdp)
os = observations(mdp)
@req length(::typeof(os))
end
11 changes: 6 additions & 5 deletions src/policy_evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ Create an |S|x|S| sparse transition matrix for a given policy.
The row corresponds to the current state and column to the next state. Corresponds to ``T^π`` in equation (4.7) in Kochenderfer, *Decision Making Under Uncertainty*, 2015.
"""
function policy_transition_matrix(m::Union{MDP,POMDP}, p::Policy)
ns = n_states(m)
rows = Int[]
cols = Int[]
probs = Float64[]

for s in states(m)
state_space = states(m)
ns = length(state_space)
for s in state_space
if !isterminal(m, s) # if terminal, the transition probabilities are all just zero
si = stateindex(m, s)
a = action(p, s)
Expand All @@ -66,8 +66,9 @@ function policy_transition_matrix(m::Union{MDP,POMDP}, p::Policy)
end

function policy_reward_vector(m::Union{MDP,POMDP}, p::Policy; rewardfunction=POMDPs.reward)
r = zeros(n_states(m))
for s in states(m)
state_space = states(m)
r = zeros(length(state_space))
for s in state_space
if !isterminal(m, s) # if terminal, the transition probabilities are all just zero
si = stateindex(m, s)
a = action(p, s)
Expand Down
54 changes: 28 additions & 26 deletions src/sparse_tabular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ end
S = statetype(P)
A = actiontype(P)
@req discount(::P)
@req n_states(::P)
@req n_actions(::P)
@subreq ordered_states(mdp)
@subreq ordered_actions(mdp)
@req transition(::P,::S,::A)
Expand All @@ -49,6 +47,8 @@ end
@req actions(::P, ::S)
as = actions(mdp)
ss = states(mdp)
@req length(::typeof(as))
@req length(::typeof(ss))
a = first(as)
s = first(ss)
dist = transition(mdp, s, a)
Expand Down Expand Up @@ -113,8 +113,6 @@ end
A = actiontype(P)
O = obstype(P)
@req discount(::P)
@req n_states(::P)
@req n_actions(::P)
@subreq ordered_states(pomdp)
@subreq ordered_actions(pomdp)
@subreq ordered_observations(pomdp)
Expand All @@ -128,6 +126,8 @@ end
@req obsindex(::P, ::O)
as = actions(pomdp)
ss = states(pomdp)
@req length(::typeof(as))
@req length(::typeof(ss))
a = first(as)
s = first(ss)
dist = transition(pomdp, s, a)
Expand Down Expand Up @@ -160,8 +160,9 @@ const SparseTabularProblem = Union{SparseTabularMDP, SparseTabularPOMDP}

function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP})
# Thanks to zach
na = n_actions(mdp)
ns = n_states(mdp)
na = length(actions(mdp))
state_space = states(mdp)
ns = length(state_space)
transmat_row_A = [Int[] for _ in 1:n_actions(mdp)]
transmat_col_A = [Int[] for _ in 1:n_actions(mdp)]
transmat_data_A = [Float64[] for _ in 1:n_actions(mdp)]
Expand All @@ -187,15 +188,17 @@ function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP})
end
end
end
transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)]
transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], ns, ns) for a in 1:na]
# if an action is not valid from a state, the transition is 0.0 everywhere
# @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1"
# @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(ns)) for a in 1:na) "Transition probabilities must sum to 1"
return transmats_A_S_S2
end

function reward_s_a(mdp::Union{MDP, POMDP})
reward_S_A = fill(-Inf, (n_states(mdp), n_actions(mdp))) # set reward for all actions to -Inf unless they are in actions(mdp, s)
for s in states(mdp)
state_space = states(mdp)
action_space = actions(mdp)
reward_S_A = fill(-Inf, (length(state_space), length(action_space))) # set reward for all actions to -Inf unless they are in actions(mdp, s)
for s in state_space
if isterminal(mdp, s)
reward_S_A[stateindex(mdp, s), :] .= 0.0
else
Expand Down Expand Up @@ -227,12 +230,15 @@ function terminal_states_set(mdp::Union{MDP, POMDP})
end

function observation_matrix_a_sp_o(pomdp::POMDP)
na = n_actions(pomdp)
ns = n_states(pomdp)
no = n_observations(pomdp)
obsmat_row_A = [Int[] for _ in 1:n_actions(pomdp)]
obsmat_col_A = [Int[] for _ in 1:n_actions(pomdp)]
obsmat_data_A = [Float64[] for _ in 1:n_actions(pomdp)]
state_space = states(pomdp)
action_space = actions(pomdp)
obs_space = observations(pomdp)
na = length(action_space)
ns = length(state_space)
no = length(obs_space)
obsmat_row_A = [Int[] for _ in 1:na]
obsmat_col_A = [Int[] for _ in 1:na]
obsmat_data_A = [Float64[] for _ in 1:na]

for sp in states(pomdp)
spi = stateindex(pomdp, sp)
Expand All @@ -249,19 +255,16 @@ function observation_matrix_a_sp_o(pomdp::POMDP)
end
end
end
obsmats_A_SP_O = [sparse(obsmat_row_A[a], obsmat_col_A[a], obsmat_data_A[a], n_states(pomdp), n_states(pomdp)) for a in 1:n_actions(pomdp)]
@assert all(all(sum(obsmats_A_SP_O[a], dims=2) .≈ ones(n_observations(pomdp))) for a in 1:n_actions(pomdp)) "Observation probabilities must sum to 1"
obsmats_A_SP_O = [sparse(obsmat_row_A[a], obsmat_col_A[a], obsmat_data_A[a], ns, ns) for a in 1:na]
@assert all(all(sum(obsmats_A_SP_O[a], dims=2) .≈ ones(no)) for a in 1:na) "Observation probabilities must sum to 1"
return obsmats_A_SP_O
end

# MDP and POMDP common methods

POMDPs.n_states(prob::SparseTabularProblem) = size(prob.T[1], 1)
POMDPs.n_actions(prob::SparseTabularProblem) = size(prob.T, 1)

POMDPs.states(p::SparseTabularProblem) = 1:n_states(p)
POMDPs.actions(p::SparseTabularProblem) = 1:n_actions(p)
POMDPs.actions(p::SparseTabularProblem, s::Int64) = [a for a in actions(p) if sum(transition_matrix(p, a)) ≈ n_states(p)]
POMDPs.states(p::SparseTabularProblem) = 1:size(p.T[1], 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call collect on these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As is, states is an iterator, one can call collect(states(mdp)). I don't think we want to collect by default, what would be the use case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should collect by default

POMDPs.actions(p::SparseTabularProblem) = 1:size(p.T, 1)
POMDPs.actions(p::SparseTabularProblem, s::Int64) = [a for a in actions(p) if sum(transition_matrix(p, a)) ≈ size(p.T[1], 1)]

POMDPs.stateindex(::SparseTabularProblem, s::Int64) = s
POMDPs.actionindex(::SparseTabularProblem, a::Int64) = a
Expand Down Expand Up @@ -303,9 +306,8 @@ Accessor function for the reward matrix R[s, a] of a sparse tabular problem.
reward_matrix(p::SparseTabularProblem) = p.R

# POMDP only methods
POMDPs.n_observations(p::SparseTabularPOMDP) = size(p.O[1], 2)

POMDPs.observations(p::SparseTabularPOMDP) = 1:n_observations(p)
POMDPs.observations(p::SparseTabularPOMDP) = 1:size(p.O[1], 2)

POMDPs.observation(p::SparseTabularPOMDP, a::Int64, sp::Int64) = SparseCat(findnz(p.O[a][sp, :])...)

Expand Down
2 changes: 0 additions & 2 deletions src/underlying_mdp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ POMDPs.reward(mdp::UnderlyingMDP{P, S, A}, s::S, a::A) where {P,S,A} = reward(md
POMDPs.reward(mdp::UnderlyingMDP{P, S, A}, s::S, a::A, sp::S) where {P,S,A} = reward(mdp.pomdp, s, a, sp)
POMDPs.isterminal(mdp ::UnderlyingMDP{P, S, A}, s::S) where {P,S,A} = isterminal(mdp.pomdp, s)
POMDPs.discount(mdp::UnderlyingMDP) = discount(mdp.pomdp)
POMDPs.n_actions(mdp::UnderlyingMDP) = n_actions(mdp.pomdp)
POMDPs.n_states(mdp::UnderlyingMDP) = n_states(mdp.pomdp)
POMDPs.stateindex(mdp::UnderlyingMDP{P, S, A}, s::S) where {P,S,A} = stateindex(mdp.pomdp, s)
POMDPs.stateindex(mdp::UnderlyingMDP{P, Int, A}, s::Int) where {P,A} = stateindex(mdp.pomdp, s) # fix ambiguity with src/convenience
POMDPs.stateindex(mdp::UnderlyingMDP{P, Bool, A}, s::Bool) where {P,A} = stateindex(mdp.pomdp, s)
Expand Down
111 changes: 57 additions & 54 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,65 +9,68 @@ using POMDPPolicies
import Distributions.Categorical
using SparseArrays

@testset "ordered" begin
include("test_ordered_spaces.jl")
end
@testset "POMDPModelTools" begin
@testset "ordered" begin
include("test_ordered_spaces.jl")
end

# require POMDPModels
@testset "genbeliefmdp" begin
include("test_generative_belief_mdp.jl")
end
@testset "implement" begin
include("test_implementations.jl")
end
@testset "weightediter" begin
include("test_weighted_iteration.jl")
end
@testset "sparsecat" begin
include("test_sparse_cat.jl")
end
@testset "bool" begin
include("test_bool.jl")
end
@testset "deterministic" begin
include("test_deterministic.jl")
end
@testset "uniform" begin
include("test_uniform.jl")
end
@testset "terminalstate" begin
include("test_terminal_state.jl")
end
# require POMDPModels
@testset "genbeliefmdp" begin
include("test_generative_belief_mdp.jl")
end
@testset "implement" begin
include("test_implementations.jl")
end
@testset "weightediter" begin
include("test_weighted_iteration.jl")
end
@testset "sparsecat" begin
include("test_sparse_cat.jl")
end
@testset "bool" begin
include("test_bool.jl")
end
@testset "deterministic" begin
include("test_deterministic.jl")
end
@testset "uniform" begin
include("test_uniform.jl")
end
@testset "terminalstate" begin
include("test_terminal_state.jl")
end

# require POMDPModels
@testset "info" begin
include("test_info.jl")
end
@testset "obsweight" begin
include("test_obs_weight.jl")
end
# require POMDPModels
@testset "info" begin
include("test_info.jl")
end
@testset "obsweight" begin
include("test_obs_weight.jl")
end

# require DiscreteValueIteration
@testset "visolve" begin
POMDPs.add_registry()
Pkg.add("DiscreteValueIteration")
using DiscreteValueIteration
include("test_fully_observable_pomdp.jl")
include("test_underlying_mdp.jl")
end
# require DiscreteValueIteration
@testset "visolve" begin
POMDPs.add_registry()
Pkg.add("DiscreteValueIteration")
using DiscreteValueIteration
include("test_fully_observable_pomdp.jl")
include("test_underlying_mdp.jl")
end

@testset "vis" begin
include("test_visualization.jl")
end
@testset "vis" begin
include("test_visualization.jl")
end

@testset "evaluation" begin
include("test_evaluation.jl")
end
@testset "evaluation" begin
include("test_evaluation.jl")
end

@testset "pretty printing" begin
include("test_pretty_printing.jl")
end
@testset "pretty printing" begin
include("test_pretty_printing.jl")
end

@testset "sparse tabular" begin
include("test_tabular.jl")
end

@testset "sparse tabular" begin
include("test_tabular.jl")
end
Loading