diff --git a/docs/src/model_transformations.md b/docs/src/model_transformations.md index 816c4c6..b3612b1 100644 --- a/docs/src/model_transformations.md +++ b/docs/src/model_transformations.md @@ -2,6 +2,18 @@ POMDPModelTools contains several tools for transforming problems into other classes so that they can be used by different solvers. +## Sparse Tabular MDPs and POMDPs + +The `SparseTabularMDP` and `SparseTabularPOMDP` represents discrete problems defined using the explicit interface. The transition and observation models are represented using sparse matrices. Solver writers can leverage these data structures to write efficient vectorized code. A problem writer can define its problem using the explicit interface and it can be automatically converted to a sparse tabular representation by calling the constructors `SparseTabularMDP(::MDP)` or `SparseTabularPOMDP(::POMDP)`. See the following docs to know more about the matrix representation and how to access the fields of the `SparseTabular` objects: + +```@docs +SparseTabularMDP +SparseTabularPOMDP +transition_matrix +reward_vector +observation_matrix +``` + ## Fully Observable POMDP ```@docs diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index cec0ab2..19ba08b 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -89,4 +89,13 @@ export showdistribution include("distributions/pretty_printing.jl") +export + SparseTabularMDP, + SparseTabularPOMDP, + transition_matrix, + reward_vector, + observation_matrix + +include("sparse_tabular.jl") + end # module diff --git a/src/sparse_tabular.jl b/src/sparse_tabular.jl new file mode 100644 index 0000000..d288893 --- /dev/null +++ b/src/sparse_tabular.jl @@ -0,0 +1,271 @@ +""" + SparseTabularMDP + +An MDP object where states and actions are integers and the transition is represented by a list of sparse matrices. +This data structure can be useful to exploit in vectorized algorithm (e.g. see SparseValueIterationSolver). +The recommended way to access the transition and reward matrices is through the provided accessor functions: `transition_matrix` and `reward_vector`. + +# Fields +- `T::Vector{SparseMatrixCSC{Float64, Int64}}` The transition model is represented as a vector of sparse matrices (one for each action). `T[a][s, sp]` the probability of transition from `s` to `sp` taking action `a`. +- `R::Array{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. +- `discount::Float64` The discount factor + +# Constructors + +- `SparseTabularMDP(mdp::MDP)` : One can provide the matrices to the default constructor or one can construct a `SparseTabularMDP` from any discrete state MDP defined using the explicit interface. +Note that constructing the transition and reward matrices requires to iterate over all the states and can take a while. +To learn more information about how to define an MDP with the explicit interface please visit https://juliapomdp.github.io/POMDPs.jl/latest/explicit/ . +- `SparseTabularMDP(smdp::SparseTabularMDP; transition, reward, discount)` : This constructor returns a new sparse MDP that is a copy of the original smdp except for the field specified by the keyword arguments. + +""" +struct SparseTabularMDP <: MDP{Int64, Int64} + T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp] + R::Array{Float64, 2} # R[s, a] + discount::Float64 +end + +function SparseTabularMDP(mdp::MDP) + T = transition_matrix_a_s_sp(mdp) + R = reward_s_a(mdp) + return SparseTabularMDP(T, R, discount(mdp)) +end + +@POMDP_require SparseTabularMDP(mdp::MDP) begin + P = typeof(mdp) + S = statetype(P) + A = actiontype(P) + @req discount(::P) + @req n_states(::P) + @req n_actions(::P) + @subreq ordered_states(mdp) + @subreq ordered_actions(mdp) + @req transition(::P,::S,::A) + @req reward(::P,::S,::A,::S) + @req stateindex(::P,::S) + @req actionindex(::P, ::A) + @req actions(::P, ::S) + as = actions(mdp) + ss = states(mdp) + a = first(as) + s = first(ss) + dist = transition(mdp, s, a) + D = typeof(dist) + @req support(::D) + @req pdf(::D,::S) +end + +function SparseTabularMDP(mdp::SparseTabularMDP; + transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, + reward::Union{Nothing, Array{Float64, 2}} = nothing, + discount::Union{Nothing, Float64} = nothing) + T = transition != nothing ? transition : mdp.T + R = reward != nothing ? reward : mdp.R + d = discount != nothing ? discount : mdp.discount + return SparseTabularMDP(T, R, d) +end + +""" + SparseTabularPOMDP + +A POMDP object where states and actions are integers and the transition and observation distributions are represented by lists of sparse matrices. +This data structure can be useful to exploit in vectorized algorithms to gain performance (e.g. see SparseValueIterationSolver). +The recommended way to access the transition, reward, and observation matrices is through the provided accessor functions: `transition_matrix`, `reward_vector`, `observation_matrix`. + +# Fields +- `T::Vector{SparseMatrixCSC{Float64, Int64}}` The transition model is represented as a vector of sparse matrices (one for each action). `T[a][s, sp]` the probability of transition from `s` to `sp` taking action `a`. +- `R::Array{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. +- `O::Vector{SparseMatrixCSC{Float64, Int64}}` The observation model is represented as a vector of sparse matrices (one for each action). `O[a][sp, o]` is the probability of observing `o` from state `sp` after having taken action `a`. +- `discount::Float64` The discount factor + +# Constructors + +- `SparseTabularPOMDP(pomdp::POMDP)` : One can provide the matrices to the default constructor or one can construct a `SparseTabularPOMDP` from any discrete state MDP defined using the explicit interface. +Note that constructing the transition and reward matrices requires to iterate over all the states and can take a while. +To learn more information about how to define an MDP with the explicit interface please visit https://juliapomdp.github.io/POMDPs.jl/latest/explicit/ . +- `SparseTabularPOMDP(spomdp::SparseTabularMDP; transition, reward, observation, discount)` : This constructor returns a new sparse POMDP that is a copy of the original smdp except for the field specified by the keyword arguments. + +""" +struct SparseTabularPOMDP <: POMDP{Int64, Int64, Int64} + T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp] + R::Array{Float64, 2} # R[s,sp] + O::Vector{SparseMatrixCSC{Float64, Int64}} # O[a][sp, o] + discount::Float64 +end + +function SparseTabularPOMDP(pomdp::POMDP) + T = transition_matrix_a_s_sp(pomdp) + R = reward_s_a(pomdp) + O = observation_matrix_a_sp_o(pomdp) + return SparseTabularPOMDP(T, R, O, discount(pomdp)) +end + +@POMDP_require SparseTabularMDP(mdp::MDP) begin + P = typeof(mdp) + S = statetype(P) + A = actiontype(P) + @req discount(::P) + @req n_states(::P) + @req n_actions(::P) + @subreq ordered_states(mdp) + @subreq ordered_actions(mdp) + @req transition(::P,::S,::A) + @req reward(::P,::S,::A,::S) + @req stateindex(::P,::S) + @req actionindex(::P, ::A) + @req actions(::P, ::S) + as = actions(mdp) + ss = states(mdp) + a = first(as) + s = first(ss) + dist = transition(mdp, s, a) + D = typeof(dist) + @req support(::D) + @req pdf(::D,::S) +end + + +function SparseTabularPOMDP(pomdp::SparseTabularPOMDP; + transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, + reward::Union{Nothing, Array{Float64, 2}} = nothing, + observation::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, + discount::Union{Nothing, Float64} = nothing) + T = transition != nothing ? transition : pomdp.T + R = reward != nothing ? reward : pomdp.R + d = discount != nothing ? discount : pomdp.discount + O = observation != nothing ? transition : pomdp.O + return SparseTabularPOMDP(T, R, O, d) +end + +const SparseTabularProblem = Union{SparseTabularMDP, SparseTabularPOMDP} + + +function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP}) + # Thanks to zach + na = n_actions(mdp) + ns = n_states(mdp) + transmat_row_A = [Int[] for _ in 1:n_actions(mdp)] + transmat_col_A = [Int[] for _ in 1:n_actions(mdp)] + transmat_data_A = [Float64[] for _ in 1:n_actions(mdp)] + + for s in states(mdp) + si = stateindex(mdp, s) + for a in actions(mdp, s) + ai = actionindex(mdp, a) + if isterminal(mdp, s) # if terminal, there is a probability of 1 of staying in that state + push!(transmat_row_A[ai], si) + push!(transmat_col_A[ai], si) + push!(transmat_data_A[ai], 1.0) + else + td = transition(mdp, s, a) + for (sp, p) in weighted_iterator(td) + if p > 0.0 + spi = stateindex(mdp, sp) + push!(transmat_row_A[ai], si) + push!(transmat_col_A[ai], spi) + push!(transmat_data_A[ai], p) + end + end + end + end + end + transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)] + @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1" + return transmats_A_S_S2 +end + +function reward_s_a(mdp::Union{MDP, POMDP}) + reward_S_A = fill(-Inf, (n_states(mdp), n_actions(mdp))) # set reward for all actions to -Inf unless they are in actions(mdp, s) + for s in states(mdp) + if isterminal(mdp, s) + reward_S_A[stateindex(mdp, s), :] .= 0.0 + else + for a in actions(mdp, s) + td = transition(mdp, s, a) + r = 0.0 + for (sp, p) in weighted_iterator(td) + if p > 0.0 + r += p*reward(mdp, s, a, sp) + end + end + reward_S_A[stateindex(mdp, s), actionindex(mdp, a)] = r + end + end + end + return reward_S_A +end + +function observation_matrix_a_sp_o(pomdp::POMDP) + na = n_actions(pomdp) + ns = n_states(pomdp) + no = n_observations(pomdp) + obsmat_row_A = [Int[] for _ in 1:n_actions(pomdp)] + obsmat_col_A = [Int[] for _ in 1:n_actions(pomdp)] + obsmat_data_A = [Float64[] for _ in 1:n_actions(pomdp)] + + for sp in states(pomdp) + spi = stateindex(pomdp, sp) + for a in actions(pomdp) + ai = actionindex(pomdp, a) + od = observation(pomdp, a, sp) + for (o, p) in weighted_iterator(od) + if p > 0.0 + oi = obsindex(pomdp, o) + push!(obsmat_row_A[ai], spi) + push!(obsmat_col_A[ai], oi) + push!(obsmat_data_A[ai], p) + end + end + end + end + obsmats_A_SP_O = [sparse(obsmat_row_A[a], obsmat_col_A[a], obsmat_data_A[a], n_states(pomdp), n_states(pomdp)) for a in 1:n_actions(pomdp)] + @assert all(all(sum(obsmats_A_SP_O[a], dims=2) .≈ ones(n_observations(pomdp))) for a in 1:n_actions(pomdp)) "Observation probabilities must sum to 1" + return obsmats_A_SP_O +end + +# MDP and POMDP common methods + +POMDPs.n_states(prob::SparseTabularProblem) = size(prob.T[1], 1) +POMDPs.n_actions(prob::SparseTabularProblem) = size(prob.T, 1) + +POMDPs.states(p::SparseTabularProblem) = 1:n_states(p) +POMDPs.actions(p::SparseTabularProblem) = 1:n_actions(p) + +POMDPs.stateindex(::SparseTabularProblem, s::Int64) = s +POMDPs.actionindex(::SparseTabularProblem, a::Int64) = a + +POMDPs.discount(p::SparseTabularProblem) = p.discount + +POMDPs.transition(p::SparseTabularProblem, s::Int64, a::Int64) = SparseCat(findnz(p.T[a][s, :])...) # XXX not memory efficient + +POMDPs.reward(p::SparseTabularProblem, s::Int64, a::Int64) = p.R[s, a] + +""" + transition_matrix(p::SparseTabularProblem, a) +Accessor function for the transition model of a sparse tabular problem. +It returns a sparse matrix containing the transition probabilities when taking action a: T[s, sp] = Pr(sp | s, a). +""" +transition_matrix(p::SparseTabularProblem, a) = p.T[a] + +""" + reward_vector(p::SparseTabularProblem, a) +Accessor function for the reward function of a sparse tabular problem. +It returns a vector containing the reward for all the states when taking action a: R(s, a). +The length of the return vector is equal to the number of states. +""" +reward_vector(p::SparseTabularProblem, a) = view(p.R, :, a) + +# POMDP only methods +POMDPs.n_observations(p::SparseTabularPOMDP) = size(p.O[1], 2) + +POMDPs.observations(p::SparseTabularPOMDP) = 1:n_observations(p) + +POMDPs.observation(p::SparseTabularPOMDP, a::Int64, sp::Int64) = SparseCat(findnz(p.O[a][sp, :])...) + +POMDPs.obsindex(p::SparseTabularPOMDP, o::Int64) = o + +""" + observation_matrix(p::SparseTabularPOMDP, a::Int64) +Accessor function for the observation model of a sparse tabular POMDP. +It returns a sparse matrix containing the observation probabilities when having taken action a: O[sp, o] = Pr(o | sp, a). +""" +observation_matrix(p::SparseTabularPOMDP, a::Int64) = p.O[a] \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9d3fd13..4a5a324 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using Pkg using POMDPSimulators using POMDPPolicies import Distributions.Categorical +using SparseArrays @testset "ordered" begin include("test_ordered_spaces.jl") @@ -66,3 +67,7 @@ end @testset "pretty printing" begin include("test_pretty_printing.jl") end + +@testset "sparse tabular" begin + include("test_tabular.jl") +end \ No newline at end of file diff --git a/test/test_tabular.jl b/test/test_tabular.jl new file mode 100644 index 0000000..0b27c97 --- /dev/null +++ b/test/test_tabular.jl @@ -0,0 +1,106 @@ +function test_transition(pb1::Union{MDP, POMDP}, pb2::Union{SparseTabularMDP, SparseTabularPOMDP}) + for s in states(pb1) + for a in actions(pb1) + td1 = transition(pb1, s, a) + si = stateindex(pb1, s) + ai = actionindex(pb1, a) + td2 = transition(pb2, si, ai) + for (sp, p) in weighted_iterator(td1) + spi = stateindex(pb1, sp) + @test pdf(td2, spi) == p + end + end + end +end + +function test_reward(pb1::Union{MDP, POMDP}, pb2::Union{SparseTabularMDP, SparseTabularPOMDP}) + for s in states(pb1) + for a in actions(pb1) + si = stateindex(pb1, s) + ai = actionindex(pb1, a) + @test reward(pb1, s, a) == reward(pb2, si, ai) + end + end +end + +function test_observation(pb1::POMDP, pb2::SparseTabularPOMDP) + for s in states(pb1) + for a in actions(pb1) + od1 = observation(pb1, a, s) + si = stateindex(pb1, s) + ai = actionindex(pb1, a) + od2 = observation(pb2, ai, si) + for (o, p) in weighted_iterator(od1) + oi = obsindex(pb1, o) + @test pdf(od2, oi) == p + end + end + end +end + + +## MDP + +mdp = RandomMDP(100, 4, 0.95) + +smdp = SparseTabularMDP(mdp) + +smdp2 = SparseTabularMDP(smdp) + +@test smdp2.T == smdp.T +@test smdp2.R == smdp.R +@test smdp2.discount == smdp.discount + +smdp3 = SparseTabularMDP(smdp, reward = zeros(n_states(mdp), n_actions(mdp))) +@test smdp3.T == smdp.T +@test smdp3.R != smdp.R +smdp3 = SparseTabularMDP(smdp, transition = [sparse(1:n_states(mdp), 1:n_states(mdp), 1.0) for a in 1:n_actions(mdp)]) +@test smdp3.T != smdp.T +@test smdp3.R == smdp.R +smdp3 = SparseTabularMDP(smdp, discount = 0.5) +@test smdp3.discount != smdp.discount + +@test size(transition_matrix(smdp, 1)) == (n_states(smdp), n_states(smdp)) +@test length(reward_vector(smdp, 1)) == n_states(smdp) + +## POMDP + +pomdp = TigerPOMDP() + +spomdp = SparseTabularPOMDP(pomdp) + +spomdp2 = SparseTabularPOMDP(spomdp) + +@test spomdp2.T == spomdp.T +@test spomdp2.R == spomdp.R +@test spomdp2.O == spomdp.O +@test spomdp2.discount == spomdp.discount + +spomdp3 = SparseTabularPOMDP(spomdp, reward = zeros(n_states(mdp), n_actions(mdp))) +@test spomdp3.T == spomdp.T +@test spomdp3.R != spomdp.R +spomdp3 = SparseTabularPOMDP(spomdp, transition = [sparse(1:n_states(mdp), 1:n_states(mdp), 1.0) for a in 1:n_actions(mdp)]) +@test spomdp3.T != spomdp.T +@test spomdp3.R == spomdp.R +spomdp3 = SparseTabularPOMDP(spomdp, discount = 0.5) +@test spomdp3.discount != spomdp.discount +@test size(observation_matrix(spomdp, 1)) == (n_states(spomdp), n_observations(spomdp)) + + +## Tests + +@test n_states(pomdp) == n_states(spomdp) +@test n_actions(pomdp) == n_actions(spomdp) +@test n_observations(pomdp) == n_observations(spomdp) +@test length(states(spomdp)) == n_states(spomdp) +@test length(actions(spomdp)) == n_actions(spomdp) +@test length(observations(spomdp)) == n_observations(spomdp) +@test statetype(spomdp) == Int64 +@test actiontype(spomdp) == Int64 +@test obstype(spomdp) == Int64 +@test discount(spomdp) == discount(pomdp) + +test_transition(mdp, smdp) +test_transition(pomdp, spomdp) +test_reward(pomdp, spomdp) +test_observation(pomdp, spomdp)