Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SparseTabularMDP from POMDPModelTools #32

Merged
merged 9 commits into from
Aug 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
language: julia
os:
- linux
- osx
- windows

julia:
- 0.7
- 1.0
- 1.1

notifications:
email: false

Expand Down
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
POMDPModelTools = ">= 0.1.7"
POMDPPolicies = ">= 0.1.4"

[extras]
Expand All @@ -19,4 +20,4 @@ POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "POMDPModels", "DelimitedFiles"]
test = ["Test", "POMDPModels", "DelimitedFiles"]
78 changes: 22 additions & 56 deletions src/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function SparseValueIterationSolver(;max_iterations=500,
return SparseValueIterationSolver(max_iterations, belres, include_Q, verbose, init_util)
end

@POMDP_require solve(solver::SparseValueIterationSolver, mdp::Union{MDP,POMDP}) begin
@POMDP_require solve(solver::SparseValueIterationSolver, mdp::MDP) begin
P = typeof(mdp)
S = statetype(P)
A = actiontype(P)
Expand All @@ -36,6 +36,7 @@ end
D = typeof(dist)
@req support(::D)
@req pdf(::D,::S)
@subreq SparseTabularMDP(mdp)
end

function qvalue!(m::Union{MDP,POMDP}, transition_A_S_S2, reward_S_A::AbstractMatrix{F}, value_S::AbstractVector{F}, out_qvals_S_A) where {F}
Expand All @@ -45,59 +46,7 @@ function qvalue!(m::Union{MDP,POMDP}, transition_A_S_S2, reward_S_A::AbstractMat
end
end

function transition_matrix_a_s_sp(mdp::MDP)
# Thanks to zach
na = n_actions(mdp)
ns = n_states(mdp)
transmat_row_A = [Int[] for _ in 1:n_actions(mdp)]
transmat_col_A = [Int[] for _ in 1:n_actions(mdp)]
transmat_data_A = [Float64[] for _ in 1:n_actions(mdp)]

for s in states(mdp)
si = stateindex(mdp, s)
for a in actions(mdp, s)
ai = actionindex(mdp, a)
if !isterminal(mdp, s) # if terminal, the transition probabilities are all just zero
td = transition(mdp, s, a)
for (sp, p) in weighted_iterator(td)
if p > 0.0
spi = stateindex(mdp, sp)
push!(transmat_row_A[ai], si)
push!(transmat_col_A[ai], spi)
push!(transmat_data_A[ai], p)
end
end
end
end
end
transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)]
# Note: assert below is not valid for terminal states
# @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1"
return transmats_A_S_S2
end

function reward_s_a(mdp::MDP)
reward_S_A = fill(-Inf, (n_states(mdp), n_actions(mdp))) # set reward for all actions to -Inf unless they are in actions(mdp, s)
for s in states(mdp)
if isterminal(mdp, s)
reward_S_A[stateindex(mdp, s), :] .= 0.0
else
for a in actions(mdp, s)
td = transition(mdp, s, a)
r = 0.0
for (sp, p) in weighted_iterator(td)
if p > 0.0
r += p*reward(mdp, s, a, sp)
end
end
reward_S_A[stateindex(mdp, s), actionindex(mdp, a)] = r
end
end
end
return reward_S_A
end

function solve(solver::SparseValueIterationSolver, mdp::Union{MDP, POMDP})
function solve(solver::SparseValueIterationSolver, mdp::SparseTabularMDP)
nS = n_states(mdp)
nA = n_actions(mdp)
if isempty(solver.init_util)
Expand All @@ -107,8 +56,8 @@ function solve(solver::SparseValueIterationSolver, mdp::Union{MDP, POMDP})
v_S = solver.init_util
end

transition_A_S_S2 = transition_matrix_a_s_sp(mdp)
reward_S_A = reward_s_a(mdp)
transition_A_S_S2 = transition_matrices(mdp)
reward_S_A = reward_matrix(mdp)
qvals_S_A = zeros(nS, nA)
maxchanges_T = zeros(solver.max_iterations)

Expand Down Expand Up @@ -137,3 +86,20 @@ function solve(solver::SparseValueIterationSolver, mdp::Union{MDP, POMDP})
end
return policy
end

function solve(solver::SparseValueIterationSolver, mdp::MDP)
return solve(solver, SparseTabularMDP(mdp))
end

function solve(::SparseValueIterationSolver, ::POMDP)
throw("""
ValueIterationError: `solve(::SparseValueIterationSolver, ::POMDP)` is not supported,
`SparseValueIterationSolver` supports MDP models only, look at QMDP.jl for a POMDP solver that assumes full observability.
If you still wish to use the transition and reward from your POMDP model you can use the `UnderlyingMDP` wrapper from POMDPModelTools.jl as follows:
```
solver = ValueIterationSolver()
mdp = UnderlyingMDP(pomdp)
solve(solver, mdp)
```
""")
end
15 changes: 14 additions & 1 deletion src/vanilla.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
# policy = ValueIterationPolicy(mdp)
# solve(solver, mdp, policy, verbose=true)
#####################################################################
function solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}; kwargs...)
function solve(solver::ValueIterationSolver, mdp::MDP; kwargs...)

# deprecation warning - can be removed when Julia 1.0 is adopted
if !isempty(kwargs)
Expand Down Expand Up @@ -134,3 +134,16 @@ function solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}; kwargs...)
return ValueIterationPolicy(mdp, utility=util, policy=pol, include_Q=false)
end
end

function solve(::ValueIterationSolver, ::POMDP)
throw("""
ValueIterationError: `solve(::ValueIterationSolver, ::POMDP)` is not supported,
`ValueIterationSolver` supports MDP models only, look at QMDP.jl for a POMDP solver that assumes full observability.
If you still wish to use the transition and reward from your POMDP model you can use the `UnderlyingMDP` wrapper from POMDPModelTools.jl as follows:
```
solver = ValueIterationSolver()
mdp = UnderlyingMDP(pomdp)
solve(solver, mdp)
```
""")
end
3 changes: 2 additions & 1 deletion test/test_basic_value_iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,5 @@ end
@test test_simple_grid() == true
@test test_init_solution() == true
@test test_not_include_Q() == true
test_warning()
test_warning()
@test_throws String solve(ValueIterationSolver(), TigerPOMDP())
2 changes: 1 addition & 1 deletion test/test_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ sparsepolicy = solve(sparsesolver, mdp)

@test sparsepolicy.qmat == policy.qmat
@test value(sparsepolicy, 2) ≈ 0.0

@test_throws String solve(SparseValueIterationSolver(), TigerPOMDP())