Skip to content

Commit

Permalink
Merge 0177d78 into dc93000
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Nov 21, 2017
2 parents dc93000 + 0177d78 commit 1bf4bd0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
20 changes: 15 additions & 5 deletions src/dpw.jl
Expand Up @@ -70,14 +70,24 @@ function simulate(dpw::DPWPlanner, snode::Int, d::Int)
end

# action progressive widening
if length(tree.children[snode]) <= sol.k_action*tree.total_n[snode]^sol.alpha_action # criterion for new action generation
a = next_action(dpw.next_action, dpw.mdp, s, DPWStateNode(tree, snode)) # action generation step
if !sol.check_repeat_action || !haskey(tree.a_lookup, (snode, a))
if dpw.solver.enable_action_pw
if length(tree.children[snode]) <= sol.k_action*tree.total_n[snode]^sol.alpha_action # criterion for new action generation
a = next_action(dpw.next_action, dpw.mdp, s, DPWStateNode(tree, snode)) # action generation step
if !sol.check_repeat_action || !haskey(tree.a_lookup, (snode, a))
n0 = init_N(sol.init_N, dpw.mdp, s, a)
insert_action_node!(tree, snode, a, n0,
init_Q(sol.init_Q, dpw.mdp, s, a),
sol.check_repeat_action
)
tree.total_n[snode] += n0
end
end
elseif isempty(tree.children[snode])
for a in iterator(actions(dpw.mdp, s))
n0 = init_N(sol.init_N, dpw.mdp, s, a)
insert_action_node!(tree, snode, a, n0,
init_Q(sol.init_Q, dpw.mdp, s, a),
sol.check_repeat_action
)
false)
tree.total_n[snode] += n0
end
end
Expand Down
10 changes: 8 additions & 2 deletions src/dpw_types.jl
Expand Up @@ -29,9 +29,13 @@ Fields:
defaults: k:10, alpha:0.5
keep_tree::Bool
If true, store the tree in the planner for reuse at the next timestep (and every time it is used in the future). There is a computational cost for maintaining the state dictionary necessary for this.
If true, store the tree in the planner for reuse at the next timestep (and every time it is used in the future). There is a computational cost for maintaining the state dictionary necessary for this.
default: false
enable_action_pw::Bool
If true, enable progressive widening on the action space; if false just use the whole action space.
default: true
check_repeat_state::Bool
check_repeat_action::Bool
When constructing the tree, check whether a state or action has been seen before (there is a computational cost to maintaining the dictionaries necessary for this)
Expand Down Expand Up @@ -78,6 +82,7 @@ mutable struct DPWSolver <: AbstractMCTSSolver
k_state::Float64
alpha_state::Float64
keep_tree::Bool
enable_action_pw::Bool
check_repeat_state::Bool
check_repeat_action::Bool
rng::AbstractRNG
Expand All @@ -101,14 +106,15 @@ function DPWSolver(;depth::Int=10,
k_state::Float64=10.0,
alpha_state::Float64=0.5,
keep_tree::Bool=false,
enable_action_pw::Bool=true,
check_repeat_state::Bool=true,
check_repeat_action::Bool=true,
rng::AbstractRNG=Base.GLOBAL_RNG,
estimate_value::Any = RolloutEstimator(RandomSolver(rng)),
init_Q::Any = 0.0,
init_N::Any = 0,
next_action::Any = RandomActionGenerator(rng))
DPWSolver(depth, exploration_constant, n_iterations, max_time, k_action, alpha_action, k_state, alpha_state, keep_tree, check_repeat_state, check_repeat_action, rng, estimate_value, init_Q, init_N, next_action)
DPWSolver(depth, exploration_constant, n_iterations, max_time, k_action, alpha_action, k_state, alpha_state, keep_tree, enable_action_pw, check_repeat_state, check_repeat_action, rng, estimate_value, init_Q, init_N, next_action)
end

#=
Expand Down
11 changes: 11 additions & 0 deletions test/dpw_test.jl
Expand Up @@ -9,3 +9,14 @@ a = action(policy, state)

clear_tree!(policy)
@test isnull(policy.tree)


# no action pw
solver = DPWSolver(n_iterations=n_iter, depth=depth, exploration_constant=ec, enable_action_pw=false)
mdp = GridWorld()

policy = solve(solver, mdp)

state = GridWorldState(1,1)

a = action(policy, state)

0 comments on commit 1bf4bd0

Please sign in to comment.