-
Notifications
You must be signed in to change notification settings - Fork 0
/
decomposed_policy.jl
40 lines (33 loc) · 1.24 KB
/
decomposed_policy.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
using POMDPs
using POMDPPolicies
using AutomotivePOMDPs
using Flux
struct DecPolicy{P <: Policy, M <: Union{MDP, POMDP}, O} <: Policy
policy::P
problem::M
op::O # reduction operator
end
function POMDPs.action(policy::DecPolicy, s::OCObs)
ai = argmax(actionvalues(policy, s))
return actions(policy.problem)[ai]
end
function POMDPs.action(policy::DecPolicy, s::Vector{OCObs}) # for the KMarkov updater
s_ = hcat(s...)
ai = argmax(actionvalues(policy, s_))
return actions(policy.problem)[ai]
end
function POMDPPolicies.actionvalues(policy::DecPolicy, s)
return Flux.data(_actionvalues(policy, decompose_state(policy.problem, s)))
end
function _actionvalues(policy::DecPolicy, s_dec::AbstractArray) # no hidden state!
return reduce(policy.op, actionvalues(policy.policy, s) for s in s_dec)
end
function decompose_state(pomdp::OCPOMDP, s)
return [get_singlestate(pomdp, s, i) for i in 1:pomdp.max_peds]
end
function get_singlestate(pomdp::OCPOMDP, s, i::Int) #XXX Beware of batch size!
n_features = 2
ego = view(s, Base.setindex(axes(s), 1:n_features, 1)...)
ped = view(s, Base.setindex(axes(s), n_features*i + 1:n_features*(i + 1), 1)...)
return vcat(ego, ped)
end