From 6bacd51fd1c469954b2aa0847e4d483183a38208 Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Wed, 7 Aug 2019 15:14:39 -0700 Subject: [PATCH 1/3] implement SparseTabular --- src/POMDPModelTools.jl | 6 ++ src/sparse_tabular.jl | 156 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 5 ++ test/test_tabular.jl | 103 +++++++++++++++++++++++++++ 4 files changed, 270 insertions(+) create mode 100644 src/sparse_tabular.jl create mode 100644 test/test_tabular.jl diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index cec0ab2..c95a68b 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -89,4 +89,10 @@ export showdistribution include("distributions/pretty_printing.jl") +export + SparseTabularMDP, + SparseTabularPOMDP + +include("sparse_tabular.jl") + end # module diff --git a/src/sparse_tabular.jl b/src/sparse_tabular.jl new file mode 100644 index 0000000..11e97fb --- /dev/null +++ b/src/sparse_tabular.jl @@ -0,0 +1,156 @@ +struct SparseTabularMDP <: MDP{Int64, Int64} + T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp] + R::Array{Float64, 2} # R[s,sp] + discount::Float64 +end + +function SparseTabularMDP(mdp::MDP) + T = transition_matrix_a_s_sp(mdp) + R = reward_s_a(mdp) + return SparseTabularMDP(T, R, discount(mdp)) +end + +function SparseTabularMDP(mdp::SparseTabularMDP; + transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, + reward::Union{Nothing, Array{Float64, 2}} = nothing, + discount::Union{Nothing, Float64} = nothing) + T = transition != nothing ? transition : mdp.T + R = reward != nothing ? reward : mdp.R + d = discount != nothing ? discount : mdp.discount + return SparseTabularMDP(T, R, d) +end + +struct SparseTabularPOMDP <: POMDP{Int64, Int64, Int64} + T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp] + R::Array{Float64, 2} # R[s,sp] + O::Vector{SparseMatrixCSC{Float64, Int64}} # O[a][sp, o] + discount::Float64 +end + +function SparseTabularPOMDP(pomdp::POMDP) + T = transition_matrix_a_s_sp(pomdp) + R = reward_s_a(pomdp) + O = observation_matrix_a_sp_o(pomdp) + return SparseTabularPOMDP(T, R, O, discount(pomdp)) +end + +function SparseTabularPOMDP(pomdp::SparseTabularPOMDP; + transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, + reward::Union{Nothing, Array{Float64, 2}} = nothing, + observation::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, + discount::Union{Nothing, Float64} = nothing) + T = transition != nothing ? transition : pomdp.T + R = reward != nothing ? reward : pomdp.R + d = discount != nothing ? discount : pomdp.discount + O = observation != nothing ? transition : pomdp.O + return SparseTabularPOMDP(T, R, O, d) +end + +const SparseTabularProblem = Union{SparseTabularMDP, SparseTabularPOMDP} + + +function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP}) + # Thanks to zach + na = n_actions(mdp) + ns = n_states(mdp) + transmat_row_A = [Int[] for _ in 1:n_actions(mdp)] + transmat_col_A = [Int[] for _ in 1:n_actions(mdp)] + transmat_data_A = [Float64[] for _ in 1:n_actions(mdp)] + + for s in states(mdp) + si = stateindex(mdp, s) + for a in actions(mdp, s) + ai = actionindex(mdp, a) + if !isterminal(mdp, s) # if terminal, the transition probabilities are all just zero + td = transition(mdp, s, a) + for (sp, p) in weighted_iterator(td) + if p > 0.0 + spi = stateindex(mdp, sp) + push!(transmat_row_A[ai], si) + push!(transmat_col_A[ai], spi) + push!(transmat_data_A[ai], p) + end + end + end + end + end + transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)] + # Note: assert below is not valid for terminal states + # @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1" + return transmats_A_S_S2 +end + +function reward_s_a(mdp::Union{MDP, POMDP}) + reward_S_A = fill(-Inf, (n_states(mdp), n_actions(mdp))) # set reward for all actions to -Inf unless they are in actions(mdp, s) + for s in states(mdp) + if isterminal(mdp, s) + reward_S_A[stateindex(mdp, s), :] .= 0.0 + else + for a in actions(mdp, s) + td = transition(mdp, s, a) + r = 0.0 + for (sp, p) in weighted_iterator(td) + if p > 0.0 + r += p*reward(mdp, s, a, sp) + end + end + reward_S_A[stateindex(mdp, s), actionindex(mdp, a)] = r + end + end + end + return reward_S_A +end + +function observation_matrix_a_sp_o(pomdp::POMDP) + na = n_actions(pomdp) + ns = n_states(pomdp) + no = n_observations(pomdp) + obsmat_row_A = [Int[] for _ in 1:n_actions(pomdp)] + obsmat_col_A = [Int[] for _ in 1:n_actions(pomdp)] + obsmat_data_A = [Float64[] for _ in 1:n_actions(pomdp)] + + for sp in states(pomdp) + spi = stateindex(pomdp, sp) + for a in actions(pomdp) + ai = actionindex(pomdp, a) + od = observation(pomdp, a, sp) + for (o, p) in weighted_iterator(od) + if p > 0.0 + oi = obsindex(pomdp, o) + push!(obsmat_row_A[ai], spi) + push!(obsmat_col_A[ai], oi) + push!(obsmat_data_A[ai], p) + end + end + end + end + obsmats_A_SP_O = [sparse(obsmat_row_A[a], obsmat_col_A[a], obsmat_data_A[a], n_states(pomdp), n_states(pomdp)) for a in 1:n_actions(pomdp)] + + return obsmats_A_SP_O +end + +# MDP and POMDP common methods + +POMDPs.n_states(prob::SparseTabularProblem) = size(prob.T[1], 1) +POMDPs.n_actions(prob::SparseTabularProblem) = size(prob.T, 1) + +POMDPs.states(p::SparseTabularProblem) = 1:n_states(p) +POMDPs.actions(p::SparseTabularProblem) = 1:n_actions(p) + +POMDPs.stateindex(::SparseTabularProblem, s::Int64) = s +POMDPs.actionindex(::SparseTabularProblem, a::Int64) = a + +POMDPs.discount(p::SparseTabularProblem) = p.discount + +POMDPs.transition(p::SparseTabularProblem, s::Int64, a::Int64) = SparseCat(findnz(p.T[a][s, :])...) # XXX not memory efficient + +POMDPs.reward(prob::SparseTabularProblem, s::Int64, a::Int64) = prob.R[s, a] + +# POMDP only methods +POMDPs.n_observations(p::SparseTabularPOMDP) = size(p.O[1], 2) + +POMDPs.observations(p::SparseTabularPOMDP) = 1:n_observations(p) + +POMDPs.observation(p::SparseTabularPOMDP, a::Int64, sp::Int64) = SparseCat(findnz(p.O[a][sp, :])...) + +POMDPs.obsindex(p::SparseTabularPOMDP, o::Int64) = o diff --git a/test/runtests.jl b/test/runtests.jl index 9d3fd13..4a5a324 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using Pkg using POMDPSimulators using POMDPPolicies import Distributions.Categorical +using SparseArrays @testset "ordered" begin include("test_ordered_spaces.jl") @@ -66,3 +67,7 @@ end @testset "pretty printing" begin include("test_pretty_printing.jl") end + +@testset "sparse tabular" begin + include("test_tabular.jl") +end \ No newline at end of file diff --git a/test/test_tabular.jl b/test/test_tabular.jl new file mode 100644 index 0000000..652c29e --- /dev/null +++ b/test/test_tabular.jl @@ -0,0 +1,103 @@ +function test_transition(pb1::Union{MDP, POMDP}, pb2::Union{SparseTabularMDP, SparseTabularPOMDP}) + for s in states(pb1) + for a in actions(pb1) + td1 = transition(pb1, s, a) + si = stateindex(pb1, s) + ai = actionindex(pb1, a) + td2 = transition(pb2, si, ai) + for (sp, p) in weighted_iterator(td1) + spi = stateindex(pb1, sp) + @test pdf(td2, spi) == p + end + end + end +end + +function test_reward(pb1::Union{MDP, POMDP}, pb2::Union{SparseTabularMDP, SparseTabularPOMDP}) + for s in states(pb1) + for a in actions(pb1) + si = stateindex(pb1, s) + ai = actionindex(pb1, a) + @test reward(pb1, s, a) == reward(pb2, si, ai) + end + end +end + +function test_observation(pb1::POMDP, pb2::SparseTabularPOMDP) + for s in states(pb1) + for a in actions(pb1) + od1 = observation(pb1, a, s) + si = stateindex(pb1, s) + ai = actionindex(pb1, a) + od2 = observation(pb2, ai, si) + for (o, p) in weighted_iterator(od1) + oi = obsindex(pb1, o) + @test pdf(od2, oi) == p + end + end + end +end + + +## MDP + +mdp = RandomMDP(100, 4, 0.95) + +smdp = SparseTabularMDP(mdp) + +smdp2 = SparseTabularMDP(smdp) + +@test smdp2.T == smdp.T +@test smdp2.R == smdp.R +@test smdp2.discount == smdp.discount + +smdp3 = SparseTabularMDP(smdp, reward = zeros(n_states(mdp), n_actions(mdp))) +@test smdp3.T == smdp.T +@test smdp3.R != smdp.R +smdp3 = SparseTabularMDP(smdp, transition = [sparse(1:n_states(mdp), 1:n_states(mdp), 1.0) for a in 1:n_actions(mdp)]) +@test smdp3.T != smdp.T +@test smdp3.R == smdp.R +smdp3 = SparseTabularMDP(smdp, discount = 0.5) +@test smdp3.discount != smdp.discount + + +## POMDP + +pomdp = TigerPOMDP() + +spomdp = SparseTabularPOMDP(pomdp) + +spomdp2 = SparseTabularPOMDP(spomdp) + +@test spomdp2.T == spomdp.T +@test spomdp2.R == spomdp.R +@test spomdp2.O == spomdp.O +@test spomdp2.discount == spomdp.discount + +spomdp3 = SparseTabularPOMDP(spomdp, reward = zeros(n_states(mdp), n_actions(mdp))) +@test spomdp3.T == spomdp.T +@test spomdp3.R != spomdp.R +spomdp3 = SparseTabularPOMDP(spomdp, transition = [sparse(1:n_states(mdp), 1:n_states(mdp), 1.0) for a in 1:n_actions(mdp)]) +@test spomdp3.T != spomdp.T +@test spomdp3.R == spomdp.R +spomdp3 = SparseTabularPOMDP(spomdp, discount = 0.5) +@test spomdp3.discount != spomdp.discount + + +## Tests + +@test n_states(pomdp) == n_states(spomdp) +@test n_actions(pomdp) == n_actions(spomdp) +@test n_observations(pomdp) == n_observations(spomdp) +@test length(states(spomdp)) == n_states(spomdp) +@test length(actions(spomdp)) == n_actions(spomdp) +@test length(observations(spomdp)) == n_observations(spomdp) +@test statetype(spomdp) == Int64 +@test actiontype(spomdp) == Int64 +@test obstype(spomdp) == Int64 +@test discount(spomdp) == discount(pomdp) + +test_transition(mdp, smdp) +test_transition(pomdp, spomdp) +test_reward(pomdp, spomdp) +test_observation(pomdp, spomdp) From 0379b72831162d9fec90a4967ffd23f26de7c46b Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Thu, 8 Aug 2019 14:29:19 -0700 Subject: [PATCH 2/3] add assert and docs --- docs/src/model_transformations.md | 9 +++ src/sparse_tabular.jl | 101 ++++++++++++++++++++++++++++-- 2 files changed, 105 insertions(+), 5 deletions(-) diff --git a/docs/src/model_transformations.md b/docs/src/model_transformations.md index 816c4c6..b5bd47d 100644 --- a/docs/src/model_transformations.md +++ b/docs/src/model_transformations.md @@ -2,6 +2,15 @@ POMDPModelTools contains several tools for transforming problems into other classes so that they can be used by different solvers. +## Sparse Tabular MDPs and POMDPs + +The `SparseTabularMDP` and `SparseTabularPOMDP` represents discrete problems defined using the explicit interface. The transition and observation models are represented using sparse matrices. Solver writers can leverage these data structures to write efficient vectorized code. A problem writer can define its problem using the explicit interface and it can be automatically converted to a sparse tabular representation by calling the constructors `SparseTabularMDP(::MDP)` or `SparseTabularPOMDP(::POMDP)`. See the following docs to know more about the matrix representation and how to access the fields of the `SparseTabular` objects: + +```@docs +SparseTabularMDP +SparseTabularPOMDP +``` + ## Fully Observable POMDP ```@docs diff --git a/src/sparse_tabular.jl b/src/sparse_tabular.jl index 11e97fb..4af9e17 100644 --- a/src/sparse_tabular.jl +++ b/src/sparse_tabular.jl @@ -1,6 +1,25 @@ +""" + SparseTabularMDP + +An MDP object where states and actions are integers and the transition is represented by a list of sparse matrices. +This data structure can be useful to exploit in vectorized algorithm (e.g. see SparseValueIterationSolver). + +# Fields +- `T::Vector{SparseMatrixCSC{Float64, Int64}}` The transition model is represented as a vector of sparse matrices (one for each action). `T[a][s, sp]` the probability of transition from `s` to `sp` taking action `a`. +- `R::Vector{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. +- `discount::Float64` The discount factor + +# Constructors + +- `SparseTabularMDP(mdp::MDP)` : One can provide the matrices to the default constructor or one can construct a `SparseTabularMDP` from any discrete state MDP defined using the explicit interface. +Note that constructing the transition and reward matrices requires to iterate over all the states and can take a while. +To learn more information about how to define an MDP with the explicit interface please visit https://juliapomdp.github.io/POMDPs.jl/latest/explicit/ . +- `SparseTabularMDP(smdp::SparseTabularMDP; transition, reward, discount)` : This constructor returns a new sparse MDP that is a copy of the original smdp except for the field specified by the keyword arguments. + +""" struct SparseTabularMDP <: MDP{Int64, Int64} T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp] - R::Array{Float64, 2} # R[s,sp] + R::Array{Float64, 2} # R[s, a] discount::Float64 end @@ -10,6 +29,30 @@ function SparseTabularMDP(mdp::MDP) return SparseTabularMDP(T, R, discount(mdp)) end +@POMDP_require SparseTabularMDP(mdp::MDP) begin + P = typeof(mdp) + S = statetype(P) + A = actiontype(P) + @req discount(::P) + @req n_states(::P) + @req n_actions(::P) + @subreq ordered_states(mdp) + @subreq ordered_actions(mdp) + @req transition(::P,::S,::A) + @req reward(::P,::S,::A,::S) + @req stateindex(::P,::S) + @req actionindex(::P, ::A) + @req actions(::P, ::S) + as = actions(mdp) + ss = states(mdp) + a = first(as) + s = first(ss) + dist = transition(mdp, s, a) + D = typeof(dist) + @req support(::D) + @req pdf(::D,::S) +end + function SparseTabularMDP(mdp::SparseTabularMDP; transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, reward::Union{Nothing, Array{Float64, 2}} = nothing, @@ -20,6 +63,26 @@ function SparseTabularMDP(mdp::SparseTabularMDP; return SparseTabularMDP(T, R, d) end +""" + SparseTabularPOMDP + +A POMDP object where states and actions are integers and the transition and observation distributions are represented by lists of sparse matrices. +This data structure can be useful to exploit in vectorized algorithms to gain performance (e.g. see SparseValueIterationSolver). + +# Fields +- `T::Vector{SparseMatrixCSC{Float64, Int64}}` The transition model is represented as a vector of sparse matrices (one for each action). `T[a][s, sp]` the probability of transition from `s` to `sp` taking action `a`. +- `R::Vector{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. +- `O::Vector{SparseMatrixCSC{Float64, Int64}}` The observation model is represented as a vector of sparse matrices (one for each action). `O[a][sp, o]` is the probability of observing `o` from state `sp` after having taken action `a`. +- `discount::Float64` The discount factor + +# Constructors + +- `SparseTabularPOMDP(pomdp::POMDP)` : One can provide the matrices to the default constructor or one can construct a `SparseTabularPOMDP` from any discrete state MDP defined using the explicit interface. +Note that constructing the transition and reward matrices requires to iterate over all the states and can take a while. +To learn more information about how to define an MDP with the explicit interface please visit https://juliapomdp.github.io/POMDPs.jl/latest/explicit/ . +- `SparseTabularPOMDP(spomdp::SparseTabularMDP; transition, reward, observation, discount)` : This constructor returns a new sparse POMDP that is a copy of the original smdp except for the field specified by the keyword arguments. + +""" struct SparseTabularPOMDP <: POMDP{Int64, Int64, Int64} T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp] R::Array{Float64, 2} # R[s,sp] @@ -34,6 +97,31 @@ function SparseTabularPOMDP(pomdp::POMDP) return SparseTabularPOMDP(T, R, O, discount(pomdp)) end +@POMDP_require SparseTabularMDP(mdp::MDP) begin + P = typeof(mdp) + S = statetype(P) + A = actiontype(P) + @req discount(::P) + @req n_states(::P) + @req n_actions(::P) + @subreq ordered_states(mdp) + @subreq ordered_actions(mdp) + @req transition(::P,::S,::A) + @req reward(::P,::S,::A,::S) + @req stateindex(::P,::S) + @req actionindex(::P, ::A) + @req actions(::P, ::S) + as = actions(mdp) + ss = states(mdp) + a = first(as) + s = first(ss) + dist = transition(mdp, s, a) + D = typeof(dist) + @req support(::D) + @req pdf(::D,::S) +end + + function SparseTabularPOMDP(pomdp::SparseTabularPOMDP; transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, reward::Union{Nothing, Array{Float64, 2}} = nothing, @@ -61,7 +149,11 @@ function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP}) si = stateindex(mdp, s) for a in actions(mdp, s) ai = actionindex(mdp, a) - if !isterminal(mdp, s) # if terminal, the transition probabilities are all just zero + if isterminal(mdp, s) # if terminal, there is a probability of 1 of staying in that state + push!(transmat_row_A[ai], si) + push!(transmat_col_A[ai], si) + push!(transmat_data_A[ai], 1.0) + else td = transition(mdp, s, a) for (sp, p) in weighted_iterator(td) if p > 0.0 @@ -75,8 +167,7 @@ function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP}) end end transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)] - # Note: assert below is not valid for terminal states - # @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1" + @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1" return transmats_A_S_S2 end @@ -125,7 +216,7 @@ function observation_matrix_a_sp_o(pomdp::POMDP) end end obsmats_A_SP_O = [sparse(obsmat_row_A[a], obsmat_col_A[a], obsmat_data_A[a], n_states(pomdp), n_states(pomdp)) for a in 1:n_actions(pomdp)] - + @assert all(all(sum(obsmats_A_SP_O[a], dims=2) .≈ ones(n_observations(pomdp))) for a in 1:n_actions(pomdp)) "Observation probabilities must sum to 1" return obsmats_A_SP_O end From 73af5be30a42861b78f83a391ef09a4c0e90765d Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Fri, 9 Aug 2019 09:35:29 -0700 Subject: [PATCH 3/3] accessor functions --- docs/src/model_transformations.md | 3 +++ src/POMDPModelTools.jl | 5 ++++- src/sparse_tabular.jl | 30 +++++++++++++++++++++++++++--- test/test_tabular.jl | 3 +++ 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/docs/src/model_transformations.md b/docs/src/model_transformations.md index b5bd47d..b3612b1 100644 --- a/docs/src/model_transformations.md +++ b/docs/src/model_transformations.md @@ -9,6 +9,9 @@ The `SparseTabularMDP` and `SparseTabularPOMDP` represents discrete problems def ```@docs SparseTabularMDP SparseTabularPOMDP +transition_matrix +reward_vector +observation_matrix ``` ## Fully Observable POMDP diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index c95a68b..19ba08b 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -91,7 +91,10 @@ include("distributions/pretty_printing.jl") export SparseTabularMDP, - SparseTabularPOMDP + SparseTabularPOMDP, + transition_matrix, + reward_vector, + observation_matrix include("sparse_tabular.jl") diff --git a/src/sparse_tabular.jl b/src/sparse_tabular.jl index 4af9e17..d288893 100644 --- a/src/sparse_tabular.jl +++ b/src/sparse_tabular.jl @@ -3,10 +3,11 @@ An MDP object where states and actions are integers and the transition is represented by a list of sparse matrices. This data structure can be useful to exploit in vectorized algorithm (e.g. see SparseValueIterationSolver). +The recommended way to access the transition and reward matrices is through the provided accessor functions: `transition_matrix` and `reward_vector`. # Fields - `T::Vector{SparseMatrixCSC{Float64, Int64}}` The transition model is represented as a vector of sparse matrices (one for each action). `T[a][s, sp]` the probability of transition from `s` to `sp` taking action `a`. -- `R::Vector{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. +- `R::Array{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. - `discount::Float64` The discount factor # Constructors @@ -68,10 +69,11 @@ end A POMDP object where states and actions are integers and the transition and observation distributions are represented by lists of sparse matrices. This data structure can be useful to exploit in vectorized algorithms to gain performance (e.g. see SparseValueIterationSolver). +The recommended way to access the transition, reward, and observation matrices is through the provided accessor functions: `transition_matrix`, `reward_vector`, `observation_matrix`. # Fields - `T::Vector{SparseMatrixCSC{Float64, Int64}}` The transition model is represented as a vector of sparse matrices (one for each action). `T[a][s, sp]` the probability of transition from `s` to `sp` taking action `a`. -- `R::Vector{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. +- `R::Array{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. - `O::Vector{SparseMatrixCSC{Float64, Int64}}` The observation model is represented as a vector of sparse matrices (one for each action). `O[a][sp, o]` is the probability of observing `o` from state `sp` after having taken action `a`. - `discount::Float64` The discount factor @@ -235,7 +237,22 @@ POMDPs.discount(p::SparseTabularProblem) = p.discount POMDPs.transition(p::SparseTabularProblem, s::Int64, a::Int64) = SparseCat(findnz(p.T[a][s, :])...) # XXX not memory efficient -POMDPs.reward(prob::SparseTabularProblem, s::Int64, a::Int64) = prob.R[s, a] +POMDPs.reward(p::SparseTabularProblem, s::Int64, a::Int64) = p.R[s, a] + +""" + transition_matrix(p::SparseTabularProblem, a) +Accessor function for the transition model of a sparse tabular problem. +It returns a sparse matrix containing the transition probabilities when taking action a: T[s, sp] = Pr(sp | s, a). +""" +transition_matrix(p::SparseTabularProblem, a) = p.T[a] + +""" + reward_vector(p::SparseTabularProblem, a) +Accessor function for the reward function of a sparse tabular problem. +It returns a vector containing the reward for all the states when taking action a: R(s, a). +The length of the return vector is equal to the number of states. +""" +reward_vector(p::SparseTabularProblem, a) = view(p.R, :, a) # POMDP only methods POMDPs.n_observations(p::SparseTabularPOMDP) = size(p.O[1], 2) @@ -245,3 +262,10 @@ POMDPs.observations(p::SparseTabularPOMDP) = 1:n_observations(p) POMDPs.observation(p::SparseTabularPOMDP, a::Int64, sp::Int64) = SparseCat(findnz(p.O[a][sp, :])...) POMDPs.obsindex(p::SparseTabularPOMDP, o::Int64) = o + +""" + observation_matrix(p::SparseTabularPOMDP, a::Int64) +Accessor function for the observation model of a sparse tabular POMDP. +It returns a sparse matrix containing the observation probabilities when having taken action a: O[sp, o] = Pr(o | sp, a). +""" +observation_matrix(p::SparseTabularPOMDP, a::Int64) = p.O[a] \ No newline at end of file diff --git a/test/test_tabular.jl b/test/test_tabular.jl index 652c29e..0b27c97 100644 --- a/test/test_tabular.jl +++ b/test/test_tabular.jl @@ -60,6 +60,8 @@ smdp3 = SparseTabularMDP(smdp, transition = [sparse(1:n_states(mdp), 1:n_states( smdp3 = SparseTabularMDP(smdp, discount = 0.5) @test smdp3.discount != smdp.discount +@test size(transition_matrix(smdp, 1)) == (n_states(smdp), n_states(smdp)) +@test length(reward_vector(smdp, 1)) == n_states(smdp) ## POMDP @@ -82,6 +84,7 @@ spomdp3 = SparseTabularPOMDP(spomdp, transition = [sparse(1:n_states(mdp), 1:n_s @test spomdp3.R == spomdp.R spomdp3 = SparseTabularPOMDP(spomdp, discount = 0.5) @test spomdp3.discount != spomdp.discount +@test size(observation_matrix(spomdp, 1)) == (n_states(spomdp), n_observations(spomdp)) ## Tests