In [1]:
using POMDPs
using POMDPModels # for the GridWorld problem
importall MCTS # explicit importing for writing new methods

# Incorporating Prior Knowledge

Suppose that we are solving a GridWorld problem and we want to use 10 divided by the Manhattan distance to [9,3] as a heuristic value estimate (rather than doing rollouts) and initial Q for the special state [4,6] and any action to -5.0 with a weight of N=100 trials. We can implement this with the code in this notebook. First we create a type to store our information:

In [2]:
type ManhattanHeuristic
    target_state::GridWorldState
    special_state::GridWorldState
    special_Q::Float64
    special_N::Int
end

To describe how the solver is to use the Manhattan distance, we implement a new method of estimate_value

In [3]:
function estimate_value(p::AbstractMCTSPolicy{GridWorldState, GridWorldAction, ManhattanHeuristic},
                        s::GridWorldState,
                        depth::Int)
    targ = prior_knowledge(p).target_state # will return a ManhattanHeuristic object
    m_dist = float(abs(s.x-targ.x)+abs(s.y-targ.y))
    val = 10.0/m_dist
    println("Set value for $s to $val") # this is not necessary - just shows that it's working later
    return val
end

estimate_value (generic function with 2 methods)

and to describe how to initialize Q and N, we implement the following methods

In [4]:
function init_Q(p::AbstractMCTSPolicy{GridWorldState, GridWorldAction, ManhattanHeuristic},
                s::GridWorldState,
                a::GridWorldAction)
    pk = prior_knowledge(p)
    if s == pk.special_state
        return pk.special_Q
    else
        return 0.0
    end
end

function init_N(p::AbstractMCTSPolicy{GridWorldState, GridWorldAction, ManhattanHeuristic},
                s::GridWorldState,
                a::GridWorldAction)
    pk = prior_knowledge(p)
    if s == pk.special_state
        return pk.special_N
    else
        return 0
    end
end

init_N (generic function with 2 methods)

Now we can use this prior knowledge in a solver.

In [5]:
heuristic = ManhattanHeuristic(GridWorldState(9,3),
                               GridWorldState(4,6),
                               -5.0, 100)

solver = MCTSSolver(n_iterations = 10, prior_knowledge=heuristic)

mdp = GridWorld()
policy = solve(solver, mdp)
action(policy, GridWorldState(4,6))

Set value for POMDPModels.GridWorldState(4,6,false) to 1.25




Set value for POMDPModels.GridWorldState(4,7,false) to 1.1111111111111112
Set value for POMDPModels.GridWorldState(3,7,false) to 1.0
Set value for POMDPModels.GridWorldState(3,6,false) to 1.1111111111111112
Set value for POMDPModels.GridWorldState(4,5,false) to 1.4285714285714286
Set value for POMDPModels.GridWorldState(5,5,false) to 1.6666666666666667
Set value for POMDPModels.GridWorldState(5,6,false) to 1.4285714285714286
Set value for POMDPModels.GridWorldState(4,8,false) to 1.0
Set value for POMDPModels.GridWorldState(2,7,false) to 0.9090909090909091
Set value for POMDPModels.GridWorldState(5,7,false) to 1.25


POMDPModels.GridWorldAction(:up)

We can see from the print statements that the estimate_value function was called 10 times, and we can also verify that Q and N were set correctly for the special state:

In [6]:
for (s,sn) in policy.tree
    for san in sn.sanodes
        println("s:$s, a:$(san.action) Q:$(san.Q) N:$(san.N)")
    end
end

s:POMDPModels.GridWorldState(4,6,false), a:POMDPModels.GridWorldAction(:up) Q:-4.908567654055299 N:109
s:POMDPModels.GridWorldState(4,6,false), a:POMDPModels.GridWorldAction(:down) Q:-5.0 N:100
s:POMDPModels.GridWorldState(4,6,false), a:POMDPModels.GridWorldAction(:left) Q:-5.0 N:100
s:POMDPModels.GridWorldState(4,6,false), a:POMDPModels.GridWorldAction(:right) Q:-5.0 N:100
s:POMDPModels.GridWorldState(3,7,false), a:POMDPModels.GridWorldAction(:up) Q:0.8215340909090908 N:2
s:POMDPModels.GridWorldState(3,7,false), a:POMDPModels.GridWorldAction(:down) Q:0.0 N:0
s:POMDPModels.GridWorldState(3,7,false), a:POMDPModels.GridWorldAction(:left) Q:0.0 N:0
s:POMDPModels.GridWorldState(3,7,false), a:POMDPModels.GridWorldAction(:right) Q:0.0 N:0
s:POMDPModels.GridWorldState(5,5,false), a:POMDPModels.GridWorldAction(:up) Q:0.0 N:0
s:POMDPModels.GridWorldState(5,5,false), a:POMDPModels.GridWorldAction(:down) Q:0.0 N:0
s:POMDPModels.GridWorldState(5,5,false), a:POMDPModels.GridWorldAction(:left) Q:0.0