diff --git a/docs/src/alpha_vector.md b/docs/src/alpha_vector.md index 8f5b359..64e6184 100644 --- a/docs/src/alpha_vector.md +++ b/docs/src/alpha_vector.md @@ -1,7 +1,10 @@ # Alpha Vector Policy -Represents a policy with a set of alpha vectors. Can be constructed with `AlphaVectorPolicy(pomdp, alphas, action_map)`, where alphas is either a vector of vectors or an |S| x |A| matrix. The `action_map` argument is a vector of actions with length equal to the number of alpha vectors. If this argument is not provided, ordered_actions is used to generate a default action map. +Represents a policy with a set of alpha vectors. Can be constructed with [`AlphaVectorPolicy(pomdp, alphas, action_map)`](@ref), where alphas is either a vector of vectors or an |S| x (number of alpha vectors) matrix. The `action_map` argument is a vector of actions with length equal to the number of alpha vectors. If this argument is not provided, ordered_actions is used to generate a default action map. + +Determining the estimated value and optimal action depends on calculating the dot product between alpha vectors and a belief vector. [`POMDPPolicies.beliefvec(pomdp, b)`](@ref) is used to create this vector and can be overridden for new belief types for efficiency. ```@docs AlphaVectorPolicy +POMDPPolicies.beliefvec ``` diff --git a/src/POMDPPolicies.jl b/src/POMDPPolicies.jl index b6240fd..6773e84 100644 --- a/src/POMDPPolicies.jl +++ b/src/POMDPPolicies.jl @@ -3,6 +3,7 @@ module POMDPPolicies using LinearAlgebra using Random using StatsBase # for Weights +using SparseArrays # for sparse vectors in alpha_vector.jl using POMDPs import POMDPs: action, value, solve, updater diff --git a/src/alpha_vector.jl b/src/alpha_vector.jl index 26e1b9d..fbac410 100644 --- a/src/alpha_vector.jl +++ b/src/alpha_vector.jl @@ -1,13 +1,7 @@ -###################################################################### -# alpha_vector.jl -# -# implements policy that is a set of alpha vectors -###################################################################### - """ AlphaVectorPolicy{P<:POMDP, A} -Represents a policy with a set of alpha vector +Represents a policy with a set of alpha vectors Constructor: @@ -28,7 +22,8 @@ end function AlphaVectorPolicy(pomdp::POMDP, alphas) AlphaVectorPolicy(pomdp, alphas, ordered_actions(pomdp)) end -# assumes alphas is |S| x |A| + +# assumes alphas is |S| x (number of alpha vecs) function AlphaVectorPolicy(p::POMDP, alphas::Matrix{Float64}, action_map) # turn alphas into vector of vectors num_actions = size(alphas, 2) @@ -44,17 +39,21 @@ end updater(p::AlphaVectorPolicy) = DiscreteUpdater(p.pomdp) -value(p::AlphaVectorPolicy, b::DiscreteBelief) = value(p, b.b) -function value(p::AlphaVectorPolicy, b::Vector{Float64}) - maximum(dot(b,a) for a in p.alphas) + +# The three functions below rely on beliefvec being implemented for the belief type +# Implementations of beliefvec are below +function value(p::AlphaVectorPolicy, b) + bvec = beliefvec(p.pomdp, b) + maximum(dot(bvec,a) for a in p.alphas) end -function action(p::AlphaVectorPolicy, b::DiscreteBelief) +function action(p::AlphaVectorPolicy, b) + bvec = beliefvec(p.pomdp, b) num_vectors = length(p.alphas) best_idx = 1 max_value = -Inf for i = 1:num_vectors - temp_value = dot(b.b, p.alphas[i]) + temp_value = dot(bvec, p.alphas[i]) if temp_value > max_value max_value = temp_value best_idx = i @@ -63,11 +62,12 @@ function action(p::AlphaVectorPolicy, b::DiscreteBelief) return p.action_map[best_idx] end -function actionvalues(p::AlphaVectorPolicy, b::DiscreteBelief) +function actionvalues(p::AlphaVectorPolicy, b) + bvec = beliefvec(p.pomdp, b) num_vectors = length(p.alphas) max_values = -Inf*ones(n_actions(p.pomdp)) for i = 1:num_vectors - temp_value = dot(b.b, p.alphas[i]) + temp_value = dot(bvec, p.alphas[i]) ai = actionindex(p.pomdp, p.action_map[i]) if temp_value > max_values[ai] max_values[ai] = temp_value @@ -76,45 +76,26 @@ function actionvalues(p::AlphaVectorPolicy, b::DiscreteBelief) return max_values end -function value(p::AlphaVectorPolicy, b::SparseCat) - maximum(sparsecat_dot(p.pomdp, a, b) for a in p.alphas) -end +""" + POMDPPolicies.beliefvec(m::POMDP, b) -function action(p::AlphaVectorPolicy, b::SparseCat) - num_vectors = length(p.alphas) - best_idx = 1 - max_value = -Inf - for i = 1:num_vectors - temp_value = sparsecat_dot(p.pomdp, p.alphas[i], b) - if temp_value > max_value - max_value = temp_value - best_idx = i - end - end - return p.action_map[best_idx] +Return a vector-like representation of the belief `b` suitable for calculating the dot product with the alpha vectors. +""" +function beliefvec end + +function beliefvec(m::POMDP, b::SparseCat) + return sparsevec(collect(stateindex(m, s) for s in b.vals), collect(b.probs), n_states(m)) end - -function actionvalues(p::AlphaVectorPolicy, b::SparseCat) - num_vectors = length(p.alphas) - max_values = -Inf*ones(n_actions(p.pomdp)) - for i = 1:num_vectors - temp_value = sparsecat_dot(p.pomdp, p.alphas[i], b) - ai = actionindex(p.pomdp, p.action_map[i]) - if ( temp_value > max_values[ai]) - max_values[ai] = temp_value - end +beliefvec(m::POMDP, b::DiscreteBelief) = b.b +beliefvec(m::POMDP, b::AbstractArray) = b + +function beliefvec(m::POMDP, b) + sup = support(b) + bvec = zeros(length(sup)) # maybe this should be sparse? + for s in sup + bvec[stateindex(m, s)] = pdf(b, s) end - return max_values - end - -# perform dot product between an alpha vector and a sparse cat object -function sparsecat_dot(problem::POMDP, alpha::Vector{Float64}, b::SparseCat) - val = 0. - for (s, p) in weighted_iterator(b) - si = stateindex(problem, s) - val += alpha[si]*p - end - return val + return bvec end function Base.push!(p::AlphaVectorPolicy, alpha::Vector{Float64}, a) diff --git a/test/test_alpha_policy.jl b/test/test_alpha_policy.jl index c5b5a2d..0a2fb4a 100644 --- a/test/test_alpha_policy.jl +++ b/test/test_alpha_policy.jl @@ -23,6 +23,10 @@ let @test isapprox(value(policy, sparse_b0), -16.0629) @test isapprox(actionvalues(policy, sparse_b0), [-16.0629, -19.4557]) @test action(policy, sparse_b0) == false + + # Bool_distribution (if it works for this, it should work for an arbitrary distribution) + bd = initialstate_distribution(pomdp)::BoolDistribution + @test action(policy, bd) == false # try pushing new vector push!(policy, [0.0,0.0], true)