Skip to content

Commit

Permalink
ddp-sa test pass!
Browse files Browse the repository at this point in the history
  • Loading branch information
sglyon committed Mar 14, 2016
1 parent 5d8586d commit 9e59574
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 66 deletions.
16 changes: 5 additions & 11 deletions src/markov/ddp.jl
Expand Up @@ -66,7 +66,7 @@ type DiscreteDP{T<:Real,NQ,NR,Tbeta<:Real,Tind}
msg = "R must be 2-dimensional without s-a formulation"
throw(ArgumentError(msg))
end
beta < 0 || beta >= 1 && throw(ArgumentError("beta must be [0, 1)"))
(beta < 0 || beta >= 1) && throw(ArgumentError("beta must be [0, 1)"))

# verify input integrity 2
num_states, num_actions = size(R)
Expand Down Expand Up @@ -102,7 +102,7 @@ type DiscreteDP{T<:Real,NQ,NR,Tbeta<:Real,Tind}
if NR != 1
throw(ArgumentError("R must be 1-dimensional with s-a formulation"))
end
beta < 0 || beta >= 1 && throw(ArgumentError("beta must be [0, 1)"))
(beta < 0 || beta >= 1) && throw(ArgumentError("beta must be [0, 1)"))

# verify input integrity (same length)
num_sa_pairs, num_states = size(Q)
Expand All @@ -128,7 +128,7 @@ type DiscreteDP{T<:Real,NQ,NR,Tbeta<:Real,Tind}
n = maximum(s_indices)
msg = "Duplicate s-a pair found"
as_ptr = sparse(a_indices, s_indices, 1:num_sa_pairs, m, n,
(x,y)->error(msg))
(x,y)->throw(ArgumentError(msg)))
a_indices = as_ptr.rowval
a_indptr = as_ptr.colptr

Expand Down Expand Up @@ -541,8 +541,8 @@ s_wise_max!(vals::AbstractMatrix, out::Vector) = (println("calling this one! ");
Populate `out` with `max_a vals(s, a)`, where `vals` is represented as a
`AbstractMatrix` of size `(num_states, num_actions)`.
Also fills `out_argmax` with the linear index associated with the indmax in each
row
Also fills `out_argmax` with the column number associated with the indmax in
each row
"""
function s_wise_max!(vals::AbstractMatrix, out::Vector, out_argmax::Vector)
# naive implementation where I just iterate over the rows
Expand All @@ -562,11 +562,6 @@ function s_wise_max!(vals::AbstractMatrix, out::Vector, out_argmax::Vector)
end

end
# HACK: convert to linear index for intermediate testing
# sv = size(vals)
# for (i, c) in enumerate(out_argmax)
# out_argmax[i] = sub2ind(sv, i, c)
# end
out, out_argmax
end

Expand Down Expand Up @@ -616,7 +611,6 @@ function s_wise_max!(a_indices::Vector, a_indptr::Vector, vals::Vector,
if a_indptr[i] != a_indptr[i+1]
m = a_indptr[i]
for j in a_indptr[i]+1:(a_indptr[i+1]-1)
@show i, j, m, vals[j], vals[m]
if vals[j] > vals[m]
m = j
end
Expand Down
166 changes: 111 additions & 55 deletions test/test_ddp.jl
Expand Up @@ -37,7 +37,7 @@ Tests for markov/ddp.jl
ddp0_sa = DiscreteDP(R_sa, Q_sa, beta, s_indices, a_indices)

# List of ddp formulations
ddp0_collection = (ddp0,)
ddp0_collection = (ddp0, ddp0_sa)

# Maximum Iteration and Epsilon for Tests
max_iter = 200
Expand All @@ -47,15 +47,16 @@ Tests for markov/ddp.jl
v_star = [(5-5.5*beta)/((1-0.5*beta)*(1-beta)), -1/(1-beta)]
sigma_star = [1, 1]

@testset "test bellman_operator methods" begin
@testset "bellman_operator methods" begin
# Check both Dense and State-Action Pair Formulation
for ddp_item in ddp0_collection
@test isapprox(bellman_operator(ddp_item, v_star), v_star)
for ddp in ddp0_collection
@test isapprox(bellman_operator(ddp, v_star), v_star)
end
end

@testset "test RQ_sigma" begin
@testset "RQ_sigma" begin
nr, nc = size(R)
# test for DDP
sigmas = ([1, 1], [1, 2], [2, 1], [2, 2])
for sig in sigmas
r, q = RQ_sigma(ddp0, sig)
Expand All @@ -67,63 +68,67 @@ Tests for markov/ddp.jl
end
end
end

# TODO: add test for DDPsa
end

@testset "test compute_greedy methods" begin
@testset "compute_greedy methods" begin
# Check both Dense and State-Action Pair Formulation
for ddp_item in ddp0_collection
@test compute_greedy(ddp_item, v_star) == sigma_star
for ddp in ddp0_collection
@test compute_greedy(ddp, v_star) == sigma_star
end
end

@testset "test evaluate_policy methods" begin
@testset "evaluate_policy methods" begin
# Check both Dense and State-Action Pair Formulation
for ddp_item in ddp0_collection
@test isapprox(evaluate_policy(ddp_item, sigma_star), v_star)
for ddp in ddp0_collection
@test isapprox(evaluate_policy(ddp, sigma_star), v_star)
end
end

@testset "test methods for subtypes != (Float64, Int)" begin
@testset "methods for subtypes != (Float64, Int)" begin
float_types = [Float16, Float32, Float64, BigFloat]
int_types = [Int8, Int16, Int32, Int64, Int128,
UInt8, UInt16, UInt32, UInt64, UInt128]

for f in (bellman_operator, compute_greedy)
for T in float_types
f_f64 = f(ddp0, [1.0, 1.0])
f_T = f(ddp0, ones(T, 2))
@test isapprox(f_f64, convert(Vector{eltype(f_f64)}, f_T))
end
for ddp in ddp0_collection
for f in (bellman_operator, compute_greedy)
for T in float_types
f_f64 = f(ddp, [1.0, 1.0])
f_T = f(ddp, ones(T, 2))
@test isapprox(f_f64, convert(Vector{eltype(f_f64)}, f_T))
end

# only Integer subtypes can be Rational type params
# NOTE: Only the integer types below don't overflow for this example
for T in [Int64, Int128]
@test f(ddp0, [1//1, 1//1]) == f(ddp0, ones(Rational{T}, 2))
# only Integer subtypes can be Rational type params
# NOTE: Only the integer types below don't overflow for this example
for T in [Int64, Int128]
@test f(ddp, [1//1, 1//1]) == f(ddp, ones(Rational{T}, 2))
end
end
end

for T in float_types, S in int_types
v = ones(T, 2)
s = ones(S, 2)
# just test that we can call the method and the result is
# deterministic
@test bellman_operator!(ddp0, v, s) == bellman_operator!(ddp0, v, s)
end
for T in float_types, S in int_types
v = ones(T, 2)
s = ones(S, 2)
# just test that we can call the method and the result is
# deterministic
@test bellman_operator!(ddp, v, s) == bellman_operator!(ddp, v, s)
end

for T in int_types
s = T[1, 1]
@test isapprox(evaluate_policy(ddp0, s), v_star)
for T in int_types
s = T[1, 1]
@test isapprox(evaluate_policy(ddp, s), v_star)
end
end
end

@testset "test compute_greedy! changes ddpr.v" begin
@testset "compute_greedy! changes ddpr.v" begin
res = solve(ddp0, VFI)
res.Tv[:] = 500.0
compute_greedy!(ddp0, res)
@test maxabs(res.Tv - 500.0) > 0
end

@testset "test value_iteration" begin
@testset "value_iteration" begin
# Check both Dense and State-Action Pair Formulation
for ddp_item in ddp0_collection
# Compute Result
Expand All @@ -142,7 +147,7 @@ Tests for markov/ddp.jl
end
end

@testset "test policy_iteration" begin
@testset "policy_iteration" begin
# Check both Dense and State-Action Pair Formulation
for ddp_item in ddp0_collection
res = solve(ddp_item, PFI)
Expand All @@ -159,7 +164,7 @@ Tests for markov/ddp.jl
end
end

@testset "test DiscreteDP{Rational,_,_,Rational} maintains Rational" begin
@testset "DiscreteDP{Rational,_,_,Rational} maintains Rational" begin
ddp_rational = DiscreteDP(map(Rational{BigInt}, R),
map(Rational{BigInt}, Q),
map(Rational{BigInt}, beta))
Expand All @@ -170,7 +175,7 @@ Tests for markov/ddp.jl
@test eltype(solve(ddp_rational, vi, MPFI; max_iter=1, k=1, epsilon=Inf).v) == Rational{BigInt}
end

@testset "test DiscreteDP{Rational{BigInt},_,_,Rational{BigInt}} works" begin
@testset "DiscreteDP{Rational{BigInt},_,_,Rational{BigInt}} works" begin
ddp_rational = DiscreteDP(map(Rational{BigInt}, R),
map(Rational{BigInt}, Q),
map(Rational{BigInt}, beta))
Expand All @@ -185,7 +190,7 @@ Tests for markov/ddp.jl
@test r1.mc.p == r3.mc.p
end

@testset "test modified_policy_iteration" begin
@testset "modified_policy_iteration" begin
for ddp_item in ddp0_collection
res = solve(ddp_item, MPFI)
v_init = [0.0, 1.0]
Expand All @@ -211,27 +216,78 @@ Tests for markov/ddp.jl
end
end

@testset "test ddp_no_feasible_action_error" begin
#Dense Matrix
n, m = 2, 2
R = [-Inf -Inf; 1.0 2.0]
@testset "DDPsa constructor" begin
@testset "feasbile action pair" begin
_R = [1.0, 0.0, 0.0, 1.0]
_Q = fill(1/3, 4, 3)
_s_ind = [1, 1, 3, 3]
_a_ind = [1, 2, 1, 2]
@test_throws ArgumentError DiscreteDP(_R, _Q, beta, _s_ind, _a_ind)
end

Q = Array(Float64, n, m, n)
Q[:, :, 1] = [0.5 0.0; 0.0 0.0]
Q[:, :, 2] = [0.5 1.0; 1.0 1.0]
beta = 0.95
_R, _Q = R_sa, Q_sa
_s_ind = [1, 1, 2]
_a_ind = [1, 2, 1]

@test_throws ArgumentError DiscreteDP(R, Q, beta)
@testset "beta in [0, 1)" begin
@test_throws ArgumentError DiscreteDP(_R, _Q, -eps(), _s_ind, _a_ind)
@test_throws ArgumentError DiscreteDP(_R, _Q, 1.0, _s_ind, _a_ind)
@test_throws ArgumentError DiscreteDP(_R, _Q, 1+eps(), _s_ind, _a_ind)
end

# # State-Action Pair Formulation
# s_indices = [1, 1, 3, 3]
# a_indices = [1, 2, 1, 2]
# #TODO: @sglyon We need to construct R_sa, Q_sa right?
#
# @test_throws ArgumentError DiscreteDP(R, Q, beta, s_indices, a_indices)
@testset "argument sizes" begin
# NQ != 2
@test_throws ArgumentError DiscreteDP(_R, rand(4, 3, 1), beta, _s_ind, _a_ind)

# NR != 1
@test_throws ArgumentError DiscreteDP(rand(4, 1), _Q, beta, _s_ind, _a_ind)

# incorrect lengths
@test_throws ArgumentError DiscreteDP(rand(2), _Q, beta, _s_ind, _a_ind)
@test_throws ArgumentError DiscreteDP(_R, rand(5, 2), beta, _s_ind, _a_ind)
@test_throws ArgumentError DiscreteDP(_R, _Q, beta, rand(1:3, 2), _a_ind)
@test_throws ArgumentError DiscreteDP(_R, _Q, beta, _s_ind, rand(1:3, 2))
end

@testset "duplicate sa pair" begin
@test_throws ArgumentError DiscreteDP(_R, _Q, beta, _s_ind, [1, 1, 2])
end
end

@testset "DDP constructor" begin
@testset "beta in [0, 1)" begin
@test_throws ArgumentError DiscreteDP(R, Q, -eps())
@test_throws ArgumentError DiscreteDP(R, Q, 1.0)
@test_throws ArgumentError DiscreteDP(R, Q, 1+eps())
end

@testset "feasbile action pair" begin
#Dense Matrix
n, m = 2, 2
_R = [-Inf -Inf; 1.0 2.0]

_Q = Array(Float64, n, m, n)
_Q[:, :, 1] = [0.5 0.0; 0.0 0.0]
_Q[:, :, 2] = [0.5 1.0; 1.0 1.0]
_beta = 0.95

@test_throws ArgumentError DiscreteDP(_R, _Q, _beta)
end

@testset "R, Q sizes" begin
# NQ != 3
@test_throws ArgumentError DiscreteDP(R, zeros(2, 2), beta)

# NR != 2
@test_throws ArgumentError DiscreteDP(zeros(1), Q, beta)

# incompatible dimensions
@test_throws ArgumentError DiscreteDP(zeros(2, 3), Q, beta)
@test_throws ArgumentError DiscreteDP(R, zeros(2, 3, 2), beta)
end
end

@testset "test ddp_negative_inf_error()" begin
@testset "ddp_negative_inf_error()" begin
# Dense Matrix
n, m = 3, 2
R = [0 1;
Expand Down

0 comments on commit 9e59574

Please sign in to comment.