Skip to content

Commit

Permalink
speedups in value and action
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed May 21, 2017
1 parent 34fcbc0 commit 4a9bb82
Showing 1 changed file with 21 additions and 26 deletions.
47 changes: 21 additions & 26 deletions src/vanilla.jl
Expand Up @@ -66,41 +66,36 @@ alphas(policy::QMDPPolicy) = policy.alphas

function action(policy::QMDPPolicy, b::DiscreteBelief)
alphas = policy.alphas
ihi = 0
vhi = -Inf
(ns, na) = size(alphas)
@assert length(b.b) == ns "Length of belief and alpha-vector size mismatch"
for ai = 1:na
util = dot(alphas[:,ai], b.b)
if util > vhi
vhi = util
ihi = ai
end
end

util = alphas'*b.b
ihi = indmax(util)
return policy.action_map[ihi]
end

function action(policy::QMDPPolicy, b)
bv = Array(Float64, n_states(policy.pomdp))
for (i,s) in enumerate(ordered_states(policy.pomdp))
bv[i] = pdf(b, s)
end
return action(policy, DiscreteBelief(bv))
end

function value(policy::QMDPPolicy, b::DiscreteBelief)
alphas = policy.alphas
vhi = -Inf
(ns, na) = size(alphas)
@assert length(b.b) == ns "Length of belief and alpha-vector size mismatch"
for ai = 1:na
util = 0.0
for si = 1:length(b.b)
util += b.b[si] * alphas[si,ai]
end
if util > vhi
vhi = util
end

util = alphas'*b.b
return maximum(util)
end

function value(policy::QMDPPolicy, b)
return action(policy, DiscreteBelief(belief_vector(policy, b)))
end

function action(policy::QMDPPolicy, b)
return action(policy, DiscreteBelief(belief_vector(policy, b)))
end

function belief_vector(policy::QMDPPolicy, b)
bv = Array(Float64, n_states(policy.pomdp))
for (i,s) in enumerate(ordered_states(policy.pomdp))
bv[i] = pdf(b, s)
end
return vhi
return bv
end

0 comments on commit 4a9bb82

Please sign in to comment.