From 73af5be30a42861b78f83a391ef09a4c0e90765d Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Fri, 9 Aug 2019 09:35:29 -0700 Subject: [PATCH] accessor functions --- docs/src/model_transformations.md | 3 +++ src/POMDPModelTools.jl | 5 ++++- src/sparse_tabular.jl | 30 +++++++++++++++++++++++++++--- test/test_tabular.jl | 3 +++ 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/docs/src/model_transformations.md b/docs/src/model_transformations.md index b5bd47d..b3612b1 100644 --- a/docs/src/model_transformations.md +++ b/docs/src/model_transformations.md @@ -9,6 +9,9 @@ The `SparseTabularMDP` and `SparseTabularPOMDP` represents discrete problems def ```@docs SparseTabularMDP SparseTabularPOMDP +transition_matrix +reward_vector +observation_matrix ``` ## Fully Observable POMDP diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index c95a68b..19ba08b 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -91,7 +91,10 @@ include("distributions/pretty_printing.jl") export SparseTabularMDP, - SparseTabularPOMDP + SparseTabularPOMDP, + transition_matrix, + reward_vector, + observation_matrix include("sparse_tabular.jl") diff --git a/src/sparse_tabular.jl b/src/sparse_tabular.jl index 4af9e17..d288893 100644 --- a/src/sparse_tabular.jl +++ b/src/sparse_tabular.jl @@ -3,10 +3,11 @@ An MDP object where states and actions are integers and the transition is represented by a list of sparse matrices. This data structure can be useful to exploit in vectorized algorithm (e.g. see SparseValueIterationSolver). +The recommended way to access the transition and reward matrices is through the provided accessor functions: `transition_matrix` and `reward_vector`. # Fields - `T::Vector{SparseMatrixCSC{Float64, Int64}}` The transition model is represented as a vector of sparse matrices (one for each action). `T[a][s, sp]` the probability of transition from `s` to `sp` taking action `a`. -- `R::Vector{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. +- `R::Array{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. - `discount::Float64` The discount factor # Constructors @@ -68,10 +69,11 @@ end A POMDP object where states and actions are integers and the transition and observation distributions are represented by lists of sparse matrices. This data structure can be useful to exploit in vectorized algorithms to gain performance (e.g. see SparseValueIterationSolver). +The recommended way to access the transition, reward, and observation matrices is through the provided accessor functions: `transition_matrix`, `reward_vector`, `observation_matrix`. # Fields - `T::Vector{SparseMatrixCSC{Float64, Int64}}` The transition model is represented as a vector of sparse matrices (one for each action). `T[a][s, sp]` the probability of transition from `s` to `sp` taking action `a`. -- `R::Vector{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. +- `R::Array{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`. - `O::Vector{SparseMatrixCSC{Float64, Int64}}` The observation model is represented as a vector of sparse matrices (one for each action). `O[a][sp, o]` is the probability of observing `o` from state `sp` after having taken action `a`. - `discount::Float64` The discount factor @@ -235,7 +237,22 @@ POMDPs.discount(p::SparseTabularProblem) = p.discount POMDPs.transition(p::SparseTabularProblem, s::Int64, a::Int64) = SparseCat(findnz(p.T[a][s, :])...) # XXX not memory efficient -POMDPs.reward(prob::SparseTabularProblem, s::Int64, a::Int64) = prob.R[s, a] +POMDPs.reward(p::SparseTabularProblem, s::Int64, a::Int64) = p.R[s, a] + +""" + transition_matrix(p::SparseTabularProblem, a) +Accessor function for the transition model of a sparse tabular problem. +It returns a sparse matrix containing the transition probabilities when taking action a: T[s, sp] = Pr(sp | s, a). +""" +transition_matrix(p::SparseTabularProblem, a) = p.T[a] + +""" + reward_vector(p::SparseTabularProblem, a) +Accessor function for the reward function of a sparse tabular problem. +It returns a vector containing the reward for all the states when taking action a: R(s, a). +The length of the return vector is equal to the number of states. +""" +reward_vector(p::SparseTabularProblem, a) = view(p.R, :, a) # POMDP only methods POMDPs.n_observations(p::SparseTabularPOMDP) = size(p.O[1], 2) @@ -245,3 +262,10 @@ POMDPs.observations(p::SparseTabularPOMDP) = 1:n_observations(p) POMDPs.observation(p::SparseTabularPOMDP, a::Int64, sp::Int64) = SparseCat(findnz(p.O[a][sp, :])...) POMDPs.obsindex(p::SparseTabularPOMDP, o::Int64) = o + +""" + observation_matrix(p::SparseTabularPOMDP, a::Int64) +Accessor function for the observation model of a sparse tabular POMDP. +It returns a sparse matrix containing the observation probabilities when having taken action a: O[sp, o] = Pr(o | sp, a). +""" +observation_matrix(p::SparseTabularPOMDP, a::Int64) = p.O[a] \ No newline at end of file diff --git a/test/test_tabular.jl b/test/test_tabular.jl index 652c29e..0b27c97 100644 --- a/test/test_tabular.jl +++ b/test/test_tabular.jl @@ -60,6 +60,8 @@ smdp3 = SparseTabularMDP(smdp, transition = [sparse(1:n_states(mdp), 1:n_states( smdp3 = SparseTabularMDP(smdp, discount = 0.5) @test smdp3.discount != smdp.discount +@test size(transition_matrix(smdp, 1)) == (n_states(smdp), n_states(smdp)) +@test length(reward_vector(smdp, 1)) == n_states(smdp) ## POMDP @@ -82,6 +84,7 @@ spomdp3 = SparseTabularPOMDP(spomdp, transition = [sparse(1:n_states(mdp), 1:n_s @test spomdp3.R == spomdp.R spomdp3 = SparseTabularPOMDP(spomdp, discount = 0.5) @test spomdp3.discount != spomdp.discount +@test size(observation_matrix(spomdp, 1)) == (n_states(spomdp), n_observations(spomdp)) ## Tests