Skip to content

Commit

Permalink
move to cartesian indexing for dense ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
sglyon committed Mar 14, 2016
1 parent 17524da commit 5d8586d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 31 deletions.
77 changes: 50 additions & 27 deletions src/markov/ddp.jl
Expand Up @@ -75,7 +75,7 @@ type DiscreteDP{T<:Real,NQ,NR,Tbeta<:Real,Tind}
end

# check feasibility
R_max = _s_wise_max(R)
R_max = s_wise_max(R)
if any(R_max .== -Inf)
# First state index such that all actions yield -Inf
s = find(R_max .== -Inf) #-Only Gives True
Expand Down Expand Up @@ -500,10 +500,8 @@ the transition probability matrix `Q_sigma`.
"""
function RQ_sigma{T<:Integer}(ddp::DDP, sigma::Array{T})
R_sigma = ddp.R[sigma]
# convert from linear index based on R to column number
ind = map(x->ind2sub(size(ddp.R), x), sigma)
Q_sigma = hcat([getindex(ddp.Q, ind[i]..., Colon())[:] for i=1:num_states(ddp)]...)
R_sigma = [ddp.R[i, sigma[i]] for i in 1:length(sigma)]
Q_sigma = hcat([getindex(ddp.Q, i, sigma[i], Colon())[:] for i=1:num_states(ddp)]...)
return R_sigma, Q_sigma'
end

Expand All @@ -520,46 +518,74 @@ end
# Internal methods #
# ---------------- #

s_wise_max(ddp::DiscreteDP, vals::Matrix) = _s_wise_max(vals)
## s_wise_max for DDP

s_wise_max!(ddp::DiscreteDP, vals::Matrix, out::Vector, out_argmax::Vector) =
_s_wise_max!(vals, out, out_argmax)
s_wise_max(ddp::DiscreteDP, vals::AbstractMatrix) = s_wise_max(vals)

function s_wise_max(ddp::DDPsa, vals::Vector)
_s_wise_max!(get(ddp.a_indices), get(ddp.a_indptr), vals,
Array(Float64, num_states(ddp)))
end

s_wise_max!(ddp::DDPsa, vals::Vector, out::Vector, out_argmax::Vector) =
_s_wise_max!(get(ddp.a_indices), get(ddp.a_indptr), vals, out, out_argmax)
s_wise_max!(ddp::DiscreteDP, vals::AbstractMatrix, out::Vector, out_argmax::Vector) =
s_wise_max!(vals, out, out_argmax)

"""
Return the `Vector` `max_a vals(s, a)`, where `vals` is represented as a
`Matrix` of size `(num_states, num_actions)`.
`AbstractMatrix` of size `(num_states, num_actions)`.
"""
_s_wise_max(vals::Matrix) = vec(maximum(vals, 2))
s_wise_max(vals::AbstractMatrix) = vec(maximum(vals, 2))

"""
Populate `out` with `max_a vals(s, a)`, where `vals` is represented as a
`Matrix` of size `(num_states, num_actions)`.
`AbstractMatrix` of size `(num_states, num_actions)`.
"""
_s_wise_max!(vals::Matrix, out::Vector) = (println("calling this one! "); maximum!(out, vals))
s_wise_max!(vals::AbstractMatrix, out::Vector) = (println("calling this one! "); maximum!(out, vals))

"""
Populate `out` with `max_a vals(s, a)`, where `vals` is represented as a
`Matrix` of size `(num_states, num_actions)`.
`AbstractMatrix` of size `(num_states, num_actions)`.
Also fills `out_argmax` with the linear index associated with the indmax in each
row
"""
_s_wise_max!(vals::Matrix, out::Vector, out_argmax::Vector) =
Base.findminmax!(Base.MoreFun(), fill!(out, -Inf) , out_argmax, vals)
function s_wise_max!(vals::AbstractMatrix, out::Vector, out_argmax::Vector)
# naive implementation where I just iterate over the rows
nr, nc = size(vals)
for i_r in 1:nr
# reset temporaries
cur_max = -Inf
out_argmax[i_r] = 1

for i_c in 1:nc
@inbounds v_rc = vals[i_r, i_c]
if v_rc > cur_max
out[i_r] = v_rc
out_argmax[i_r] = i_c
cur_max = v_rc
end
end

end
# HACK: convert to linear index for intermediate testing
# sv = size(vals)
# for (i, c) in enumerate(out_argmax)
# out_argmax[i] = sub2ind(sv, i, c)
# end
out, out_argmax
end


## s_wise_max for DDPsa

function s_wise_max(ddp::DDPsa, vals::Vector)
s_wise_max!(get(ddp.a_indices), get(ddp.a_indptr), vals,
Array(Float64, num_states(ddp)))
end

s_wise_max!(ddp::DDPsa, vals::Vector, out::Vector, out_argmax::Vector) =
s_wise_max!(get(ddp.a_indices), get(ddp.a_indptr), vals, out, out_argmax)

"""
Populate `out` with `max_a vals(s, a)`, where `vals` is represented as a
`Vector` of size `(num_sa_pairs,)`.
"""
function _s_wise_max!(a_indices::Vector, a_indptr::Vector, vals::Vector,
function s_wise_max!(a_indices::Vector, a_indptr::Vector, vals::Vector,
out::Vector)
n = length(out)
for i in 1:n
Expand All @@ -583,7 +609,7 @@ Populate `out` with `max_a vals(s, a)`, where `vals` is represented as a
Also fills `out_argmax` with the cartesiean index associated with the indmax in
each row
"""
function _s_wise_max!(a_indices::Vector, a_indptr::Vector, vals::Vector,
function s_wise_max!(a_indices::Vector, a_indptr::Vector, vals::Vector,
out::Vector, out_argmax::Vector)
n = length(out)
for i in 1:n
Expand Down Expand Up @@ -682,9 +708,6 @@ function *{T}(A::Array{T,3}, v::Vector)
return reshape(out, shape[1:end-1])
end

Base.ind2sub(ddp::DiscreteDP, x::Vector) =
map(_ -> ind2sub(size(ddp.R), _)[2], x)

"""
Impliments Value Iteration
NOTE: See `solve` for further details
Expand Down
22 changes: 18 additions & 4 deletions test/test_ddp.jl
Expand Up @@ -45,7 +45,7 @@ Tests for markov/ddp.jl

# Analytical solution for beta > 10/11, Example 6.2.1
v_star = [(5-5.5*beta)/((1-0.5*beta)*(1-beta)), -1/(1-beta)]
sigma_star = [1, 2]
sigma_star = [1, 1]

@testset "test bellman_operator methods" begin
# Check both Dense and State-Action Pair Formulation
Expand All @@ -54,6 +54,21 @@ Tests for markov/ddp.jl
end
end

@testset "test RQ_sigma" begin
nr, nc = size(R)
sigmas = ([1, 1], [1, 2], [2, 1], [2, 2])
for sig in sigmas
r, q = RQ_sigma(ddp0, sig)

for i_r in 1:nr
@test r[i_r] == ddp0.R[i_r, sig[i_r]]
for i_c in 1:length(sig)
@test vec(q[i_c, :]) == vec(ddp0.Q[i_c, sig[i_c], :])
end
end
end
end

@testset "test compute_greedy methods" begin
# Check both Dense and State-Action Pair Formulation
for ddp_item in ddp0_collection
Expand Down Expand Up @@ -96,7 +111,7 @@ Tests for markov/ddp.jl
end

for T in int_types
s = T[1, 2]
s = T[1, 1]
@test isapprox(evaluate_policy(ddp0, s), v_star)
end
end
Expand Down Expand Up @@ -228,7 +243,7 @@ Tests for markov/ddp.jl
@test_throws ArgumentError DiscreteDP(R, Q, beta)

# State-Action Pair Formulation
#
#
# s_indices = [0, 0, 1, 1, 2, 2]
# a_indices = [0, 1, 0, 1, 0, 1]
# R_sa = reshape(R, n*m)
Expand All @@ -237,5 +252,4 @@ Tests for markov/ddp.jl
# @test_throws ArgumentError DiscreteDP(R_sa, Q_sa, beta, s_indices, a_indices)
end


end # end @testset

0 comments on commit 5d8586d

Please sign in to comment.