Skip to content

Commit

Permalink
Fixed not taking terminality of states into account
Browse files Browse the repository at this point in the history
  • Loading branch information
Omastto1 committed Apr 11, 2021
1 parent e2eaa0c commit d6ab696
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 27 deletions.
59 changes: 42 additions & 17 deletions src/pbvi.jl
Expand Up @@ -36,6 +36,20 @@ function _argmax(f, X)
return X[argmax(map(f, X))]
end

function belief_update(pomdp, b, b′, terminals, not_terminals)
if sum(b′[not_terminals]) != 0.
if !isempty(terminals)
b′[not_terminals] = b′[not_terminals] / (sum(b′[not_terminals]) / (1. - sum(b[terminals]) - sum(b′[terminals])))
b′[terminals] += b[terminals]
else
b′[not_terminals] /= sum(b′[not_terminals])
end
else
b′[terminals] .= 1/length(terminals)
end
return b′
end

function backup_belief(pomdp::POMDP, Γ, b)
S = ordered_states(pomdp)
A = ordered_actions(pomdp)
Expand All @@ -45,39 +59,41 @@ function backup_belief(pomdp::POMDP, Γ, b)

Γa = Vector{Vector{Float64}}(undef, length(A))

not_terminals = [stateindex(pomdp, s) for s in S if !isterminal(pomdp, s)]
terminals = [stateindex(pomdp, s) for s in S if isterminal(pomdp, s)]
for a in A
Γao = Vector{Vector{Float64}}(undef, length(O))
trans_probs = sum([pdf(transition(pomdp, s, a), sp) * b.b[stateindex(pomdp, s)] for s in S, sp in S], dims=1)
trans_probs = sum([pdf(transition(pomdp, S[is], a), sp) * b.b[is] for is in not_terminals, sp in S], dims=1)
if !isempty(terminals) trans_probs[terminals] .+= b.b[terminals] end

for o in O
# update beliefs
b′ = nothing
obs_probs = pdf.(map(sp -> observation(pomdp, a, sp), ordered_states(pomdp)), o)
pr_o_given_b_a = sum(obs_probs .* vec(trans_probs))
obs_probs = pdf.(map(sp -> observation(pomdp, a, sp), S), o)
b′ = obs_probs .* vec(trans_probs)

# P(o|b, a) = ∑(sp∈S) P(o|a, sp) ∑(s∈S) P(sp|s, a) * b(s)
if pr_o_given_b_a > 0.
b′ = DiscreteBelief(pomdp, b.state_list, [obs_probs[stateindex(pomdp, sp)] / pr_o_given_b_a * trans_probs[stateindex(pomdp, sp)] for sp in ordered_states(pomdp)])
if sum(b′) > 0.
b′ = DiscreteBelief(pomdp, b.state_list, belief_update(pomdp, b.b, b′, terminals, not_terminals))
else
b′ = DiscreteBelief(pomdp, b.state_list, zeros(length(S)))
end

# extract optimal alpha vector at resulting belief
Γao[obsindex(pomdp, o)] = _argmax-> α b′.b, Γ)
end

# construct new alpha vectors
αa = [r(s, a) + γ * sum(sum(pdf(transition(pomdp, s, a), sp) * pdf(observation(pomdp, s, a, sp), o) * Γao[i][j]
for (j, sp) in enumerate(S))
for (i, o) in enumerate(O))
for s in S]
αa = [r(s, a) + (!isterminal(pomdp, s) ? (γ * sum(sum(pdf(transition(pomdp, s, a), sp) * pdf(observation(pomdp, s, a, sp), o) * Γao[i][j]
for (j, sp) in enumerate(S))
for (i, o) in enumerate(O))) : 0.)
for s in S]

Γa[actionindex(pomdp, a)] = αa
end

# find the optimal alpha vector
idx = argmax(map(αa -> αa b.b, Γa))
alphavec = AlphaVec(Γa[idx], A[idx])

return alphavec
end

Expand All @@ -95,14 +111,23 @@ function improve(pomdp, B, Γ, ϵ)
end

function successors(pomdp, b, Bs)
S = ordered_states(pomdp)
not_terminals = [stateindex(pomdp, s) for s in S if !isterminal(pomdp, s)]
terminals = [stateindex(pomdp, s) for s in S if isterminal(pomdp, s)]
succs = []

for a in actions(pomdp)
trans_probs = sum([pdf(transition(pomdp, s, a), sp) * b[stateindex(pomdp, s)] for s in states(pomdp), sp in ordered_states(pomdp)], dims=1)
trans_probs = sum([pdf(transition(pomdp, S[is], a), sp) * b[is] for is in not_terminals, sp in S], dims=1)
if !isempty(terminals) trans_probs[terminals] .+= b[terminals] end

for o in observations(pomdp)
obs_probs = pdf.(map(sp -> observation(pomdp, a, sp), ordered_states(pomdp)), o)
pr_o_given_b_a = sum(obs_probs .* trans_probs')
if pr_o_given_b_a > 0.
b′ = [obs_probs[stateindex(pomdp, sp)] / pr_o_given_b_a * trans_probs[stateindex(pomdp, sp)] for sp in ordered_states(pomdp)]
#update belief
obs_probs = pdf.(map(sp -> observation(pomdp, a, sp), S), o)
b′ = obs_probs .* vec(trans_probs)

if sum(b′) > 0.
b′ = belief_update(pomdp, b, b′, terminals, not_terminals)

if !in(b′, Bs)
push!(succs, b′)
end
Expand Down
13 changes: 3 additions & 10 deletions test/runtests.jl
Expand Up @@ -10,7 +10,7 @@ using PointBasedValueIteration
pomdps = [TigerPOMDP(), BabyPOMDP(), MiniHallway()]

for pomdp in pomdps
solver = PBVISolver()
solver = PBVISolver(10, typeof(pomdp) == MiniHallway ? 1. : 0.01)
policy = solve(solver, pomdp)

sarsop = SARSOPSolver(verbose=false)
Expand All @@ -35,14 +35,7 @@ using PointBasedValueIteration
end

pbvi_vals = [value(policy, b) for b in B]
if typeof(pomdp) == MiniHallway
sarsop_vals = [value(sarsop_policy, b) for b in B]
# Test passes when the value function is multiplied
# @test isapprox(sarsop_vals * 3.4 , pbvi_vals, rtol=0.3)
@test_broken isapprox(sarsop_vals, pbvi_vals, rtol=0.3)
else
sarsop_vals = [value(sarsop_policy, b) for b in B]
@test isapprox(sarsop_vals, pbvi_vals, rtol=0.1)
end
sarsop_vals = [value(sarsop_policy, b) for b in B]
@test isapprox(sarsop_vals, pbvi_vals, rtol=0.1)
end
end

0 comments on commit d6ab696

Please sign in to comment.