diff --git a/docs/src/model_transformations.md b/docs/src/model_transformations.md index b3612b1..6e8666e 100644 --- a/docs/src/model_transformations.md +++ b/docs/src/model_transformations.md @@ -12,6 +12,9 @@ SparseTabularPOMDP transition_matrix reward_vector observation_matrix +transition_matrices +reward_matrix +observation_matrices ``` ## Fully Observable POMDP diff --git a/src/POMDPModelTools.jl b/src/POMDPModelTools.jl index 19ba08b..d9c0aab 100644 --- a/src/POMDPModelTools.jl +++ b/src/POMDPModelTools.jl @@ -94,7 +94,10 @@ export SparseTabularPOMDP, transition_matrix, reward_vector, - observation_matrix + observation_matrix, + transition_matrices, + reward_matrix, + observation_matrices include("sparse_tabular.jl") diff --git a/src/sparse_tabular.jl b/src/sparse_tabular.jl index 4b2799c..f705051 100644 --- a/src/sparse_tabular.jl +++ b/src/sparse_tabular.jl @@ -273,11 +273,11 @@ It returns a sparse matrix containing the transition probabilities when taking a transition_matrix(p::SparseTabularProblem, a) = p.T[a] """ - transition_matrix(p::SparseTabularProblem) + transition_matrices(p::SparseTabularProblem) Accessor function for the transition model of a sparse tabular problem. It returns a list of sparse matrices for each action of the problem. """ -transition_matrix(p::SparseTabularProblem) = p.T +transition_matrices(p::SparseTabularProblem) = p.T """ reward_vector(p::SparseTabularProblem, a) @@ -310,8 +310,8 @@ It returns a sparse matrix containing the observation probabilities when having observation_matrix(p::SparseTabularPOMDP, a::Int64) = p.O[a] """ - observation_matrix(p::SparseTabularPOMDP) + observation_matrices(p::SparseTabularPOMDP) Accessor function for the observation model of a sparse tabular POMDP. It returns a list of sparse matrices for each action of the problem. """ -observation_matrix(p::SparseTabularPOMDP) = p.O \ No newline at end of file +observation_matrices(p::SparseTabularPOMDP) = p.O \ No newline at end of file diff --git a/test/test_tabular.jl b/test/test_tabular.jl index cdeacc8..241ab71 100644 --- a/test/test_tabular.jl +++ b/test/test_tabular.jl @@ -91,8 +91,8 @@ spomdp3 = SparseTabularPOMDP(spomdp, transition = [sparse(1:n_states(mdp), 1:n_s spomdp3 = SparseTabularPOMDP(spomdp, discount = 0.5) @test spomdp3.discount != spomdp.discount @test size(observation_matrix(spomdp, 1)) == (n_states(spomdp), n_observations(spomdp)) -@test observation_matrix(spomdp) == spomdp2.O -@test transition_matrix(spomdp) == spomdp2.T +@test observation_matrices(spomdp) == spomdp2.O +@test transition_matrices(spomdp) == spomdp2.T @test reward_matrix(spomdp) == spomdp2.R ## Tests