Skip to content

Commit

Permalink
added default_action option to DPWSolver (#37)
Browse files Browse the repository at this point in the history
* added default_action option to DPWSolver

* fixed errors
  • Loading branch information
zsunberg committed Feb 22, 2018
1 parent de99d58 commit d042a34
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 48 deletions.
9 changes: 7 additions & 2 deletions src/MCTS.jl
Expand Up @@ -30,13 +30,17 @@ export
init_Q,
children,
n_children,
isroot
isroot,
default_action

export
AbstractStateNode,
StateActionStateNode,
DPWStateActionNode,
DPWStateNode
DPWStateNode,

ExceptionRethrow,
ReportWhenUsed

abstract type AbstractMCTSPlanner{P<:Union{MDP,POMDP}} <: Policy end
abstract type AbstractMCTSSolver <: Solver end
Expand All @@ -49,6 +53,7 @@ include("dpw_types.jl")
include("dpw.jl")
include("action_gen.jl")
include("util.jl")
include("default_action.jl")
include("belief_mcts.jl")

include("visualization.jl")
Expand Down
23 changes: 23 additions & 0 deletions src/default_action.jl
@@ -0,0 +1,23 @@
struct ExceptionRethrow end

default_action(::ExceptionRethrow, mdp, s, ex) = rethrow(ex)
default_action(f::Function, mdp, s, ex) = f(s, ex)
default_action(p::POMDPs.Policy, mdp, s, ex) = action(p, s)
default_action(sol::POMDPs.Solver, mdp, s, ex) = action(solve(sol, mdp), s)
default_action(a, mdp, s, ex) = a

"""
ReportWhenUsed(a)
When the planner fails, returns action `a`, but also prints the exception.
"""
struct ReportWhenUsed{T}
a::T
end

function default_action(r::ReportWhenUsed, mdp, s, ex)
showerror(STDERR, ex)
a = default_action(r.a, mdp, s, ex)
warn("Using default action $a")
return a
end
90 changes: 49 additions & 41 deletions src/dpw.jl
Expand Up @@ -21,56 +21,64 @@ POMDPs.action(p::DPWPlanner, s) = first(action_info(p, s))
Construct an MCTSDPW tree and choose the best action. Also output some information.
"""
function POMDPToolbox.action_info(p::DPWPlanner, s)
if isterminal(p.mdp, s)
error("""
MCTS cannot handle terminal states. action was called with
s = $s
""")
end
local a::action_type(p.mdp)
info = Dict{Symbol, Any}()
try
if isterminal(p.mdp, s)
error("""
MCTS cannot handle terminal states. action was called with
s = $s
""")
end

S = state_type(p.mdp)
A = action_type(p.mdp)
if p.solver.keep_tree
if isnull(p.tree)
S = state_type(p.mdp)
A = action_type(p.mdp)
if p.solver.keep_tree
if isnull(p.tree)
tree = DPWTree{S,A}(p.solver.n_iterations)
p.tree = Nullable(tree)
else
tree = get(p.tree)
end
if haskey(tree.s_lookup, s)
snode = tree.s_lookup[s]
else
snode = insert_state_node!(tree, s, true)
end
else
tree = DPWTree{S,A}(p.solver.n_iterations)
p.tree = Nullable(tree)
else
tree = get(p.tree)
end
if haskey(tree.s_lookup, s)
snode = tree.s_lookup[s]
else
snode = insert_state_node!(tree, s, true)
snode = insert_state_node!(tree, s, p.solver.check_repeat_state)
end
else
tree = DPWTree{S,A}(p.solver.n_iterations)
p.tree = Nullable(tree)
snode = insert_state_node!(tree, s, p.solver.check_repeat_state)
end

i = 0
start_us = CPUtime_us()
for i = 1:p.solver.n_iterations
simulate(p, snode, p.solver.depth) # (not 100% sure we need to make a copy of the state here)
if CPUtime_us() - start_us >= p.solver.max_time * 1e6
break
i = 0
start_us = CPUtime_us()
for i = 1:p.solver.n_iterations
simulate(p, snode, p.solver.depth) # (not 100% sure we need to make a copy of the state here)
if CPUtime_us() - start_us >= p.solver.max_time * 1e6
break
end
end
end
info[:search_time_us] = CPUtime_us() - start_us
info[:tree_queries] = i
info[:tree] = tree

best_Q = -Inf
sanode = 0
for child in tree.children[snode]
if tree.q[child] > best_Q
best_Q = tree.q[child]
sanode = child
info[:search_time_us] = CPUtime_us() - start_us
info[:tree_queries] = i
info[:tree] = tree
best_Q = -Inf
sanode = 0
for child in tree.children[snode]
if tree.q[child] > best_Q
best_Q = tree.q[child]
sanode = child
end
end
# XXX some publications say to choose action that has been visited the most
a = tree.a_labels[sanode] # choose action with highest approximate value
catch ex
a = convert(action_type(p.mdp), default_action(p.solver.default_action, p.mdp, s, ex))
info[:exception] = ex
end
# XXX some publications say to choose action that has been visited the most
return tree.a_labels[sanode], info # choose action with highest approximate value

return a, info
end


Expand Down
14 changes: 12 additions & 2 deletions src/dpw_types.jl
Expand Up @@ -71,6 +71,13 @@ Fields:
If this is a function `f`, `f(mdp, s, snode)` will be called to set the value.
If this is an object `o`, `next_action(o, mdp, s, snode)` will be called.
default: RandomActionGenerator(rng)
default_action::Any
Function, action, or Policy used to determine the action if POMCP fails with exception `ex`.
If this is a Function `f`, `f(pomdp, belief, ex)` will be called.
If this is a Policy `p`, `action(p, belief)` will be called.
If it is an object `a`, `default_action(a, pomdp, belief, ex)` will be called, and if this method is not implemented, `a` will be returned directly.
default: `ExceptionRethrow()`
"""
mutable struct DPWSolver <: AbstractMCTSSolver
depth::Int
Expand All @@ -90,6 +97,7 @@ mutable struct DPWSolver <: AbstractMCTSSolver
init_Q::Any
init_N::Any
next_action::Any
default_action::Any
end

"""
Expand All @@ -113,8 +121,10 @@ function DPWSolver(;depth::Int=10,
estimate_value::Any = RolloutEstimator(RandomSolver(rng)),
init_Q::Any = 0.0,
init_N::Any = 0,
next_action::Any = RandomActionGenerator(rng))
DPWSolver(depth, exploration_constant, n_iterations, max_time, k_action, alpha_action, k_state, alpha_state, keep_tree, enable_action_pw, check_repeat_state, check_repeat_action, rng, estimate_value, init_Q, init_N, next_action)
next_action::Any = RandomActionGenerator(rng),
default_action::Any = ExceptionRethrow()
)
DPWSolver(depth, exploration_constant, n_iterations, max_time, k_action, alpha_action, k_state, alpha_state, keep_tree, enable_action_pw, check_repeat_state, check_repeat_action, rng, estimate_value, init_Q, init_N, next_action, default_action)
end

#=
Expand Down
4 changes: 2 additions & 2 deletions test/dpw_test.jl
Expand Up @@ -5,7 +5,7 @@ policy = solve(solver, mdp)

state = GridWorldState(1,1)

a = action(policy, state)
a = @inferred action(policy, state)

clear_tree!(policy)
@test isnull(policy.tree)
Expand All @@ -19,4 +19,4 @@ policy = solve(solver, mdp)

state = GridWorldState(1,1)

a = action(policy, state)
a = @inferred action(policy, state)
2 changes: 1 addition & 1 deletion test/runtests.jl
Expand Up @@ -32,7 +32,7 @@ policy = solve(solver, mdp)

state = GridWorldState(1,1)

a = action(policy, state)
a = @inferred action(policy, state)

clear_tree!(policy)
@test isempty(policy.tree)
Expand Down

0 comments on commit d042a34

Please sign in to comment.