Skip to content

Commit

Permalink
added timing and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Sep 5, 2017
1 parent 1d3b03f commit e690435
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 1 deletion.
1 change: 1 addition & 0 deletions REQUIRE
Expand Up @@ -3,3 +3,4 @@ POMDPs
POMDPToolbox
Parameters
ParticleFilters
CPUTime
3 changes: 3 additions & 0 deletions src/BasicPOMCP.jl
Expand Up @@ -14,6 +14,7 @@ using POMDPs
using Parameters
using POMDPToolbox
using ParticleFilters
using CPUTime

import POMDPs: action, solve, updater

Expand Down Expand Up @@ -78,6 +79,7 @@ Partially Observable Monte Carlo Planning Solver. Options are set using the keyw
max_depth::Int = 20
c::Float64 = 1.0
tree_queries::Int = 1000
max_time::Float64 = Inf
estimate_value::Any = RolloutEstimator(RandomSolver())
default_action::Any = ExceptionRethrow()
rng::AbstractRNG = Base.GLOBAL_RNG
Expand Down Expand Up @@ -112,6 +114,7 @@ function POMCPTree(pomdp::POMDP, sz::Int=1000)
acts = collect(iterator(actions(pomdp)))
A = action_type(pomdp)
O = obs_type(pomdp)
sz = min(10_000_000, sz)
return POMCPTree{A,O}(sizehint!(Int[0], sz),
sizehint!(Vector{Int}[collect(1:length(acts))], sz),
sizehint!(Array{O}(1), sz),
Expand Down
6 changes: 5 additions & 1 deletion src/solver.jl
@@ -1,5 +1,5 @@
function action(p::POMCPPlanner, b)
local a
local a::action_type(p.problem)
try
a = search(p, b, POMCPTree(p.problem, p.solver.tree_queries))
catch ex
Expand All @@ -11,7 +11,11 @@ end

function search(p::POMCPPlanner, b, t::POMCPTree)
all_terminal = true
start_us = CPUtime_us()
for i in 1:p.solver.tree_queries
if CPUtime_us() - start_us >= 1e6*p.solver.max_time
break
end
s = rand(p.rng, b)
if !POMDPs.isterminal(p.problem, s)
simulate(p, s, POMCPObsNode(t, 1), p.solver.max_depth)
Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Expand Up @@ -19,4 +19,10 @@ r = @inferred BasicPOMCP.simulate(planner, initial_state(pomdp, MersenneTwister(
sim = HistoryRecorder(max_steps=10)
simulate(sim, pomdp, planner, updater(pomdp))

solver = POMCPSolver(max_time=0.1, tree_queries=typemax(Int), rng = MersenneTwister(1))
planner = solve(solver, pomdp)
action(planner, initial_state_distribution(pomdp))
println("time below should be about 0.1 seconds")
@time action(planner, initial_state_distribution(pomdp))

nbinclude(joinpath(dirname(@__FILE__), "..", "notebooks", "Minimal_Example.ipynb"))

0 comments on commit e690435

Please sign in to comment.