diff --git a/.travis.yml b/.travis.yml index 08c1a30..d727559 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,9 +1,13 @@ language: julia os: - linux + - osx + - windows + julia: - - 0.7 - 1.0 + - 1.1 + notifications: email: false diff --git a/Project.toml b/Project.toml index 167739a..ddc0c7a 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] +POMDPModelTools = ">= 0.1.7" POMDPPolicies = ">= 0.1.4" [extras] @@ -19,4 +20,4 @@ POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "POMDPModels", "DelimitedFiles"] +test = ["Test", "POMDPModels", "DelimitedFiles"] \ No newline at end of file diff --git a/src/sparse.jl b/src/sparse.jl index 8a8e726..010b3bc 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -14,7 +14,7 @@ function SparseValueIterationSolver(;max_iterations=500, return SparseValueIterationSolver(max_iterations, belres, include_Q, verbose, init_util) end -@POMDP_require solve(solver::SparseValueIterationSolver, mdp::Union{MDP,POMDP}) begin +@POMDP_require solve(solver::SparseValueIterationSolver, mdp::MDP) begin P = typeof(mdp) S = statetype(P) A = actiontype(P) @@ -36,6 +36,7 @@ end D = typeof(dist) @req support(::D) @req pdf(::D,::S) + @subreq SparseTabularMDP(mdp) end function qvalue!(m::Union{MDP,POMDP}, transition_A_S_S2, reward_S_A::AbstractMatrix{F}, value_S::AbstractVector{F}, out_qvals_S_A) where {F} @@ -45,59 +46,7 @@ function qvalue!(m::Union{MDP,POMDP}, transition_A_S_S2, reward_S_A::AbstractMat end end -function transition_matrix_a_s_sp(mdp::MDP) - # Thanks to zach - na = n_actions(mdp) - ns = n_states(mdp) - transmat_row_A = [Int[] for _ in 1:n_actions(mdp)] - transmat_col_A = [Int[] for _ in 1:n_actions(mdp)] - transmat_data_A = [Float64[] for _ in 1:n_actions(mdp)] - - for s in states(mdp) - si = stateindex(mdp, s) - for a in actions(mdp, s) - ai = actionindex(mdp, a) - if !isterminal(mdp, s) # if terminal, the transition probabilities are all just zero - td = transition(mdp, s, a) - for (sp, p) in weighted_iterator(td) - if p > 0.0 - spi = stateindex(mdp, sp) - push!(transmat_row_A[ai], si) - push!(transmat_col_A[ai], spi) - push!(transmat_data_A[ai], p) - end - end - end - end - end - transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)] - # Note: assert below is not valid for terminal states - # @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1" - return transmats_A_S_S2 -end - -function reward_s_a(mdp::MDP) - reward_S_A = fill(-Inf, (n_states(mdp), n_actions(mdp))) # set reward for all actions to -Inf unless they are in actions(mdp, s) - for s in states(mdp) - if isterminal(mdp, s) - reward_S_A[stateindex(mdp, s), :] .= 0.0 - else - for a in actions(mdp, s) - td = transition(mdp, s, a) - r = 0.0 - for (sp, p) in weighted_iterator(td) - if p > 0.0 - r += p*reward(mdp, s, a, sp) - end - end - reward_S_A[stateindex(mdp, s), actionindex(mdp, a)] = r - end - end - end - return reward_S_A -end - -function solve(solver::SparseValueIterationSolver, mdp::Union{MDP, POMDP}) +function solve(solver::SparseValueIterationSolver, mdp::SparseTabularMDP) nS = n_states(mdp) nA = n_actions(mdp) if isempty(solver.init_util) @@ -107,8 +56,8 @@ function solve(solver::SparseValueIterationSolver, mdp::Union{MDP, POMDP}) v_S = solver.init_util end - transition_A_S_S2 = transition_matrix_a_s_sp(mdp) - reward_S_A = reward_s_a(mdp) + transition_A_S_S2 = transition_matrices(mdp) + reward_S_A = reward_matrix(mdp) qvals_S_A = zeros(nS, nA) maxchanges_T = zeros(solver.max_iterations) @@ -137,3 +86,20 @@ function solve(solver::SparseValueIterationSolver, mdp::Union{MDP, POMDP}) end return policy end + +function solve(solver::SparseValueIterationSolver, mdp::MDP) + return solve(solver, SparseTabularMDP(mdp)) +end + +function solve(::SparseValueIterationSolver, ::POMDP) + throw(""" + ValueIterationError: `solve(::SparseValueIterationSolver, ::POMDP)` is not supported, + `SparseValueIterationSolver` supports MDP models only, look at QMDP.jl for a POMDP solver that assumes full observability. + If you still wish to use the transition and reward from your POMDP model you can use the `UnderlyingMDP` wrapper from POMDPModelTools.jl as follows: + ``` + solver = ValueIterationSolver() + mdp = UnderlyingMDP(pomdp) + solve(solver, mdp) + ``` + """) +end \ No newline at end of file diff --git a/src/vanilla.jl b/src/vanilla.jl index e833c79..e524cc8 100644 --- a/src/vanilla.jl +++ b/src/vanilla.jl @@ -50,7 +50,7 @@ end # policy = ValueIterationPolicy(mdp) # solve(solver, mdp, policy, verbose=true) ##################################################################### -function solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}; kwargs...) +function solve(solver::ValueIterationSolver, mdp::MDP; kwargs...) # deprecation warning - can be removed when Julia 1.0 is adopted if !isempty(kwargs) @@ -134,3 +134,16 @@ function solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}; kwargs...) return ValueIterationPolicy(mdp, utility=util, policy=pol, include_Q=false) end end + +function solve(::ValueIterationSolver, ::POMDP) + throw(""" + ValueIterationError: `solve(::ValueIterationSolver, ::POMDP)` is not supported, + `ValueIterationSolver` supports MDP models only, look at QMDP.jl for a POMDP solver that assumes full observability. + If you still wish to use the transition and reward from your POMDP model you can use the `UnderlyingMDP` wrapper from POMDPModelTools.jl as follows: + ``` + solver = ValueIterationSolver() + mdp = UnderlyingMDP(pomdp) + solve(solver, mdp) + ``` + """) +end \ No newline at end of file diff --git a/test/test_basic_value_iteration.jl b/test/test_basic_value_iteration.jl index 8b16851..5d43442 100644 --- a/test/test_basic_value_iteration.jl +++ b/test/test_basic_value_iteration.jl @@ -95,4 +95,5 @@ end @test test_simple_grid() == true @test test_init_solution() == true @test test_not_include_Q() == true -test_warning() \ No newline at end of file +test_warning() +@test_throws String solve(ValueIterationSolver(), TigerPOMDP()) \ No newline at end of file diff --git a/test/test_sparse.jl b/test/test_sparse.jl index 1296251..77d9359 100644 --- a/test/test_sparse.jl +++ b/test/test_sparse.jl @@ -45,4 +45,4 @@ sparsepolicy = solve(sparsesolver, mdp) @test sparsepolicy.qmat == policy.qmat @test value(sparsepolicy, 2) ≈ 0.0 - +@test_throws String solve(SparseValueIterationSolver(), TigerPOMDP())