# Incorporating Domain Knowledge

Aside from tuning the solver parameters (c, k, alpha), MCTS currently offers several means of incorporating domain knowledge. The following solver parameters control the planner's behavior:

- `estimate_value` determines how the value is estimated at the leaf nodes (this is usually done using a rollout simulation).
- `init_N` and `init_Q` determine how N(s,a) and Q(s,a) are initialized when a new node is created.
- `next_action` determines which new actions are tried in double progressive widening

There are three ways of specifying these parameters: 1) with constant values, 2) with functions, and 3) with custom objects.

In [15]:
using MCTS
using POMDPs
using POMDPModels
using Random
mdp = SimpleGridWorld();

## Constant Values

`init_N`, `init_Q`, and `estimate_value` can be set with constant values (though this is a bad idea for `estimate_value`. `next_action` cannot be specified in this way. The following code sets all new N to 3 and all new Q to 11.73 for example.

In [16]:
solver = MCTSSolver(n_iterations=3, depth=4,
                    init_N=3,
                    init_Q=11.73)
policy = solve(solver, mdp)
action(policy, GWPos(1,1))
println("State-Action Nodes")
tree = policy.tree
for sn in MCTS.state_nodes(tree)
    for san in MCTS.children(sn)
        println("s:$(MCTS.state(sn)), a:$(action(san)) Q:$(MCTS.q(san)) N:$(MCTS.n(san))")
    end
end

State-Action Nodes
s:[1, 1], a:up Q:7.037999999999999 N:5
s:[1, 1], a:down Q:8.7975 N:4
s:[1, 1], a:left Q:5.864999999999999 N:6
s:[1, 1], a:right Q:11.73 N:3
s:[1, 2], a:up Q:8.7975 N:4
s:[1, 2], a:down Q:11.73 N:3
s:[1, 2], a:left Q:11.73 N:3
s:[1, 2], a:right Q:11.73 N:3
s:[2, 2], a:up Q:11.73 N:3
s:[2, 2], a:down Q:11.73 N:3
s:[2, 2], a:left Q:11.73 N:3
s:[2, 2], a:right Q:11.73 N:3
s:[2, 1], a:up Q:11.73 N:3
s:[2, 1], a:down Q:11.73 N:3
s:[2, 1], a:left Q:11.73 N:3
s:[2, 1], a:right Q:11.73 N:3


## Functions

`init_N`, `init_Q`, `estimate_value`, and `next_action` can also be functions. The following code will

- initialize Q to 0.0 everywhere except state [1,2] where it will be 11.73
- initialize N to 0 everywhere except state [1,2] where it will be 3
- estimate the value to be 10 divided by the manhattan distance to state [9,3]
- always choose action "up" first in double progressive widening

Note: the `?` below is part of the [ternary operator](http://docs.julialang.org/en/release-0.5/manual/control-flow/#control-flow).

In [17]:
special_Q(mdp, s, a) = s == GWPos(1,2) ? 11.73 : 0.0
special_N(mdp, s, a) = s == GWPos(1,2) ? 3 : 0

function manhattan_value(mdp, s, remaining_depth) # note remaining depth is ignored
    m_dist = abs(s.x-9)+abs(s.y-3)
    val = 10.0/m_dist
    println("Set value for $s to $val") # this is not necessary - just shows that it's working later
    return val
end

function up_priority(mdp, s, snode) # snode is the state node of type DPWStateNode
    if haskey(snode.tree.a_lookup, (snode.index, :up)) # "up" is already there
        return rand([:left, :down, :right]) # add a random action
    else
        return :up
    end
end;

In [18]:
solver = DPWSolver(n_iterations=8, depth=4,
                   init_N=special_N, init_Q=special_Q,
                   estimate_value=manhattan_value,
                   next_action=up_priority)
policy = solve(solver, mdp)
action(policy, GWPos(1,1))
println("State-Action Nodes:")
tree = policy.tree
for i in 1:length(tree.total_n)
    for j in tree.children[i]
        println("s:$(tree.s_labels[i]), a:$(tree.a_labels[j]), Q:$(tree.q[j]) N:$(tree.n[j])")
    end
end

Set value for [1, 2] to 1.1111111111111112
Set value for [1, 3] to 1.25
Set value for [2, 1] to 1.1111111111111112
Set value for [1, 1] to 1.0
Set value for [1, 1] to 1.0
Set value for [2, 2] to 1.25
Set value for [2, 3] to 1.4285714285714286
Set value for [3, 1] to 1.25
State-Action Nodes:
s:[1, 1], a:up, Q:1.085133101851852 N:3
s:[1, 1], a:right, Q:1.134156746031746 N:4
s:[1, 1], a:left, Q:0.8810953125 N:4
s:[1, 1], a:down, Q:0.8810953125 N:4
s:[1, 2], a:up, Q:9.094375 N:4
s:[2, 1], a:up, Q:1.2383928571428573 N:2
s:[2, 1], a:right, Q:1.1875 N:1
s:[2, 2], a:up, Q:1.3571428571428572 N:1


## Objects

There are many cases where functions are not suitable, for example when the solver needs to be serialized. In this case, arbitrary objects may be passed to the solver to encode the behavior. The same object can be passed to multiple solver parameters to govern all of their behavior. See the docstring for the solver for more information on which functions will be called on the object(s). The following code does exactly the same thing as the function-based code above:

In [19]:
mutable struct MyHeuristic
    target_state::GWPos
    special_state::GWPos
    special_Q::Float64
    special_N::Int
    priority_action::Symbol
    rng::AbstractRNG
end;

In [20]:
MCTS.init_Q(h::MyHeuristic, mdp::SimpleGridWorld, s, a) = s == h.special_state ? h.special_Q : 0.0
MCTS.init_N(h::MyHeuristic, mdp::SimpleGridWorld, s, a) = s == h.special_state ? h.special_N : 0

function MCTS.estimate_value(h::MyHeuristic, mdp::SimpleGridWorld, s, remaining_depth)
    targ = h.target_state
    m_dist = abs(s.x-targ.x)+abs(s.y-targ.y)
    val = 10.0/m_dist
    println("Set value for $s to $val") # this is not necessary - just shows that it's working later
    return val
end

function MCTS.next_action(h::MyHeuristic, mdp::SimpleGridWorld, s, snode::DPWStateNode)
    if haskey(snode.tree.a_lookup, (snode.index, h.priority_action))
        return rand(h.rng, [:up, :left, :down, :right]) # add a random other action
    else
        return h.priority_action
    end
end;

In [21]:
heur = MyHeuristic(GWPos(9,3), GWPos(1,2), 11.73, 3, :up, Random.GLOBAL_RNG)
solver = DPWSolver(n_iterations=8, depth=4,
                   init_N=heur, init_Q=heur,
                   estimate_value=heur,
                   next_action=heur)
policy = solve(solver, mdp)
action(policy, GWPos(1,1))
println("State-Action Nodes:")
tree = policy.tree
for i in 1:length(tree.total_n)
    for j in tree.children[i]
        println("s:$(tree.s_labels[i]), a:$(tree.a_labels[j]), Q:$(tree.q[j]) N:$(tree.n[j])")
    end
end

Set value for [1, 2] to 1.1111111111111112
Set value for [1, 3] to 1.25
Set value for [2, 1] to 1.1111111111111112
Set value for [1, 1] to 1.0
Set value for [2, 2] to 1.25
Set value for [1, 1] to 1.0
Set value for [2, 3] to 1.4285714285714286
Set value for [2, 1] to 1.1111111111111112
State-Action Nodes:
s:[1, 1], a:up, Q:0.9781366030092592 N:6
s:[1, 1], a:right, Q:1.0663283482142856 N:5
s:[1, 1], a:down, Q:0.8810953125 N:4
s:[1, 2], a:up, Q:7.4898437499999995 N:5
s:[1, 2], a:down, Q:9.0654296875 N:4
s:[1, 2], a:left, Q:9.01184375 N:4
s:[1, 2], a:right, Q:11.73 N:3
s:[2, 1], a:up, Q:1.2383928571428573 N:2
s:[2, 1], a:down, Q:1.0036574074074074 N:3
s:[2, 2], a:up, Q:1.3571428571428572 N:1


## Rollouts

The most common way to estimate the value of a state node is with rollout simulations. This can be done with an arbitrary policy or solver by passing a `RolloutEstimator` object as the `estimate_value` parameter. The following code does this with a policy that moves towards state [9,3].

In [22]:
mutable struct SeekTarget <: Policy
    target::GWPos
end

In [23]:
function POMDPs.action(p::SeekTarget, s::GWPos, a::Symbol=:up)
    if p.target.x > s.x
        return :right
    elseif p.target.x < s.x
        return :left
    elseif p.target.y > s.y
        return :up
    else
        return :down
    end
end

In [24]:
solver = MCTSSolver(n_iterations=5, depth=20,
                    estimate_value=RolloutEstimator(SeekTarget(GWPos(9,3)); max_depth=50))
policy = solve(solver, mdp)
action(policy, GWPos(5,1))
println("State-Action Nodes")
tree = policy.tree
for sn in MCTS.state_nodes(tree)
    for san in MCTS.children(sn)
        println("s:$(MCTS.state(sn)), a:$(action(san)) Q:$(MCTS.q(san)) N:$(MCTS.n(san))")
    end
end

State-Action Nodes
s:[5, 1], a:up Q:-1.2931903038081076 N:2
s:[5, 1], a:down Q:6.634204312890622 N:1
s:[5, 1], a:left Q:4.401266686517653 N:1
s:[5, 1], a:right Q:5.987369392383786 N:1
s:[4, 1], a:up Q:4.632912301597529 N:1
s:[4, 1], a:down Q:0.0 N:0
s:[4, 1], a:left Q:0.0 N:0
s:[4, 1], a:right Q:0.0 N:0
s:[5, 2], a:up Q:0.0 N:0
s:[5, 2], a:down Q:0.0 N:0
s:[5, 2], a:left Q:0.0 N:0
s:[5, 2], a:right Q:0.0 N:0
s:[6, 1], a:up Q:6.3024940972460906 N:1
s:[6, 1], a:down Q:0.0 N:0
s:[6, 1], a:left Q:0.0 N:0
s:[6, 1], a:right Q:0.0 N:0
s:[4, 2], a:up Q:0.0 N:0
s:[4, 2], a:down Q:0.0 N:0
s:[4, 2], a:left Q:0.0 N:0
s:[4, 2], a:right Q:0.0 N:0
s:[6, 2], a:up Q:0.0 N:0
s:[6, 2], a:down Q:0.0 N:0
s:[6, 2], a:left Q:0.0 N:0
s:[6, 2], a:right Q:0.0 N:0
