From 2b7a03847eb0bbae9cf464de3147087670004bd2 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Sat, 28 Jan 2017 11:16:19 -0800 Subject: [PATCH] added some exceptions, made visualization better --- src/POMCP.jl | 28 +++++++++++++++++++++++++--- src/constructor.jl | 12 ++++++++---- src/exceptions.jl | 26 ++++++++++++++++++++++++++ src/solver.jl | 28 +++++++++++++++++++++------- src/visualization.jl | 38 ++++---------------------------------- 5 files changed, 84 insertions(+), 48 deletions(-) create mode 100644 src/exceptions.jl diff --git a/src/POMCP.jl b/src/POMCP.jl index 136be65..315d153 100644 --- a/src/POMCP.jl +++ b/src/POMCP.jl @@ -46,11 +46,17 @@ export ParticleReinvigorator, reinvigorate!, handle_unseen_observation, - DefaultReinvigoratorStub + DefaultReinvigoratorStub, + + NoDecision, + AllSamplesTerminal, + ExceptionRethrow, + default_action include("tree.jl") include("particle_filter.jl") +include("exceptions.jl") abstract AbstractPOMCPSolver <: POMDPs.Solver @@ -108,6 +114,13 @@ Fields: Number of actions to be considered at each node. If <= 0, the entire action space will be considered. default: 0 + + default_action::Any + Function, action, or Policy used to determine the action if POMCP fails with exception `ex`. + If this is a Function `f`, `f(belief, ex)` will be called. + If this is a Policy `p`, `action(p, belief)` will be called. + If it is an object `a`, `default_action(a, belief, ex) will be called, and + if this method is not implemented, `a` will be returned directly. """ type POMCPSolver <: AbstractPOMCPSolver eps::Float64 # will stop simulations when discount^depth is less than this @@ -123,6 +136,7 @@ type POMCPSolver <: AbstractPOMCPSolver init_N::Any num_sparse_actions::Int # = 0 or less if not used + default_action::Any end """ @@ -168,8 +182,8 @@ Fields: k_action::Float64 alpha_action::Float64 - k_state::Float64 - alpha_state::Float64 + k_observation::Float64 + alpha_observation::Float64 These constants control the double progressive widening. A new observation or action will be added if the number of children is less than or equal to kN^alpha. defaults: k:10, alpha:0.5 @@ -195,6 +209,13 @@ Fields: If this is an object `o`, `next_action(o, pomdp, b, h)` will be called. default: RandomActionGenerator(rng) + default_action::Any + Function, action, or Policy used to determine the action if POMCP fails with exception `ex`. + If this is a Function `f`, `f(belief, ex)` will be called. + If this is a Policy `p`, `action(p, belief)` will be called. + If it is an object `a`, `default_action(a, belief, ex) will be called, and + if this method is not implemented, `a` will be returned directly. + For more information on the k and alpha parameters, see Couëtoux, A., Hoock, J.-B., Sokolovska, N., Teytaud, O., & Bonnard, N. (2011). Continuous Upper Confidence Trees. In Learning and Intelligent Optimization. Rome, Italy. Retrieved from http://link.springer.com/chapter/10.1007/978-3-642-25566-3_32 """ type POMCPDPWSolver <: AbstractPOMCPSolver @@ -216,6 +237,7 @@ type POMCPDPWSolver <: AbstractPOMCPSolver init_V::Any init_N::Any next_action::Any + default_action::Any end """ diff --git a/src/constructor.jl b/src/constructor.jl index df8e6da..1227cff 100644 --- a/src/constructor.jl +++ b/src/constructor.jl @@ -14,7 +14,8 @@ function POMCPSolver(;eps=0.01, estimate_value=RolloutEstimator(POMDPToolbox.RandomSolver()), init_V=0.0, init_N=0, - num_sparse_actions=0) + num_sparse_actions=0, + default_action=ExceptionRethrow()) return POMCPSolver(eps, max_depth, @@ -25,7 +26,8 @@ function POMCPSolver(;eps=0.01, estimate_value, init_V, init_N, - num_sparse_actions) + num_sparse_actions, + default_action) end """ @@ -49,7 +51,8 @@ function POMCPDPWSolver(;eps=0.01, k_action::Float64=10., init_V=0.0, init_N=0, - next_action=RandomActionGenerator()) + next_action=RandomActionGenerator(), + default_action=ExceptionRethrow()) return POMCPDPWSolver(eps, max_depth, @@ -65,5 +68,6 @@ function POMCPDPWSolver(;eps=0.01, k_action, init_V, init_N, - next_action) + next_action, + default_action) end diff --git a/src/exceptions.jl b/src/exceptions.jl new file mode 100644 index 0000000..db85a09 --- /dev/null +++ b/src/exceptions.jl @@ -0,0 +1,26 @@ +abstract NoDecision <: Exception +Base.show(io::IO, nd::NoDecision) = print(io, """ + POMCP failed to choose an action because the following exception was thrown: + $nd + + To specify an action for this case, use the default_action solver parameter. + """) + +immutable AllSamplesTerminal <: NoDecision + belief +end +Base.show(io::IO, ast::AllSamplesTerminal) = print(io, """ + POMCP failed to choose an action because all states sampled from the belief were terminal. + + To see the belief, catch this exception as ex and see ex.belief. + + To specify an action for this case, use the default_action solver parameter. + """) + + +immutable ExceptionRethrow end + +default_action(::ExceptionRethrow, belief, ex) = rethrow(ex) +default_action(f::Function, belief, ex) = f(belief, ex) +default_action(p::POMDPs.Policy, belief, ex) = action(p, belief) +default_action(o, belief, ex) = o diff --git a/src/solver.jl b/src/solver.jl index 2d99c13..10ab44c 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -14,7 +14,12 @@ function create_policy{S}(solver::Union{POMCPSolver,POMCPDPWSolver}, pomdp::POMD end function action(policy::POMCPPlanner, belief, a=nothing) - return search(policy, belief, policy.solver.tree_queries) + try + a = search(policy, belief, policy.solver.tree_queries) + catch ex + a = default_action(policy.solver.default_action, belief, ex) + end + return a end """ @@ -42,19 +47,28 @@ function search{RootBelief}(pomcp::POMCPPlanner, belief::RootBelief, tree_querie end function search(pomcp::POMCPPlanner, b::BeliefNode, tree_queries) - #XXX hack + # XXX hack pomcp._tree_ref = b # end hack + all_terminal = true for i in 1:tree_queries s = rand(pomcp.solver.rng, b) - simulate(pomcp, b, s, 0) - b.N += 1 + if !POMDPs.isterminal(pomcp.problem, s) + simulate(pomcp, b, s, 0) + b.N += 1 + all_terminal = false + end + end + + if all_terminal + throw(AllSamplesTerminal(b.B)) end - best_V = -Inf - local best_node # guessing that type stability is not important enough to make a difference at this point - for node in values(b.children) + best_node = first(values(b.children)) + best_V = best_node.V + @assert !isnan(best_V) + for node in collect(values(b.children))[2:end] if node.V >= best_V best_V = node.V best_node = node diff --git a/src/visualization.jl b/src/visualization.jl index 2689cc1..e1a9a7e 100644 --- a/src/visualization.jl +++ b/src/visualization.jl @@ -1,10 +1,12 @@ import JSON -import MCTS: node_tag, tooltip_tag +import MCTS: node_tag, tooltip_tag, blink -type POMCPTreeVisualizer +type POMCPTreeVisualizer <: AbstractTreeVisualizer node::BeliefNode end +blink(n::BeliefNode) = blink(POMCPTreeVisualizer(n)) + typealias NodeDict Dict{Int, Dict{String, Any}} function create_json(v::POMCPTreeVisualizer) @@ -69,35 +71,3 @@ function recursive_push!(nd::NodeDict, n::ActNode, parent_id=-1) end return nd end - -function Base.show(f::IO, ::MIME"text/html", visualizer::POMCPTreeVisualizer) - json, root_id = create_json(visualizer) - # write("/tmp/tree_dump.json", json) - css = @compat readstring(joinpath(Pkg.dir("MCTS"), "src", "tree_vis.css")) - js = @compat readstring(joinpath(Pkg.dir("MCTS"), "src", "tree_vis.js")) - div = "treevis$(randstring())" - - html_string = """ -
- - -
- """ - # html_string = "visualization doesn't work yet :(" - - # for debugging - # outfile = open("/tmp/pomcp_debug.html","w") - # write(outfile,html_string) - # close(outfile) - - println(f,html_string) -end