From 6cd3f72d0b52f932f95ce97a7a149d773495ebdb Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Sat, 20 May 2017 16:40:21 -0700 Subject: [PATCH] added support for terminal states --- src/vanilla.jl | 58 ++++++++++--------- ...sic_value_iteration_disallowing_actions.jl | 2 +- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/vanilla.jl b/src/vanilla.jl index e24b96c..eed2532 100644 --- a/src/vanilla.jl +++ b/src/vanilla.jl @@ -144,33 +144,39 @@ function solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}, policy=creat tic() # state loop for (istate,s) in enumerate(states) - old_util = util[istate] # for residual - max_util = -Inf - # action loop - # util(s) = max_a( R(s,a) + discount_factor * sum(T(s'|s,a)util(s') ) sub_aspace = actions(mdp, s) - for a in iterator(sub_aspace) - iaction = action_index(mdp, a) - dist = transition(mdp, s, a) # creates distribution over neighbors - u = 0.0 - for sp in iterator(dist) - p = pdf(dist, sp) - p == 0.0 ? continue : nothing # skip if zero prob - r = reward(mdp, s, a, sp) - isp = state_index(mdp, sp) - u += p * (r + discount_factor * util[isp]) - end - new_util = u - if new_util > max_util - max_util = new_util - pol[istate] = iaction - end - include_Q ? (qmat[istate, iaction] = new_util) : nothing - end # action - # update the value array - util[istate] = max_util - diff = abs(max_util - old_util) - diff > residual ? (residual = diff) : nothing + if isterminal(mdp, s) + util[istate] = 0.0 + pol[istate] = 1 + else + old_util = util[istate] # for residual + max_util = -Inf + # action loop + # util(s) = max_a( R(s,a) + discount_factor * sum(T(s'|s,a)util(s') ) + for a in iterator(sub_aspace) + iaction = action_index(mdp, a) + dist = transition(mdp, s, a) # creates distribution over neighbors + u = 0.0 + for sp in iterator(dist) + p = pdf(dist, sp) + p == 0.0 ? continue : nothing # skip if zero prob + r = reward(mdp, s, a, sp) + isp = state_index(mdp, sp) + u += p * (r + discount_factor * util[isp]) + + end + new_util = u + if new_util > max_util + max_util = new_util + pol[istate] = iaction + end + include_Q ? (qmat[istate, iaction] = new_util) : nothing + end # action + # update the value array + util[istate] = max_util + diff = abs(max_util - old_util) + diff > residual ? (residual = diff) : nothing + end end # state iter_time = toq() total_time += iter_time diff --git a/test/runtests_basic_value_iteration_disallowing_actions.jl b/test/runtests_basic_value_iteration_disallowing_actions.jl index ef52a9b..62cd7ea 100644 --- a/test/runtests_basic_value_iteration_disallowing_actions.jl +++ b/test/runtests_basic_value_iteration_disallowing_actions.jl @@ -33,7 +33,7 @@ function POMDPs.actions(mdp::SpecialGridWorld, s::GridWorldState) elseif sidx == 7 acts = [GridWorldAction(:up), GridWorldAction(:down), GridWorldAction(:left), GridWorldAction(:right)] end - return GridWorldActionSpace(acts) + return acts end