Skip to content

Commit

Permalink
add assert and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximeBouton committed Aug 8, 2019
1 parent 6bacd51 commit 0379b72
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 5 deletions.
9 changes: 9 additions & 0 deletions docs/src/model_transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

POMDPModelTools contains several tools for transforming problems into other classes so that they can be used by different solvers.

## Sparse Tabular MDPs and POMDPs

The `SparseTabularMDP` and `SparseTabularPOMDP` represents discrete problems defined using the explicit interface. The transition and observation models are represented using sparse matrices. Solver writers can leverage these data structures to write efficient vectorized code. A problem writer can define its problem using the explicit interface and it can be automatically converted to a sparse tabular representation by calling the constructors `SparseTabularMDP(::MDP)` or `SparseTabularPOMDP(::POMDP)`. See the following docs to know more about the matrix representation and how to access the fields of the `SparseTabular` objects:

```@docs
SparseTabularMDP
SparseTabularPOMDP
```

## Fully Observable POMDP

```@docs
Expand Down
101 changes: 96 additions & 5 deletions src/sparse_tabular.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
"""
SparseTabularMDP
An MDP object where states and actions are integers and the transition is represented by a list of sparse matrices.
This data structure can be useful to exploit in vectorized algorithm (e.g. see SparseValueIterationSolver).
# Fields
- `T::Vector{SparseMatrixCSC{Float64, Int64}}` The transition model is represented as a vector of sparse matrices (one for each action). `T[a][s, sp]` the probability of transition from `s` to `sp` taking action `a`.
- `R::Vector{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`.
- `discount::Float64` The discount factor
# Constructors
- `SparseTabularMDP(mdp::MDP)` : One can provide the matrices to the default constructor or one can construct a `SparseTabularMDP` from any discrete state MDP defined using the explicit interface.
Note that constructing the transition and reward matrices requires to iterate over all the states and can take a while.
To learn more information about how to define an MDP with the explicit interface please visit https://juliapomdp.github.io/POMDPs.jl/latest/explicit/ .
- `SparseTabularMDP(smdp::SparseTabularMDP; transition, reward, discount)` : This constructor returns a new sparse MDP that is a copy of the original smdp except for the field specified by the keyword arguments.
"""
struct SparseTabularMDP <: MDP{Int64, Int64}
T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp]
R::Array{Float64, 2} # R[s,sp]
R::Array{Float64, 2} # R[s, a]
discount::Float64
end

Expand All @@ -10,6 +29,30 @@ function SparseTabularMDP(mdp::MDP)
return SparseTabularMDP(T, R, discount(mdp))
end

@POMDP_require SparseTabularMDP(mdp::MDP) begin
P = typeof(mdp)
S = statetype(P)
A = actiontype(P)
@req discount(::P)
@req n_states(::P)
@req n_actions(::P)
@subreq ordered_states(mdp)
@subreq ordered_actions(mdp)
@req transition(::P,::S,::A)
@req reward(::P,::S,::A,::S)
@req stateindex(::P,::S)
@req actionindex(::P, ::A)
@req actions(::P, ::S)
as = actions(mdp)
ss = states(mdp)
a = first(as)
s = first(ss)
dist = transition(mdp, s, a)
D = typeof(dist)
@req support(::D)
@req pdf(::D,::S)
end

function SparseTabularMDP(mdp::SparseTabularMDP;
transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing,
reward::Union{Nothing, Array{Float64, 2}} = nothing,
Expand All @@ -20,6 +63,26 @@ function SparseTabularMDP(mdp::SparseTabularMDP;
return SparseTabularMDP(T, R, d)
end

"""
SparseTabularPOMDP
A POMDP object where states and actions are integers and the transition and observation distributions are represented by lists of sparse matrices.
This data structure can be useful to exploit in vectorized algorithms to gain performance (e.g. see SparseValueIterationSolver).
# Fields
- `T::Vector{SparseMatrixCSC{Float64, Int64}}` The transition model is represented as a vector of sparse matrices (one for each action). `T[a][s, sp]` the probability of transition from `s` to `sp` taking action `a`.
- `R::Vector{Float64, 2}` The reward is represented as a matrix where the rows are states and the columns actions: `R[s, a]` is the reward of taking action `a` in sate `s`.
- `O::Vector{SparseMatrixCSC{Float64, Int64}}` The observation model is represented as a vector of sparse matrices (one for each action). `O[a][sp, o]` is the probability of observing `o` from state `sp` after having taken action `a`.
- `discount::Float64` The discount factor
# Constructors
- `SparseTabularPOMDP(pomdp::POMDP)` : One can provide the matrices to the default constructor or one can construct a `SparseTabularPOMDP` from any discrete state MDP defined using the explicit interface.
Note that constructing the transition and reward matrices requires to iterate over all the states and can take a while.
To learn more information about how to define an MDP with the explicit interface please visit https://juliapomdp.github.io/POMDPs.jl/latest/explicit/ .
- `SparseTabularPOMDP(spomdp::SparseTabularMDP; transition, reward, observation, discount)` : This constructor returns a new sparse POMDP that is a copy of the original smdp except for the field specified by the keyword arguments.
"""
struct SparseTabularPOMDP <: POMDP{Int64, Int64, Int64}
T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp]
R::Array{Float64, 2} # R[s,sp]
Expand All @@ -34,6 +97,31 @@ function SparseTabularPOMDP(pomdp::POMDP)
return SparseTabularPOMDP(T, R, O, discount(pomdp))
end

@POMDP_require SparseTabularMDP(mdp::MDP) begin
P = typeof(mdp)
S = statetype(P)
A = actiontype(P)
@req discount(::P)
@req n_states(::P)
@req n_actions(::P)
@subreq ordered_states(mdp)
@subreq ordered_actions(mdp)
@req transition(::P,::S,::A)
@req reward(::P,::S,::A,::S)
@req stateindex(::P,::S)
@req actionindex(::P, ::A)
@req actions(::P, ::S)
as = actions(mdp)
ss = states(mdp)
a = first(as)
s = first(ss)
dist = transition(mdp, s, a)
D = typeof(dist)
@req support(::D)
@req pdf(::D,::S)
end


function SparseTabularPOMDP(pomdp::SparseTabularPOMDP;
transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing,
reward::Union{Nothing, Array{Float64, 2}} = nothing,
Expand Down Expand Up @@ -61,7 +149,11 @@ function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP})
si = stateindex(mdp, s)
for a in actions(mdp, s)
ai = actionindex(mdp, a)
if !isterminal(mdp, s) # if terminal, the transition probabilities are all just zero
if isterminal(mdp, s) # if terminal, there is a probability of 1 of staying in that state
push!(transmat_row_A[ai], si)
push!(transmat_col_A[ai], si)
push!(transmat_data_A[ai], 1.0)
else
td = transition(mdp, s, a)
for (sp, p) in weighted_iterator(td)
if p > 0.0
Expand All @@ -75,8 +167,7 @@ function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP})
end
end
transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)]
# Note: assert below is not valid for terminal states
# @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1"
@assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1"
return transmats_A_S_S2
end

Expand Down Expand Up @@ -125,7 +216,7 @@ function observation_matrix_a_sp_o(pomdp::POMDP)
end
end
obsmats_A_SP_O = [sparse(obsmat_row_A[a], obsmat_col_A[a], obsmat_data_A[a], n_states(pomdp), n_states(pomdp)) for a in 1:n_actions(pomdp)]

@assert all(all(sum(obsmats_A_SP_O[a], dims=2) .≈ ones(n_observations(pomdp))) for a in 1:n_actions(pomdp)) "Observation probabilities must sum to 1"
return obsmats_A_SP_O
end

Expand Down

0 comments on commit 0379b72

Please sign in to comment.