Skip to content

Commit

Permalink
implement SparseTabular
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximeBouton committed Aug 7, 2019
1 parent 8a1e0d0 commit 6bacd51
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/POMDPModelTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,10 @@ export
showdistribution
include("distributions/pretty_printing.jl")

export
SparseTabularMDP,
SparseTabularPOMDP

include("sparse_tabular.jl")

end # module
156 changes: 156 additions & 0 deletions src/sparse_tabular.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
struct SparseTabularMDP <: MDP{Int64, Int64}
T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp]
R::Array{Float64, 2} # R[s,sp]
discount::Float64
end

function SparseTabularMDP(mdp::MDP)
T = transition_matrix_a_s_sp(mdp)
R = reward_s_a(mdp)
return SparseTabularMDP(T, R, discount(mdp))
end

function SparseTabularMDP(mdp::SparseTabularMDP;
transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing,
reward::Union{Nothing, Array{Float64, 2}} = nothing,
discount::Union{Nothing, Float64} = nothing)
T = transition != nothing ? transition : mdp.T
R = reward != nothing ? reward : mdp.R
d = discount != nothing ? discount : mdp.discount
return SparseTabularMDP(T, R, d)
end

struct SparseTabularPOMDP <: POMDP{Int64, Int64, Int64}
T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][s, sp]
R::Array{Float64, 2} # R[s,sp]
O::Vector{SparseMatrixCSC{Float64, Int64}} # O[a][sp, o]
discount::Float64
end

function SparseTabularPOMDP(pomdp::POMDP)
T = transition_matrix_a_s_sp(pomdp)
R = reward_s_a(pomdp)
O = observation_matrix_a_sp_o(pomdp)
return SparseTabularPOMDP(T, R, O, discount(pomdp))
end

function SparseTabularPOMDP(pomdp::SparseTabularPOMDP;
transition::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing,
reward::Union{Nothing, Array{Float64, 2}} = nothing,
observation::Union{Nothing, Vector{SparseMatrixCSC{Float64, Int64}}} = nothing,
discount::Union{Nothing, Float64} = nothing)
T = transition != nothing ? transition : pomdp.T
R = reward != nothing ? reward : pomdp.R
d = discount != nothing ? discount : pomdp.discount
O = observation != nothing ? transition : pomdp.O
return SparseTabularPOMDP(T, R, O, d)
end

const SparseTabularProblem = Union{SparseTabularMDP, SparseTabularPOMDP}


function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP})
# Thanks to zach
na = n_actions(mdp)
ns = n_states(mdp)
transmat_row_A = [Int[] for _ in 1:n_actions(mdp)]
transmat_col_A = [Int[] for _ in 1:n_actions(mdp)]
transmat_data_A = [Float64[] for _ in 1:n_actions(mdp)]

for s in states(mdp)
si = stateindex(mdp, s)
for a in actions(mdp, s)
ai = actionindex(mdp, a)
if !isterminal(mdp, s) # if terminal, the transition probabilities are all just zero
td = transition(mdp, s, a)
for (sp, p) in weighted_iterator(td)
if p > 0.0
spi = stateindex(mdp, sp)
push!(transmat_row_A[ai], si)
push!(transmat_col_A[ai], spi)
push!(transmat_data_A[ai], p)
end
end
end
end
end
transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)]
# Note: assert below is not valid for terminal states
# @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1"
return transmats_A_S_S2
end

function reward_s_a(mdp::Union{MDP, POMDP})
reward_S_A = fill(-Inf, (n_states(mdp), n_actions(mdp))) # set reward for all actions to -Inf unless they are in actions(mdp, s)
for s in states(mdp)
if isterminal(mdp, s)
reward_S_A[stateindex(mdp, s), :] .= 0.0
else
for a in actions(mdp, s)
td = transition(mdp, s, a)
r = 0.0
for (sp, p) in weighted_iterator(td)
if p > 0.0
r += p*reward(mdp, s, a, sp)
end
end
reward_S_A[stateindex(mdp, s), actionindex(mdp, a)] = r
end
end
end
return reward_S_A
end

function observation_matrix_a_sp_o(pomdp::POMDP)
na = n_actions(pomdp)
ns = n_states(pomdp)
no = n_observations(pomdp)
obsmat_row_A = [Int[] for _ in 1:n_actions(pomdp)]
obsmat_col_A = [Int[] for _ in 1:n_actions(pomdp)]
obsmat_data_A = [Float64[] for _ in 1:n_actions(pomdp)]

for sp in states(pomdp)
spi = stateindex(pomdp, sp)
for a in actions(pomdp)
ai = actionindex(pomdp, a)
od = observation(pomdp, a, sp)
for (o, p) in weighted_iterator(od)
if p > 0.0
oi = obsindex(pomdp, o)
push!(obsmat_row_A[ai], spi)
push!(obsmat_col_A[ai], oi)
push!(obsmat_data_A[ai], p)
end
end
end
end
obsmats_A_SP_O = [sparse(obsmat_row_A[a], obsmat_col_A[a], obsmat_data_A[a], n_states(pomdp), n_states(pomdp)) for a in 1:n_actions(pomdp)]

return obsmats_A_SP_O
end

# MDP and POMDP common methods

POMDPs.n_states(prob::SparseTabularProblem) = size(prob.T[1], 1)
POMDPs.n_actions(prob::SparseTabularProblem) = size(prob.T, 1)

POMDPs.states(p::SparseTabularProblem) = 1:n_states(p)
POMDPs.actions(p::SparseTabularProblem) = 1:n_actions(p)

POMDPs.stateindex(::SparseTabularProblem, s::Int64) = s
POMDPs.actionindex(::SparseTabularProblem, a::Int64) = a

POMDPs.discount(p::SparseTabularProblem) = p.discount

POMDPs.transition(p::SparseTabularProblem, s::Int64, a::Int64) = SparseCat(findnz(p.T[a][s, :])...) # XXX not memory efficient

POMDPs.reward(prob::SparseTabularProblem, s::Int64, a::Int64) = prob.R[s, a]

# POMDP only methods
POMDPs.n_observations(p::SparseTabularPOMDP) = size(p.O[1], 2)

POMDPs.observations(p::SparseTabularPOMDP) = 1:n_observations(p)

POMDPs.observation(p::SparseTabularPOMDP, a::Int64, sp::Int64) = SparseCat(findnz(p.O[a][sp, :])...)

POMDPs.obsindex(p::SparseTabularPOMDP, o::Int64) = o
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Pkg
using POMDPSimulators
using POMDPPolicies
import Distributions.Categorical
using SparseArrays

@testset "ordered" begin
include("test_ordered_spaces.jl")
Expand Down Expand Up @@ -66,3 +67,7 @@ end
@testset "pretty printing" begin
include("test_pretty_printing.jl")
end

@testset "sparse tabular" begin
include("test_tabular.jl")
end
103 changes: 103 additions & 0 deletions test/test_tabular.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
function test_transition(pb1::Union{MDP, POMDP}, pb2::Union{SparseTabularMDP, SparseTabularPOMDP})
for s in states(pb1)
for a in actions(pb1)
td1 = transition(pb1, s, a)
si = stateindex(pb1, s)
ai = actionindex(pb1, a)
td2 = transition(pb2, si, ai)
for (sp, p) in weighted_iterator(td1)
spi = stateindex(pb1, sp)
@test pdf(td2, spi) == p
end
end
end
end

function test_reward(pb1::Union{MDP, POMDP}, pb2::Union{SparseTabularMDP, SparseTabularPOMDP})
for s in states(pb1)
for a in actions(pb1)
si = stateindex(pb1, s)
ai = actionindex(pb1, a)
@test reward(pb1, s, a) == reward(pb2, si, ai)
end
end
end

function test_observation(pb1::POMDP, pb2::SparseTabularPOMDP)
for s in states(pb1)
for a in actions(pb1)
od1 = observation(pb1, a, s)
si = stateindex(pb1, s)
ai = actionindex(pb1, a)
od2 = observation(pb2, ai, si)
for (o, p) in weighted_iterator(od1)
oi = obsindex(pb1, o)
@test pdf(od2, oi) == p
end
end
end
end


## MDP

mdp = RandomMDP(100, 4, 0.95)

smdp = SparseTabularMDP(mdp)

smdp2 = SparseTabularMDP(smdp)

@test smdp2.T == smdp.T
@test smdp2.R == smdp.R
@test smdp2.discount == smdp.discount

smdp3 = SparseTabularMDP(smdp, reward = zeros(n_states(mdp), n_actions(mdp)))
@test smdp3.T == smdp.T
@test smdp3.R != smdp.R
smdp3 = SparseTabularMDP(smdp, transition = [sparse(1:n_states(mdp), 1:n_states(mdp), 1.0) for a in 1:n_actions(mdp)])
@test smdp3.T != smdp.T
@test smdp3.R == smdp.R
smdp3 = SparseTabularMDP(smdp, discount = 0.5)
@test smdp3.discount != smdp.discount


## POMDP

pomdp = TigerPOMDP()

spomdp = SparseTabularPOMDP(pomdp)

spomdp2 = SparseTabularPOMDP(spomdp)

@test spomdp2.T == spomdp.T
@test spomdp2.R == spomdp.R
@test spomdp2.O == spomdp.O
@test spomdp2.discount == spomdp.discount

spomdp3 = SparseTabularPOMDP(spomdp, reward = zeros(n_states(mdp), n_actions(mdp)))
@test spomdp3.T == spomdp.T
@test spomdp3.R != spomdp.R
spomdp3 = SparseTabularPOMDP(spomdp, transition = [sparse(1:n_states(mdp), 1:n_states(mdp), 1.0) for a in 1:n_actions(mdp)])
@test spomdp3.T != spomdp.T
@test spomdp3.R == spomdp.R
spomdp3 = SparseTabularPOMDP(spomdp, discount = 0.5)
@test spomdp3.discount != spomdp.discount


## Tests

@test n_states(pomdp) == n_states(spomdp)
@test n_actions(pomdp) == n_actions(spomdp)
@test n_observations(pomdp) == n_observations(spomdp)
@test length(states(spomdp)) == n_states(spomdp)
@test length(actions(spomdp)) == n_actions(spomdp)
@test length(observations(spomdp)) == n_observations(spomdp)
@test statetype(spomdp) == Int64
@test actiontype(spomdp) == Int64
@test obstype(spomdp) == Int64
@test discount(spomdp) == discount(pomdp)

test_transition(mdp, smdp)
test_transition(pomdp, spomdp)
test_reward(pomdp, spomdp)
test_observation(pomdp, spomdp)

0 comments on commit 6bacd51

Please sign in to comment.