Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/JuliaPOMDP/MCTS.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Nov 8, 2017
2 parents df62a07 + 66a7b76 commit c97b34b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
8 changes: 7 additions & 1 deletion src/dpw.jl
Expand Up @@ -16,6 +16,12 @@ end
Call simulate and chooses the approximate best action from the reward approximations
"""
function POMDPs.action(p::DPWPlanner, s)
if isterminal(p.mdp, s)
error("""
MCTS cannot handle terminal states. action was called with
s = $s
""")
end
S = state_type(p.mdp)
A = action_type(p.mdp)
if p.solver.keep_tree
Expand Down Expand Up @@ -46,7 +52,7 @@ function POMDPs.action(p::DPWPlanner, s)
end
end
# XXX some publications say to choose action that has been visited the most
return tree.a_labels[sanode] # choose action with highest approximate value
return tree.a_labels[sanode] # choose action with highest approximate value
end


Expand Down
13 changes: 13 additions & 0 deletions test/runtests.jl
Expand Up @@ -86,4 +86,17 @@ let
@test abs(t-1.0) < 0.5
end

# test terminal state error
let
solver = DPWSolver(n_iterations=typemax(Int),
depth=depth,
max_time=1.0,
exploration_constant=ec)
mdp = GridWorld()

policy = solve(solver, mdp)
state = GridWorldState(1,1,true)
@test_throws ErrorException action(policy, state)
end

nbinclude("../notebooks/Domain_Knowledge_Example.ipynb")
2 changes: 1 addition & 1 deletion test/visualization.jl
Expand Up @@ -11,7 +11,7 @@ a = action(policy, state)
tree = D3Tree(policy, state)

# dpw
solver = DPWSolver(n_iterations=n_iter, depth=depth, exploration_constant=ec)
solver = DPWSolver(n_iterations=n_iter, depth=depth, exploration_constant=ec, rng=MersenneTwister(13))
mdp = GridWorld()

policy = solve(solver, mdp)
Expand Down

0 comments on commit c97b34b

Please sign in to comment.