Skip to content

Commit

Permalink
Trees (#29)
Browse files Browse the repository at this point in the history
* now uses centralized tree structure

(tests do not pass and tree visualization does not work)

* got rid of some deprecations

* timing tests

* changed some things to work with VDP tag

* lowered max size for preallocated tree

* export n_children

* added isroot

* this is a lot of work - stopping 4 lunch

* woo - tree visualization works on the fast branch

* updated docs
  • Loading branch information
zsunberg committed Oct 7, 2017
1 parent 882d9ee commit 4b0b365
Show file tree
Hide file tree
Showing 16 changed files with 2,821 additions and 388 deletions.
4 changes: 3 additions & 1 deletion REQUIRE
@@ -1,6 +1,8 @@
julia 0.6
JSON
POMDPs 0.4
Compat
Blink
POMDPToolbox
CPUTime
D3Trees
Colors
2 changes: 1 addition & 1 deletion docs/src/index.md
Expand Up @@ -93,7 +93,7 @@ Depth = 2

An example of visualization of the search tree in a jupyter notebook is [here](https://nbviewer.jupyter.org/github/JuliaPOMDP/MCTS.jl/blob/master/notebooks/Test_Visualization.ipynb) (or [here](https://github.com/JuliaPOMDP/MCTS.jl/blob/master/notebooks/Test_Visualization.ipynb) is the version on github that will not display quite right but will still show you how it's done).

To display the tree in an Electron window using Blink.jl, run `blink(TreeVisualizer(policy, state))`.
To display the tree in a Google Chrome window, run `using D3Trees; inchrome(D3Tree(policy, state))`.

## Incorporating Additional Prior Knowledge

Expand Down
142 changes: 80 additions & 62 deletions notebooks/Domain_Knowledge_Example.ipynb
Expand Up @@ -18,8 +18,19 @@
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: Method definition Type(Array{Any, 1}, Type{POMDPModels.GridWorld}) in module POMDPModels overwritten.\n",
"WARNING: Method definition (::Type{POMDPModels.GridWorld})() in module POMDPModels at /home/zach/.julia/v0.6/POMDPModels/src/GridWorlds.jl:71 overwritten at /home/zach/.julia/v0.6/POMDPModels/src/GridWorlds.jl:74.\n"
]
}
],
"source": [
"using MCTS\n",
"importall POMDPs\n",
Expand All @@ -39,7 +50,9 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -89,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {
"collapsed": true
},
Expand All @@ -106,8 +119,8 @@
"end\n",
"\n",
"function up_priority(mdp, s, snode) # snode is the state node of type DPWStateNode\n",
" if haskey(snode.A, GridWorldAction(:up)) # \"up\" is already there\n",
" return GridWorldAction(rand([:up, :left, :down, :right])) # add a random action\n",
" if haskey(snode.tree.a_lookup, (snode.index, :up)) # \"up\" is already there\n",
" return GridWorldAction(rand([:left, :down, :right])) # add a random action\n",
" else\n",
" return GridWorldAction(:up)\n",
" end\n",
Expand All @@ -116,29 +129,30 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set value for POMDPModels.GridWorldState(1, 1, false) to 1.0\n",
"Set value for POMDPModels.GridWorldState(1, 2, false) to 1.1111111111111112\n",
"Set value for POMDPModels.GridWorldState(2, 1, false) to 1.1111111111111112\n",
"Set value for POMDPModels.GridWorldState(3, 1, false) to 1.25\n",
"Set value for POMDPModels.GridWorldState(1, 3, false) to 1.25\n",
"Set value for POMDPModels.GridWorldState(1, 3, false) to 1.25\n",
"Set value for POMDPModels.GridWorldState(2, 1, false) to 1.1111111111111112\n",
"Set value for POMDPModels.GridWorldState(2, 2, false) to 1.25\n",
"Set value for POMDPModels.GridWorldState(2, 1, false) to 1.1111111111111112\n",
"State-Action Nodes:\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:left Q:0.0 N:3\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:right Q:1.1039351851851853 N:3\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:up Q:0.7278935185185185 N:3\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:down Q:0.0 N:1\n",
"s:POMDPModels.GridWorldState(3, 1, false), a:up Q:0.0 N:1\n",
"s:POMDPModels.GridWorldState(2, 1, false), a:right Q:0.0 N:1\n",
"s:POMDPModels.GridWorldState(2, 1, false), a:up Q:1.1875 N:2\n",
"s:POMDPModels.GridWorldState(2, 1, false), a:down Q:0.0 N:1\n",
"s:POMDPModels.GridWorldState(1, 2, false), a:up Q:9.094375 N:4\n"
"s:POMDPModels.GridWorldState(1, 1, false), a:up, Q:1.1875 N:2\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:down, Q:0.0 N:4\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:right, Q:1.0951388888888889 N:2\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:left, Q:0.0 N:4\n",
"s:POMDPModels.GridWorldState(1, 2, false), a:up, Q:9.061388888888889 N:4\n",
"s:POMDPModels.GridWorldState(1, 2, false), a:right, Q:11.73 N:3\n",
"s:POMDPModels.GridWorldState(1, 2, false), a:down, Q:11.73 N:3\n",
"s:POMDPModels.GridWorldState(2, 1, false), a:up, Q:0.0 N:0\n"
]
}
],
Expand All @@ -150,9 +164,10 @@
"policy = solve(solver, mdp)\n",
"action(policy, GridWorldState(1,1))\n",
"println(\"State-Action Nodes:\")\n",
"for (s,sn) in policy.tree\n",
" for (a,san) in sn.A\n",
" println(\"s:$s, a:$a Q:$(san.Q) N:$(san.N)\")\n",
"tree = get(policy.tree)\n",
"for i in 1:length(tree.total_n)\n",
" for j in tree.children[i]\n",
" println(\"s:$(tree.s_labels[i]), a:$(tree.a_labels[j]), Q:$(tree.q[j]) N:$(tree.n[j])\")\n",
" end\n",
"end"
]
Expand All @@ -168,7 +183,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"metadata": {
"collapsed": true
},
Expand All @@ -186,7 +201,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"metadata": {
"collapsed": true
},
Expand All @@ -204,7 +219,7 @@
"end\n",
"\n",
"function MCTS.next_action(h::MyHeuristic, mdp::GridWorld, s, snode::DPWStateNode)\n",
" if haskey(snode.A, h.priority_action)\n",
" if haskey(snode.tree.a_lookup, (snode.index, h.priority_action))\n",
" return GridWorldAction(rand(h.rng, [:up, :left, :down, :right])) # add a random other action\n",
" else\n",
" return h.priority_action\n",
Expand All @@ -214,28 +229,28 @@
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set value for POMDPModels.GridWorldState(1, 1, false) to 1.0\n",
"Set value for POMDPModels.GridWorldState(1, 2, false) to 1.1111111111111112\n",
"Set value for POMDPModels.GridWorldState(1, 3, false) to 1.25\n",
"Set value for POMDPModels.GridWorldState(2, 1, false) to 1.1111111111111112\n",
"Set value for POMDPModels.GridWorldState(1, 3, false) to 1.25\n",
"Set value for POMDPModels.GridWorldState(2, 2, false) to 1.25\n",
"Set value for POMDPModels.GridWorldState(3, 1, false) to 1.25\n",
"State-Action Nodes:\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:left Q:0.0 N:4\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:right Q:0.5459201388888888 N:4\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:up Q:1.063637152777778 N:2\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:down Q:0.0 N:4\n",
"s:POMDPModels.GridWorldState(2, 1, false), a:left Q:0.0 N:1\n",
"s:POMDPModels.GridWorldState(2, 1, false), a:right Q:1.1875 N:1\n",
"s:POMDPModels.GridWorldState(2, 1, false), a:up Q:1.1875 N:1\n",
"s:POMDPModels.GridWorldState(2, 1, false), a:down Q:1.128125 N:1\n",
"s:POMDPModels.GridWorldState(1, 2, false), a:up Q:8.7975 N:4\n"
"s:POMDPModels.GridWorldState(1, 1, false), a:up, Q:0.7916666666666667 N:3\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:right, Q:1.129609375 N:2\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:down, Q:0.0 N:4\n",
"s:POMDPModels.GridWorldState(1, 1, false), a:left, Q:0.0 N:4\n",
"s:POMDPModels.GridWorldState(1, 2, false), a:up, Q:9.07953125 N:4\n",
"s:POMDPModels.GridWorldState(1, 2, false), a:left, Q:5.864999999999999 N:6\n",
"s:POMDPModels.GridWorldState(1, 2, false), a:right, Q:11.73 N:3\n"
]
}
],
Expand All @@ -248,9 +263,10 @@
"policy = solve(solver, mdp)\n",
"action(policy, GridWorldState(1,1))\n",
"println(\"State-Action Nodes:\")\n",
"for (s,sn) in policy.tree\n",
" for (a,san) in sn.A\n",
" println(\"s:$s, a:$a Q:$(san.Q) N:$(san.N)\")\n",
"tree = get(policy.tree)\n",
"for i in 1:length(tree.total_n)\n",
" for j in tree.children[i]\n",
" println(\"s:$(tree.s_labels[i]), a:$(tree.a_labels[j]), Q:$(tree.q[j]) N:$(tree.n[j])\")\n",
" end\n",
"end"
]
Expand All @@ -266,7 +282,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 14,
"metadata": {
"collapsed": true
},
Expand All @@ -279,7 +295,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 15,
"metadata": {
"collapsed": true
},
Expand All @@ -300,30 +316,32 @@
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"State-Action Nodes\n",
"s:POMDPModels.GridWorldState(4,2,false), a:POMDPModels.GridWorldAction(:up) Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(4,2,false), a:POMDPModels.GridWorldAction(:down) Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(4,2,false), a:POMDPModels.GridWorldAction(:left) Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(4,2,false), a:POMDPModels.GridWorldAction(:right) Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(4,1,false), a:POMDPModels.GridWorldAction(:up) Q:5.688000922764596 N:1\n",
"s:POMDPModels.GridWorldState(4,1,false), a:POMDPModels.GridWorldAction(:down) Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(4,1,false), a:POMDPModels.GridWorldAction(:left) Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(4,1,false), a:POMDPModels.GridWorldAction(:right) Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5,1,false), a:POMDPModels.GridWorldAction(:up) Q:7.350918906249999 N:1\n",
"s:POMDPModels.GridWorldState(5,1,false), a:POMDPModels.GridWorldAction(:down) Q:0.0 N:20\n",
"s:POMDPModels.GridWorldState(5,1,false), a:POMDPModels.GridWorldAction(:left) Q:-4.172483313482347 N:1\n",
"s:POMDPModels.GridWorldState(5,1,false), a:POMDPModels.GridWorldAction(:right) Q:5.403600876626366 N:1\n",
"s:POMDPModels.GridWorldState(5,2,false), a:POMDPModels.GridWorldAction(:up) Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5,2,false), a:POMDPModels.GridWorldAction(:down) Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5,2,false), a:POMDPModels.GridWorldAction(:left) Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5,2,false), a:POMDPModels.GridWorldAction(:right) Q:0.0 N:0\n"
"s:POMDPModels.GridWorldState(6, 1, false), a:up Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(6, 1, false), a:down Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(6, 1, false), a:left Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(6, 1, false), a:right Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5, 1, false), a:up Q:0.0 N:1\n",
"s:POMDPModels.GridWorldState(5, 1, false), a:down Q:0.0 N:20\n",
"s:POMDPModels.GridWorldState(5, 1, false), a:left Q:7.350918906249998 N:1\n",
"s:POMDPModels.GridWorldState(5, 1, false), a:right Q:7.350918906249999 N:1\n",
"s:POMDPModels.GridWorldState(5, 2, false), a:up Q:7.737809374999998 N:1\n",
"s:POMDPModels.GridWorldState(5, 2, false), a:down Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5, 2, false), a:left Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5, 2, false), a:right Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5, 3, false), a:up Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5, 3, false), a:down Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5, 3, false), a:left Q:0.0 N:0\n",
"s:POMDPModels.GridWorldState(5, 3, false), a:right Q:0.0 N:0\n"
]
}
],
Expand Down

0 comments on commit 4b0b365

Please sign in to comment.