-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8a1e0d0
commit 6bacd51
Showing
4 changed files
with
270 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
struct SparseTabularMDP <: MDP{Int64, Int64} | ||
T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp] | ||
R::Array{Float64, 2} # R[s,sp] | ||
discount::Float64 | ||
end | ||
|
||
function SparseTabularMDP(mdp::MDP) | ||
T = transition_matrix_a_s_sp(mdp) | ||
R = reward_s_a(mdp) | ||
return SparseTabularMDP(T, R, discount(mdp)) | ||
end | ||
|
||
function SparseTabularMDP(mdp::SparseTabularMDP; | ||
transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, | ||
reward::Union{Nothing, Array{Float64, 2}} = nothing, | ||
discount::Union{Nothing, Float64} = nothing) | ||
T = transition != nothing ? transition : mdp.T | ||
R = reward != nothing ? reward : mdp.R | ||
d = discount != nothing ? discount : mdp.discount | ||
return SparseTabularMDP(T, R, d) | ||
end | ||
|
||
struct SparseTabularPOMDP <: POMDP{Int64, Int64, Int64} | ||
T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp] | ||
R::Array{Float64, 2} # R[s,sp] | ||
O::Vector{SparseMatrixCSC{Float64, Int64}} # O[a][sp, o] | ||
discount::Float64 | ||
end | ||
|
||
function SparseTabularPOMDP(pomdp::POMDP) | ||
T = transition_matrix_a_s_sp(pomdp) | ||
R = reward_s_a(pomdp) | ||
O = observation_matrix_a_sp_o(pomdp) | ||
return SparseTabularPOMDP(T, R, O, discount(pomdp)) | ||
end | ||
|
||
function SparseTabularPOMDP(pomdp::SparseTabularPOMDP; | ||
transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, | ||
reward::Union{Nothing, Array{Float64, 2}} = nothing, | ||
observation::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing, | ||
discount::Union{Nothing, Float64} = nothing) | ||
T = transition != nothing ? transition : pomdp.T | ||
R = reward != nothing ? reward : pomdp.R | ||
d = discount != nothing ? discount : pomdp.discount | ||
O = observation != nothing ? transition : pomdp.O | ||
return SparseTabularPOMDP(T, R, O, d) | ||
end | ||
|
||
const SparseTabularProblem = Union{SparseTabularMDP, SparseTabularPOMDP} | ||
|
||
|
||
function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP}) | ||
# Thanks to zach | ||
na = n_actions(mdp) | ||
ns = n_states(mdp) | ||
transmat_row_A = [Int[] for _ in 1:n_actions(mdp)] | ||
transmat_col_A = [Int[] for _ in 1:n_actions(mdp)] | ||
transmat_data_A = [Float64[] for _ in 1:n_actions(mdp)] | ||
|
||
for s in states(mdp) | ||
si = stateindex(mdp, s) | ||
for a in actions(mdp, s) | ||
ai = actionindex(mdp, a) | ||
if !isterminal(mdp, s) # if terminal, the transition probabilities are all just zero | ||
td = transition(mdp, s, a) | ||
for (sp, p) in weighted_iterator(td) | ||
if p > 0.0 | ||
spi = stateindex(mdp, sp) | ||
push!(transmat_row_A[ai], si) | ||
push!(transmat_col_A[ai], spi) | ||
push!(transmat_data_A[ai], p) | ||
end | ||
end | ||
end | ||
end | ||
end | ||
transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)] | ||
# Note: assert below is not valid for terminal states | ||
# @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1" | ||
return transmats_A_S_S2 | ||
end | ||
|
||
function reward_s_a(mdp::Union{MDP, POMDP}) | ||
reward_S_A = fill(-Inf, (n_states(mdp), n_actions(mdp))) # set reward for all actions to -Inf unless they are in actions(mdp, s) | ||
for s in states(mdp) | ||
if isterminal(mdp, s) | ||
reward_S_A[stateindex(mdp, s), :] .= 0.0 | ||
else | ||
for a in actions(mdp, s) | ||
td = transition(mdp, s, a) | ||
r = 0.0 | ||
for (sp, p) in weighted_iterator(td) | ||
if p > 0.0 | ||
r += p*reward(mdp, s, a, sp) | ||
end | ||
end | ||
reward_S_A[stateindex(mdp, s), actionindex(mdp, a)] = r | ||
end | ||
end | ||
end | ||
return reward_S_A | ||
end | ||
|
||
function observation_matrix_a_sp_o(pomdp::POMDP) | ||
na = n_actions(pomdp) | ||
ns = n_states(pomdp) | ||
no = n_observations(pomdp) | ||
obsmat_row_A = [Int[] for _ in 1:n_actions(pomdp)] | ||
obsmat_col_A = [Int[] for _ in 1:n_actions(pomdp)] | ||
obsmat_data_A = [Float64[] for _ in 1:n_actions(pomdp)] | ||
|
||
for sp in states(pomdp) | ||
spi = stateindex(pomdp, sp) | ||
for a in actions(pomdp) | ||
ai = actionindex(pomdp, a) | ||
od = observation(pomdp, a, sp) | ||
for (o, p) in weighted_iterator(od) | ||
if p > 0.0 | ||
oi = obsindex(pomdp, o) | ||
push!(obsmat_row_A[ai], spi) | ||
push!(obsmat_col_A[ai], oi) | ||
push!(obsmat_data_A[ai], p) | ||
end | ||
end | ||
end | ||
end | ||
obsmats_A_SP_O = [sparse(obsmat_row_A[a], obsmat_col_A[a], obsmat_data_A[a], n_states(pomdp), n_states(pomdp)) for a in 1:n_actions(pomdp)] | ||
|
||
return obsmats_A_SP_O | ||
end | ||
|
||
# MDP and POMDP common methods | ||
|
||
POMDPs.n_states(prob::SparseTabularProblem) = size(prob.T[1], 1) | ||
POMDPs.n_actions(prob::SparseTabularProblem) = size(prob.T, 1) | ||
|
||
POMDPs.states(p::SparseTabularProblem) = 1:n_states(p) | ||
POMDPs.actions(p::SparseTabularProblem) = 1:n_actions(p) | ||
|
||
POMDPs.stateindex(::SparseTabularProblem, s::Int64) = s | ||
POMDPs.actionindex(::SparseTabularProblem, a::Int64) = a | ||
|
||
POMDPs.discount(p::SparseTabularProblem) = p.discount | ||
|
||
POMDPs.transition(p::SparseTabularProblem, s::Int64, a::Int64) = SparseCat(findnz(p.T[a][s, :])...) # XXX not memory efficient | ||
|
||
POMDPs.reward(prob::SparseTabularProblem, s::Int64, a::Int64) = prob.R[s, a] | ||
|
||
# POMDP only methods | ||
POMDPs.n_observations(p::SparseTabularPOMDP) = size(p.O[1], 2) | ||
|
||
POMDPs.observations(p::SparseTabularPOMDP) = 1:n_observations(p) | ||
|
||
POMDPs.observation(p::SparseTabularPOMDP, a::Int64, sp::Int64) = SparseCat(findnz(p.O[a][sp, :])...) | ||
|
||
POMDPs.obsindex(p::SparseTabularPOMDP, o::Int64) = o |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
function test_transition(pb1::Union{MDP, POMDP}, pb2::Union{SparseTabularMDP, SparseTabularPOMDP}) | ||
for s in states(pb1) | ||
for a in actions(pb1) | ||
td1 = transition(pb1, s, a) | ||
si = stateindex(pb1, s) | ||
ai = actionindex(pb1, a) | ||
td2 = transition(pb2, si, ai) | ||
for (sp, p) in weighted_iterator(td1) | ||
spi = stateindex(pb1, sp) | ||
@test pdf(td2, spi) == p | ||
end | ||
end | ||
end | ||
end | ||
|
||
function test_reward(pb1::Union{MDP, POMDP}, pb2::Union{SparseTabularMDP, SparseTabularPOMDP}) | ||
for s in states(pb1) | ||
for a in actions(pb1) | ||
si = stateindex(pb1, s) | ||
ai = actionindex(pb1, a) | ||
@test reward(pb1, s, a) == reward(pb2, si, ai) | ||
end | ||
end | ||
end | ||
|
||
function test_observation(pb1::POMDP, pb2::SparseTabularPOMDP) | ||
for s in states(pb1) | ||
for a in actions(pb1) | ||
od1 = observation(pb1, a, s) | ||
si = stateindex(pb1, s) | ||
ai = actionindex(pb1, a) | ||
od2 = observation(pb2, ai, si) | ||
for (o, p) in weighted_iterator(od1) | ||
oi = obsindex(pb1, o) | ||
@test pdf(od2, oi) == p | ||
end | ||
end | ||
end | ||
end | ||
|
||
|
||
## MDP | ||
|
||
mdp = RandomMDP(100, 4, 0.95) | ||
|
||
smdp = SparseTabularMDP(mdp) | ||
|
||
smdp2 = SparseTabularMDP(smdp) | ||
|
||
@test smdp2.T == smdp.T | ||
@test smdp2.R == smdp.R | ||
@test smdp2.discount == smdp.discount | ||
|
||
smdp3 = SparseTabularMDP(smdp, reward = zeros(n_states(mdp), n_actions(mdp))) | ||
@test smdp3.T == smdp.T | ||
@test smdp3.R != smdp.R | ||
smdp3 = SparseTabularMDP(smdp, transition = [sparse(1:n_states(mdp), 1:n_states(mdp), 1.0) for a in 1:n_actions(mdp)]) | ||
@test smdp3.T != smdp.T | ||
@test smdp3.R == smdp.R | ||
smdp3 = SparseTabularMDP(smdp, discount = 0.5) | ||
@test smdp3.discount != smdp.discount | ||
|
||
|
||
## POMDP | ||
|
||
pomdp = TigerPOMDP() | ||
|
||
spomdp = SparseTabularPOMDP(pomdp) | ||
|
||
spomdp2 = SparseTabularPOMDP(spomdp) | ||
|
||
@test spomdp2.T == spomdp.T | ||
@test spomdp2.R == spomdp.R | ||
@test spomdp2.O == spomdp.O | ||
@test spomdp2.discount == spomdp.discount | ||
|
||
spomdp3 = SparseTabularPOMDP(spomdp, reward = zeros(n_states(mdp), n_actions(mdp))) | ||
@test spomdp3.T == spomdp.T | ||
@test spomdp3.R != spomdp.R | ||
spomdp3 = SparseTabularPOMDP(spomdp, transition = [sparse(1:n_states(mdp), 1:n_states(mdp), 1.0) for a in 1:n_actions(mdp)]) | ||
@test spomdp3.T != spomdp.T | ||
@test spomdp3.R == spomdp.R | ||
spomdp3 = SparseTabularPOMDP(spomdp, discount = 0.5) | ||
@test spomdp3.discount != spomdp.discount | ||
|
||
|
||
## Tests | ||
|
||
@test n_states(pomdp) == n_states(spomdp) | ||
@test n_actions(pomdp) == n_actions(spomdp) | ||
@test n_observations(pomdp) == n_observations(spomdp) | ||
@test length(states(spomdp)) == n_states(spomdp) | ||
@test length(actions(spomdp)) == n_actions(spomdp) | ||
@test length(observations(spomdp)) == n_observations(spomdp) | ||
@test statetype(spomdp) == Int64 | ||
@test actiontype(spomdp) == Int64 | ||
@test obstype(spomdp) == Int64 | ||
@test discount(spomdp) == discount(pomdp) | ||
|
||
test_transition(mdp, smdp) | ||
test_transition(pomdp, spomdp) | ||
test_reward(pomdp, spomdp) | ||
test_observation(pomdp, spomdp) |