Skip to content

Commit

Permalink
break ties in MaxUCB randomly
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Nov 18, 2017
1 parent 0183384 commit e8be0e7
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/POMCPOW.jl
Expand Up @@ -119,9 +119,9 @@ Fields:
default: `RandomActionGenerator(rng)`
- `default_action::Any`:
Function, action, or Policy used to determine the action if POMCP fails with exception `ex`.
If this is a Function `f`, `f(belief, ex)` will be called.
If this is a Function `f`, `f(pomdp, belief, ex)` will be called.
If this is a Policy `p`, `action(p, belief)` will be called.
If it is an object `a`, `default_action(a, belief, ex)` will be called, and
If it is an object `a`, `default_action(a, pomdp, belief, ex)` will be called, and
if this method is not implemented, `a` will be returned directly.
"""
@with_kw mutable struct POMCPOWSolver <: AbstractPOMCPSolver
Expand Down
26 changes: 20 additions & 6 deletions src/criteria.jl
Expand Up @@ -2,11 +2,13 @@ struct MaxUCB
c::Float64
end

function select_best(crit::MaxUCB, h_node::POWTreeObsNode)
function select_best(crit::MaxUCB, h_node::POWTreeObsNode, rng)
tree = h_node.tree
h = h_node.node
best_criterion_val = -Inf
local best_node
local best_node::Int
istied = false
local tied::Vector{Int}
ltn = log(tree.total_n[h])
for node in tree.tried[h]
n = tree.n[node]
Expand All @@ -17,17 +19,29 @@ function select_best(crit::MaxUCB, h_node::POWTreeObsNode)
else
criterion_value = tree.v[node] + crit.c*sqrt(ltn/n)
end
if criterion_value >= best_criterion_val
if criterion_value > best_criterion_val
best_criterion_val = criterion_value
best_node = node
istied = false
elseif criterion_value == best_criterion_val
if istied
push!(tied, node)
else
istied = true
tied = [best_node, node]
end
end
end
return best_node
if istied
return rand(rng, tied)
else
return best_node
end
end

struct MaxQ end

function select_best(crit::MaxQ, h_node::POWTreeObsNode)
function select_best(crit::MaxQ, h_node::POWTreeObsNode, rng)
tree = h_node.tree
h = h_node.node
best_node = first(tree.tried[h])
Expand All @@ -44,7 +58,7 @@ end

struct MaxTries end

function select_best(crit::MaxTries, h_node::POWTreeObsNode)
function select_best(crit::MaxTries, h_node::POWTreeObsNode, rng)
tree = h_node.tree
h = h_node.node
best_node = first(tree.tried[h])
Expand Down
4 changes: 2 additions & 2 deletions src/planner2.jl
Expand Up @@ -35,7 +35,7 @@ function action{P,NBU}(pomcp::POMCPOWPlanner{P,NBU}, b)
try
a = search(pomcp, tree)
catch ex
a = convert(A, default_action(pomcp.solver.default_action, b, ex))
a = convert(A, default_action(pomcp.solver.default_action, pomcp.problem, b, ex))
end
return a
end
Expand All @@ -60,7 +60,7 @@ function search(pomcp::POMCPOWPlanner, tree::POMCPOWTree)
throw(AllSamplesTerminal(tree.root_belief))
end

best_node = select_best(pomcp.solver.final_criterion, POWTreeObsNode(tree,1))
best_node = select_best(pomcp.solver.final_criterion, POWTreeObsNode(tree,1), pomcp.solver.rng)

return tree.a_labels[best_node]
end
2 changes: 1 addition & 1 deletion src/solver2.jl
Expand Up @@ -53,7 +53,7 @@ function simulate{B,S,A,O}(pomcp::POMCPOWPlanner, h_node::POWTreeObsNode{B,A,O},
end
total_n = tree.total_n[h]

best_node = select_best(pomcp.criterion, h_node)
best_node = select_best(pomcp.criterion, h_node, pomcp.solver.rng)
a = tree.a_labels[best_node]

sp, o, r = generate_sor(pomcp.problem, s, a, sol.rng)
Expand Down

0 comments on commit e8be0e7

Please sign in to comment.