Skip to content

Commit

Permalink
Merge pull request #8 from JuliaPOMDP/alpha_arb_bel
Browse files Browse the repository at this point in the history
Alpha vector policies work with arbitrary beliefs
  • Loading branch information
zsunberg committed Dec 19, 2018
2 parents dfab103 + 88cf930 commit 176f0f4
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 52 deletions.
5 changes: 4 additions & 1 deletion docs/src/alpha_vector.md
@@ -1,7 +1,10 @@
# Alpha Vector Policy

Represents a policy with a set of alpha vectors. Can be constructed with `AlphaVectorPolicy(pomdp, alphas, action_map)`, where alphas is either a vector of vectors or an |S| x |A| matrix. The `action_map` argument is a vector of actions with length equal to the number of alpha vectors. If this argument is not provided, ordered_actions is used to generate a default action map.
Represents a policy with a set of alpha vectors. Can be constructed with [`AlphaVectorPolicy(pomdp, alphas, action_map)`](@ref), where alphas is either a vector of vectors or an |S| x (number of alpha vectors) matrix. The `action_map` argument is a vector of actions with length equal to the number of alpha vectors. If this argument is not provided, ordered_actions is used to generate a default action map.

Determining the estimated value and optimal action depends on calculating the dot product between alpha vectors and a belief vector. [`POMDPPolicies.beliefvec(pomdp, b)`](@ref) is used to create this vector and can be overridden for new belief types for efficiency.

```@docs
AlphaVectorPolicy
POMDPPolicies.beliefvec
```
1 change: 1 addition & 0 deletions src/POMDPPolicies.jl
Expand Up @@ -3,6 +3,7 @@ module POMDPPolicies
using LinearAlgebra
using Random
using StatsBase # for Weights
using SparseArrays # for sparse vectors in alpha_vector.jl

using POMDPs
import POMDPs: action, value, solve, updater
Expand Down
83 changes: 32 additions & 51 deletions src/alpha_vector.jl
@@ -1,13 +1,7 @@
######################################################################
# alpha_vector.jl
#
# implements policy that is a set of alpha vectors
######################################################################

"""
AlphaVectorPolicy{P<:POMDP, A}
Represents a policy with a set of alpha vector
Represents a policy with a set of alpha vectors
Constructor:
Expand All @@ -28,7 +22,8 @@ end
function AlphaVectorPolicy(pomdp::POMDP, alphas)
AlphaVectorPolicy(pomdp, alphas, ordered_actions(pomdp))
end
# assumes alphas is |S| x |A|

# assumes alphas is |S| x (number of alpha vecs)
function AlphaVectorPolicy(p::POMDP, alphas::Matrix{Float64}, action_map)
# turn alphas into vector of vectors
num_actions = size(alphas, 2)
Expand All @@ -44,17 +39,21 @@ end

updater(p::AlphaVectorPolicy) = DiscreteUpdater(p.pomdp)

value(p::AlphaVectorPolicy, b::DiscreteBelief) = value(p, b.b)
function value(p::AlphaVectorPolicy, b::Vector{Float64})
maximum(dot(b,a) for a in p.alphas)

# The three functions below rely on beliefvec being implemented for the belief type
# Implementations of beliefvec are below
function value(p::AlphaVectorPolicy, b)
bvec = beliefvec(p.pomdp, b)
maximum(dot(bvec,a) for a in p.alphas)
end

function action(p::AlphaVectorPolicy, b::DiscreteBelief)
function action(p::AlphaVectorPolicy, b)
bvec = beliefvec(p.pomdp, b)
num_vectors = length(p.alphas)
best_idx = 1
max_value = -Inf
for i = 1:num_vectors
temp_value = dot(b.b, p.alphas[i])
temp_value = dot(bvec, p.alphas[i])
if temp_value > max_value
max_value = temp_value
best_idx = i
Expand All @@ -63,11 +62,12 @@ function action(p::AlphaVectorPolicy, b::DiscreteBelief)
return p.action_map[best_idx]
end

function actionvalues(p::AlphaVectorPolicy, b::DiscreteBelief)
function actionvalues(p::AlphaVectorPolicy, b)
bvec = beliefvec(p.pomdp, b)
num_vectors = length(p.alphas)
max_values = -Inf*ones(n_actions(p.pomdp))
for i = 1:num_vectors
temp_value = dot(b.b, p.alphas[i])
temp_value = dot(bvec, p.alphas[i])
ai = actionindex(p.pomdp, p.action_map[i])
if temp_value > max_values[ai]
max_values[ai] = temp_value
Expand All @@ -76,45 +76,26 @@ function actionvalues(p::AlphaVectorPolicy, b::DiscreteBelief)
return max_values
end

function value(p::AlphaVectorPolicy, b::SparseCat)
maximum(sparsecat_dot(p.pomdp, a, b) for a in p.alphas)
end
"""
POMDPPolicies.beliefvec(m::POMDP, b)
function action(p::AlphaVectorPolicy, b::SparseCat)
num_vectors = length(p.alphas)
best_idx = 1
max_value = -Inf
for i = 1:num_vectors
temp_value = sparsecat_dot(p.pomdp, p.alphas[i], b)
if temp_value > max_value
max_value = temp_value
best_idx = i
end
end
return p.action_map[best_idx]
Return a vector-like representation of the belief `b` suitable for calculating the dot product with the alpha vectors.
"""
function beliefvec end

function beliefvec(m::POMDP, b::SparseCat)
return sparsevec(collect(stateindex(m, s) for s in b.vals), collect(b.probs), n_states(m))
end

function actionvalues(p::AlphaVectorPolicy, b::SparseCat)
num_vectors = length(p.alphas)
max_values = -Inf*ones(n_actions(p.pomdp))
for i = 1:num_vectors
temp_value = sparsecat_dot(p.pomdp, p.alphas[i], b)
ai = actionindex(p.pomdp, p.action_map[i])
if ( temp_value > max_values[ai])
max_values[ai] = temp_value
end
beliefvec(m::POMDP, b::DiscreteBelief) = b.b
beliefvec(m::POMDP, b::AbstractArray) = b

function beliefvec(m::POMDP, b)
sup = support(b)
bvec = zeros(length(sup)) # maybe this should be sparse?
for s in sup
bvec[stateindex(m, s)] = pdf(b, s)
end
return max_values
end

# perform dot product between an alpha vector and a sparse cat object
function sparsecat_dot(problem::POMDP, alpha::Vector{Float64}, b::SparseCat)
val = 0.
for (s, p) in weighted_iterator(b)
si = stateindex(problem, s)
val += alpha[si]*p
end
return val
return bvec
end

function Base.push!(p::AlphaVectorPolicy, alpha::Vector{Float64}, a)
Expand Down
4 changes: 4 additions & 0 deletions test/test_alpha_policy.jl
Expand Up @@ -23,6 +23,10 @@ let
@test isapprox(value(policy, sparse_b0), -16.0629)
@test isapprox(actionvalues(policy, sparse_b0), [-16.0629, -19.4557])
@test action(policy, sparse_b0) == false

# Bool_distribution (if it works for this, it should work for an arbitrary distribution)
bd = initialstate_distribution(pomdp)::BoolDistribution
@test action(policy, bd) == false

# try pushing new vector
push!(policy, [0.0,0.0], true)
Expand Down

0 comments on commit 176f0f4

Please sign in to comment.