Skip to content

Commit

Permalink
Check that sp is non-terminal before interpolating
Browse files Browse the repository at this point in the history
  • Loading branch information
Shushman committed Jul 11, 2018
1 parent 1ebf240 commit 6c04154
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions src/local_approximation_vi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,13 @@ function solve(solver::LocalApproximationValueIterationSolver, mdp::Union{MDP,PO
# Generative Model
for j in 1:solver.n_generative_samples
sp, r = generate_sr(mdp, s, a, solver.rng)
sp_point = POMDPs.convert_s(Vector{Float64}, sp, mdp)
u += r + discount_factor*compute_value(policy.interp, sp_point)
u += r

# Only interpolate sp if it is non-terminal
if !isterminal(mdp,sp)
sp_point = POMDPs.convert_s(Vector{Float64}, sp, mdp)
u += discount_factor*compute_value(policy.interp, sp_point)
end
end
u = u / solver.n_generative_samples
else
Expand All @@ -150,8 +155,13 @@ function solve(solver::LocalApproximationValueIterationSolver, mdp::Union{MDP,PO
for (sp, p) in weighted_iterator(dist)
p == 0.0 ? continue : nothing
r = reward(mdp, s, a, sp)
sp_point = POMDPs.convert_s(Vector{Float64}, sp, mdp)
u += p * (r + discount_factor*compute_value(policy.interp, sp_point))
u += p*r

# Only interpolate sp if it is non-terminal
if !isterminal(mdp,sp)
sp_point = POMDPs.convert_s(Vector{Float64}, sp, mdp)
u += p * (discount_factor*compute_value(policy.interp, sp_point))
end
end # next-states
end

Expand Down Expand Up @@ -222,17 +232,27 @@ function action_value(policy::LocalApproximationValueIterationPolicy, s::S, a::A
if policy.is_mdp_generative
for j in 1:policy.n_generative_samples
sp, r = generate_sr(mdp, s, a, Base.GLOBAL_RNG)
sp_point = POMDPs.convert_s(Vector{Float64}, sp, mdp)
u += r + discount_factor*compute_value(policy.interp, sp_point)
u += r

# Only interpolate sp if it is non-terminal
if !isterminal(mdp,sp)
sp_point = POMDPs.convert_s(Vector{Float64}, sp, mdp)
u += discount_factor*compute_value(policy.interp, sp_point)
end
end
u = u / policy.n_generative_samples
else
dist = transition(mdp,s,a)
for (sp, p) in weighted_iterator(dist)
p == 0.0 ? continue : nothing
r = reward(mdp, s, a, sp)
sp_point = POMDPs.convert_s(Vector{Float64}, sp, mdp)
u += p * (r + discount_factor*compute_value(policy.interp, sp_point))
u += p*r

# Only interpolate sp if it is non-terminal
if !isterminal(mdp,sp)
sp_point = POMDPs.convert_s(Vector{Float64}, sp, mdp)
u += p*(discount_factor*compute_value(policy.interp, sp_point))
end
end
end

Expand Down

0 comments on commit 6c04154

Please sign in to comment.