Skip to content

Commit

Permalink
Merge pull request #32 from JuliaPOMDP/sparse
Browse files Browse the repository at this point in the history
Use SparseTabularMDP from POMDPModelTools
  • Loading branch information
MaximeBouton authored Aug 13, 2019
2 parents 95f407d + e739781 commit d72ce40
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 61 deletions.
6 changes: 5 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
language: julia
os:
- linux
- osx
- windows

julia:
- 0.7
- 1.0
- 1.1

notifications:
email: false

Expand Down
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
POMDPModelTools = ">= 0.1.7"
POMDPPolicies = ">= 0.1.4"

[extras]
Expand All @@ -19,4 +20,4 @@ POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "POMDPModels", "DelimitedFiles"]
test = ["Test", "POMDPModels", "DelimitedFiles"]
78 changes: 22 additions & 56 deletions src/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function SparseValueIterationSolver(;max_iterations=500,
return SparseValueIterationSolver(max_iterations, belres, include_Q, verbose, init_util)
end

@POMDP_require solve(solver::SparseValueIterationSolver, mdp::Union{MDP,POMDP}) begin
@POMDP_require solve(solver::SparseValueIterationSolver, mdp::MDP) begin
P = typeof(mdp)
S = statetype(P)
A = actiontype(P)
Expand All @@ -36,6 +36,7 @@ end
D = typeof(dist)
@req support(::D)
@req pdf(::D,::S)
@subreq SparseTabularMDP(mdp)
end

function qvalue!(m::Union{MDP,POMDP}, transition_A_S_S2, reward_S_A::AbstractMatrix{F}, value_S::AbstractVector{F}, out_qvals_S_A) where {F}
Expand All @@ -45,59 +46,7 @@ function qvalue!(m::Union{MDP,POMDP}, transition_A_S_S2, reward_S_A::AbstractMat
end
end

function transition_matrix_a_s_sp(mdp::MDP)
# Thanks to zach
na = n_actions(mdp)
ns = n_states(mdp)
transmat_row_A = [Int[] for _ in 1:n_actions(mdp)]
transmat_col_A = [Int[] for _ in 1:n_actions(mdp)]
transmat_data_A = [Float64[] for _ in 1:n_actions(mdp)]

for s in states(mdp)
si = stateindex(mdp, s)
for a in actions(mdp, s)
ai = actionindex(mdp, a)
if !isterminal(mdp, s) # if terminal, the transition probabilities are all just zero
td = transition(mdp, s, a)
for (sp, p) in weighted_iterator(td)
if p > 0.0
spi = stateindex(mdp, sp)
push!(transmat_row_A[ai], si)
push!(transmat_col_A[ai], spi)
push!(transmat_data_A[ai], p)
end
end
end
end
end
transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)]
# Note: assert below is not valid for terminal states
# @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1"
return transmats_A_S_S2
end

function reward_s_a(mdp::MDP)
reward_S_A = fill(-Inf, (n_states(mdp), n_actions(mdp))) # set reward for all actions to -Inf unless they are in actions(mdp, s)
for s in states(mdp)
if isterminal(mdp, s)
reward_S_A[stateindex(mdp, s), :] .= 0.0
else
for a in actions(mdp, s)
td = transition(mdp, s, a)
r = 0.0
for (sp, p) in weighted_iterator(td)
if p > 0.0
r += p*reward(mdp, s, a, sp)
end
end
reward_S_A[stateindex(mdp, s), actionindex(mdp, a)] = r
end
end
end
return reward_S_A
end

function solve(solver::SparseValueIterationSolver, mdp::Union{MDP, POMDP})
function solve(solver::SparseValueIterationSolver, mdp::SparseTabularMDP)
nS = n_states(mdp)
nA = n_actions(mdp)
if isempty(solver.init_util)
Expand All @@ -107,8 +56,8 @@ function solve(solver::SparseValueIterationSolver, mdp::Union{MDP, POMDP})
v_S = solver.init_util
end

transition_A_S_S2 = transition_matrix_a_s_sp(mdp)
reward_S_A = reward_s_a(mdp)
transition_A_S_S2 = transition_matrices(mdp)
reward_S_A = reward_matrix(mdp)
qvals_S_A = zeros(nS, nA)
maxchanges_T = zeros(solver.max_iterations)

Expand Down Expand Up @@ -137,3 +86,20 @@ function solve(solver::SparseValueIterationSolver, mdp::Union{MDP, POMDP})
end
return policy
end

function solve(solver::SparseValueIterationSolver, mdp::MDP)
return solve(solver, SparseTabularMDP(mdp))
end

function solve(::SparseValueIterationSolver, ::POMDP)
throw("""
ValueIterationError: `solve(::SparseValueIterationSolver, ::POMDP)` is not supported,
`SparseValueIterationSolver` supports MDP models only, look at QMDP.jl for a POMDP solver that assumes full observability.
If you still wish to use the transition and reward from your POMDP model you can use the `UnderlyingMDP` wrapper from POMDPModelTools.jl as follows:
```
solver = ValueIterationSolver()
mdp = UnderlyingMDP(pomdp)
solve(solver, mdp)
```
""")
end
15 changes: 14 additions & 1 deletion src/vanilla.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
# policy = ValueIterationPolicy(mdp)
# solve(solver, mdp, policy, verbose=true)
#####################################################################
function solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}; kwargs...)
function solve(solver::ValueIterationSolver, mdp::MDP; kwargs...)

# deprecation warning - can be removed when Julia 1.0 is adopted
if !isempty(kwargs)
Expand Down Expand Up @@ -134,3 +134,16 @@ function solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}; kwargs...)
return ValueIterationPolicy(mdp, utility=util, policy=pol, include_Q=false)
end
end

function solve(::ValueIterationSolver, ::POMDP)
throw("""
ValueIterationError: `solve(::ValueIterationSolver, ::POMDP)` is not supported,
`ValueIterationSolver` supports MDP models only, look at QMDP.jl for a POMDP solver that assumes full observability.
If you still wish to use the transition and reward from your POMDP model you can use the `UnderlyingMDP` wrapper from POMDPModelTools.jl as follows:
```
solver = ValueIterationSolver()
mdp = UnderlyingMDP(pomdp)
solve(solver, mdp)
```
""")
end
3 changes: 2 additions & 1 deletion test/test_basic_value_iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,5 @@ end
@test test_simple_grid() == true
@test test_init_solution() == true
@test test_not_include_Q() == true
test_warning()
test_warning()
@test_throws String solve(ValueIterationSolver(), TigerPOMDP())
2 changes: 1 addition & 1 deletion test/test_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ sparsepolicy = solve(sparsesolver, mdp)

@test sparsepolicy.qmat == policy.qmat
@test value(sparsepolicy, 2) 0.0

@test_throws String solve(SparseValueIterationSolver(), TigerPOMDP())

0 comments on commit d72ce40

Please sign in to comment.