Skip to content

Commit

Permalink
only generate_sr if observation is not needed, info
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Feb 20, 2018
1 parent 81e5946 commit 6c6aa14
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 14 deletions.
1 change: 1 addition & 0 deletions src/POMCPOW.jl
Expand Up @@ -16,6 +16,7 @@ using BasicPOMCP: convert_estimator

import Base: mean, rand, insert!
import POMDPs: action, solve
import POMDPToolbox: action_info

import MCTS: n_children, next_action, isroot, node_tag, tooltip_tag

Expand Down
22 changes: 15 additions & 7 deletions src/planner2.jl
Expand Up @@ -24,38 +24,46 @@ end

Base.srand(p::POMCPOWPlanner, seed) = srand(p.solver.rng, seed)

function action{P,NBU}(pomcp::POMCPOWPlanner{P,NBU}, b)


function action_info{P,NBU}(pomcp::POMCPOWPlanner{P,NBU}, b)
S = state_type(P)
A = action_type(P)
O = obs_type(P)
B = belief_type(NBU,P)
info = Dict{Symbol, Any}()
tree = POMCPOWTree{B,A,O,typeof(b)}(b, 2*pomcp.solver.tree_queries)
pomcp.tree = tree
local a::A
try
a = search(pomcp, tree)
a = search(pomcp, tree, info)
info[:tree] = tree
catch ex
a = convert(A, default_action(pomcp.solver.default_action, pomcp.problem, b, ex))
end
return a
return a, info
end

function search(pomcp::POMCPOWPlanner, tree::POMCPOWTree)
action(pomcp::POMCPOWPlanner, b) = first(action_info(pomcp, b))

function search(pomcp::POMCPOWPlanner, tree::POMCPOWTree, info::Dict{Symbol,Any}=Dict{Symbol,Any}())
all_terminal = true
# gc_enable(false)
start_time = CPUtime_us()
i = 0
start_us = CPUtime_us()
for i in 1:pomcp.solver.tree_queries
s = rand(pomcp.solver.rng, tree.root_belief)
if !POMDPs.isterminal(pomcp.problem, s)
max_depth = min(pomcp.solver.max_depth, ceil(Int, log(pomcp.solver.eps)/log(discount(pomcp.problem))))
simulate(pomcp, POWTreeObsNode(tree, 1), s, max_depth)
all_terminal = false
end
if CPUtime_us() - start_time >= pomcp.solver.max_time*1e6
if CPUtime_us() - start_us >= pomcp.solver.max_time*1e6
break
end
end
# gc_enable(true)
info[:search_time_us] = CPUtime_us() - start_us
info[:tree_queries] = i

if all_terminal
throw(AllSamplesTerminal(tree.root_belief))
Expand Down
16 changes: 11 additions & 5 deletions src/solver2.jl
Expand Up @@ -41,13 +41,11 @@ function simulate{B,S,A,O}(pomcp::POMCPOWPlanner, h_node::POWTreeObsNode{B,A,O},
best_node = select_best(pomcp.criterion, h_node, pomcp.solver.rng)
a = tree.a_labels[best_node]

sp, o, r = generate_sor(pomcp.problem, s, a, sol.rng)
if r == Inf
warn("POMCPOW: +Inf reward. This is not recommended and may cause future errors.")
end

new_node = false
if tree.n_a_children[best_node] <= sol.k_observation*(tree.n[best_node]^sol.alpha_observation)

sp, o, r = generate_sor(pomcp.problem, s, a, sol.rng)

if sol.check_repeat_obs && haskey(tree.a_child_lookup, (best_node,o))
hao = tree.a_child_lookup[(best_node, o)]
else
Expand All @@ -64,6 +62,14 @@ function simulate{B,S,A,O}(pomcp::POMCPOWPlanner, h_node::POWTreeObsNode{B,A,O},
tree.n_a_children[best_node] += 1
end
push!(tree.generated[best_node], o=>hao)
else

sp, r = generate_sr(pomcp.problem, s, a, sol.rng)

end

if r == Inf
warn("POMCPOW: +Inf reward. This is not recommended and may cause future errors.")
end

if new_node
Expand Down
7 changes: 5 additions & 2 deletions test/runtests.jl
Expand Up @@ -18,16 +18,19 @@ b = initial_state_distribution(pomdp)
B = POMCPOW.belief_type(POMCPOW.POWNodeFilter, typeof(pomdp))
tree = POMCPOWTree{B,Bool,Bool,typeof(b)}(b, 2*planner.solver.tree_queries)
@inferred POMCPOW.simulate(planner, POMCPOW.POWTreeObsNode(tree, 1), true, 10)
# @code_warntype POMCPOW.simulate(planner, POMCPOW.POWTreeObsNode(tree, 1), true, 10)

pomdp = LightDark1D()
solver = POMCPOWSolver(default_action=485)
planner = solve(solver, pomdp)

b = ParticleCollection([LightDark1DState(-1, 0)])
@test action(planner, b) == 485
@test @inferred(action(planner, b)) == 485

b = initial_state_distribution(pomdp)
action(planner, b)
@inferred action(planner, b)

a, info = action_info(planner, b)
d3t = D3Tree(planner)
d3t = D3Tree(info[:tree])
# inchrome(d3t)

0 comments on commit 6c6aa14

Please sign in to comment.