Skip to content
This repository has been archived by the owner on Aug 7, 2021. It is now read-only.

Commit

Permalink
added some exceptions, made visualization better
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jan 28, 2017
1 parent e230edd commit 2b7a038
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 48 deletions.
28 changes: 25 additions & 3 deletions src/POMCP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,17 @@ export
ParticleReinvigorator,
reinvigorate!,
handle_unseen_observation,
DefaultReinvigoratorStub
DefaultReinvigoratorStub,

NoDecision,
AllSamplesTerminal,
ExceptionRethrow,
default_action


include("tree.jl")
include("particle_filter.jl")
include("exceptions.jl")

abstract AbstractPOMCPSolver <: POMDPs.Solver

Expand Down Expand Up @@ -108,6 +114,13 @@ Fields:
Number of actions to be considered at each node.
If <= 0, the entire action space will be considered.
default: 0
default_action::Any
Function, action, or Policy used to determine the action if POMCP fails with exception `ex`.
If this is a Function `f`, `f(belief, ex)` will be called.
If this is a Policy `p`, `action(p, belief)` will be called.
If it is an object `a`, `default_action(a, belief, ex) will be called, and
if this method is not implemented, `a` will be returned directly.
"""
type POMCPSolver <: AbstractPOMCPSolver
eps::Float64 # will stop simulations when discount^depth is less than this
Expand All @@ -123,6 +136,7 @@ type POMCPSolver <: AbstractPOMCPSolver
init_N::Any

num_sparse_actions::Int # = 0 or less if not used
default_action::Any
end

"""
Expand Down Expand Up @@ -168,8 +182,8 @@ Fields:
k_action::Float64
alpha_action::Float64
k_state::Float64
alpha_state::Float64
k_observation::Float64
alpha_observation::Float64
These constants control the double progressive widening. A new observation
or action will be added if the number of children is less than or equal to kN^alpha.
defaults: k:10, alpha:0.5
Expand All @@ -195,6 +209,13 @@ Fields:
If this is an object `o`, `next_action(o, pomdp, b, h)` will be called.
default: RandomActionGenerator(rng)
default_action::Any
Function, action, or Policy used to determine the action if POMCP fails with exception `ex`.
If this is a Function `f`, `f(belief, ex)` will be called.
If this is a Policy `p`, `action(p, belief)` will be called.
If it is an object `a`, `default_action(a, belief, ex) will be called, and
if this method is not implemented, `a` will be returned directly.
For more information on the k and alpha parameters, see Couëtoux, A., Hoock, J.-B., Sokolovska, N., Teytaud, O., & Bonnard, N. (2011). Continuous Upper Confidence Trees. In Learning and Intelligent Optimization. Rome, Italy. Retrieved from http://link.springer.com/chapter/10.1007/978-3-642-25566-3_32
"""
type POMCPDPWSolver <: AbstractPOMCPSolver
Expand All @@ -216,6 +237,7 @@ type POMCPDPWSolver <: AbstractPOMCPSolver
init_V::Any
init_N::Any
next_action::Any
default_action::Any
end

"""
Expand Down
12 changes: 8 additions & 4 deletions src/constructor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ function POMCPSolver(;eps=0.01,
estimate_value=RolloutEstimator(POMDPToolbox.RandomSolver()),
init_V=0.0,
init_N=0,
num_sparse_actions=0)
num_sparse_actions=0,
default_action=ExceptionRethrow())

return POMCPSolver(eps,
max_depth,
Expand All @@ -25,7 +26,8 @@ function POMCPSolver(;eps=0.01,
estimate_value,
init_V,
init_N,
num_sparse_actions)
num_sparse_actions,
default_action)
end

"""
Expand All @@ -49,7 +51,8 @@ function POMCPDPWSolver(;eps=0.01,
k_action::Float64=10.,
init_V=0.0,
init_N=0,
next_action=RandomActionGenerator())
next_action=RandomActionGenerator(),
default_action=ExceptionRethrow())

return POMCPDPWSolver(eps,
max_depth,
Expand All @@ -65,5 +68,6 @@ function POMCPDPWSolver(;eps=0.01,
k_action,
init_V,
init_N,
next_action)
next_action,
default_action)
end
26 changes: 26 additions & 0 deletions src/exceptions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
abstract NoDecision <: Exception
Base.show(io::IO, nd::NoDecision) = print(io, """
POMCP failed to choose an action because the following exception was thrown:
$nd
To specify an action for this case, use the default_action solver parameter.
""")

immutable AllSamplesTerminal <: NoDecision
belief
end
Base.show(io::IO, ast::AllSamplesTerminal) = print(io, """
POMCP failed to choose an action because all states sampled from the belief were terminal.
To see the belief, catch this exception as ex and see ex.belief.
To specify an action for this case, use the default_action solver parameter.
""")


immutable ExceptionRethrow end

default_action(::ExceptionRethrow, belief, ex) = rethrow(ex)
default_action(f::Function, belief, ex) = f(belief, ex)
default_action(p::POMDPs.Policy, belief, ex) = action(p, belief)
default_action(o, belief, ex) = o
28 changes: 21 additions & 7 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ function create_policy{S}(solver::Union{POMCPSolver,POMCPDPWSolver}, pomdp::POMD
end

function action(policy::POMCPPlanner, belief, a=nothing)
return search(policy, belief, policy.solver.tree_queries)
try
a = search(policy, belief, policy.solver.tree_queries)
catch ex
a = default_action(policy.solver.default_action, belief, ex)
end
return a
end

"""
Expand Down Expand Up @@ -42,19 +47,28 @@ function search{RootBelief}(pomcp::POMCPPlanner, belief::RootBelief, tree_querie
end

function search(pomcp::POMCPPlanner, b::BeliefNode, tree_queries)
#XXX hack
# XXX hack
pomcp._tree_ref = b
# end hack

all_terminal = true
for i in 1:tree_queries
s = rand(pomcp.solver.rng, b)
simulate(pomcp, b, s, 0)
b.N += 1
if !POMDPs.isterminal(pomcp.problem, s)
simulate(pomcp, b, s, 0)
b.N += 1
all_terminal = false
end
end

if all_terminal
throw(AllSamplesTerminal(b.B))
end

best_V = -Inf
local best_node # guessing that type stability is not important enough to make a difference at this point
for node in values(b.children)
best_node = first(values(b.children))
best_V = best_node.V
@assert !isnan(best_V)
for node in collect(values(b.children))[2:end]
if node.V >= best_V
best_V = node.V
best_node = node
Expand Down
38 changes: 4 additions & 34 deletions src/visualization.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import JSON
import MCTS: node_tag, tooltip_tag
import MCTS: node_tag, tooltip_tag, blink

type POMCPTreeVisualizer
type POMCPTreeVisualizer <: AbstractTreeVisualizer
node::BeliefNode
end

blink(n::BeliefNode) = blink(POMCPTreeVisualizer(n))

typealias NodeDict Dict{Int, Dict{String, Any}}

function create_json(v::POMCPTreeVisualizer)
Expand Down Expand Up @@ -69,35 +71,3 @@ function recursive_push!(nd::NodeDict, n::ActNode, parent_id=-1)
end
return nd
end

function Base.show(f::IO, ::MIME"text/html", visualizer::POMCPTreeVisualizer)
json, root_id = create_json(visualizer)
# write("/tmp/tree_dump.json", json)
css = @compat readstring(joinpath(Pkg.dir("MCTS"), "src", "tree_vis.css"))
js = @compat readstring(joinpath(Pkg.dir("MCTS"), "src", "tree_vis.js"))
div = "treevis$(randstring())"

html_string = """
<div id="$div">
<style>
$css
</style>
<script>
(function(){
var treeData = $json;
var rootID = $root_id;
var div = "$div";
$js
})();
</script>
</div>
"""
# html_string = "visualization doesn't work yet :("

# for debugging
# outfile = open("/tmp/pomcp_debug.html","w")
# write(outfile,html_string)
# close(outfile)

println(f,html_string)
end

0 comments on commit 2b7a038

Please sign in to comment.