In [341]:
from gym.envs.classic_control import CartPoleEnv
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib import cm
from ete3 import Tree, TreeStyle, TextFace
g_accuracy = 1e12

In [342]:
class CartPoleWrapper(CartPoleEnv):
    def __init__(self):
        super().__init__()

    def reset(self, init_state=None):
        rand_obs = super().reset()  # call the base reset to do all of the other stuff

        if init_state is None:
            return rand_obs

        self.state = np.array(init_state)  # and then edit state if we want (state is a base class attribute)
        return self.state
    
    def get_mcts_state(self, acc):
        return tuple([int(dim * acc) for dim in self.state])
    
    def get_tree_state(self, mcts_state):
        return str([round(float(dim)/g_accuracy, 2) for dim in [mcts_state[0], mcts_state[2]]])

env = CartPoleWrapper()

# Gym Tree Graph

In [356]:
class MCTSTree:
    def __init__(self):
        self.tree = Tree()
        print(self.tree)        
        
    def update_tree(self, prev_obs, curr_obs):
        # needs to add new nodes when discovered and also increment the counter
        prev_len = len(self.tree.search_nodes(name=str(prev_obs)))
        curr_len = len(self.tree.search_nodes(name=str(curr_obs)))
        parent_state, child_state = env.get_tree_state(prev_obs), env.get_tree_state(curr_obs)
        
        if curr_len==0 and prev_len==0:
            # if neither the parent of the child node exist then this must be a root
            parent = self.tree.add_child(name=str(prev_obs), dist=1, support=1)
            child = parent.add_child(name=str(curr_obs), dist=1, support=1)
            
            parent.add_features(state=parent_state)
            child.add_features(state=child_state)
            return
        
        # check if the curr_obs already exists
        curr_node = self.tree.search_nodes(name=str(curr_obs))
        if len(curr_node) != 0:
            curr_node[0].support += 1
        else:
            # need to find the parent of this and then add this as a child to that parent          
            parent = self.tree.search_nodes(name=str(prev_obs))[0]
            child = parent.add_child(name=str(curr_obs), dist=1, support=1)
            
            
            child.add_features(state=child_state)
            child_face = TextFace(child_state)
            done = abs(curr_obs[0]) > env.x_threshold*g_accuracy or abs(curr_obs[2]) > env.theta_threshold_radians*g_accuracy
            if done:
                c = cm.viridis(128.5) # viridis goes from 0-255
                child_face.background.color = "#{0:02x}{1:02x}{2:02x}".format(*[int(round(i * 255)) for i in [c[0], c[1], c[2]]])
            
            child.add_face(child_face, column=1, position = "branch-bottom")
            print(parent.up.state)
            
    def show():
        ts = TreeStyle()
        ts.show_leaf_name = False
        ts.show_branch_support = True
        ts.title.add_face(TextFace("Hello ETE", fsize=20), column=0)
        # each node contains 3 attributes: node.dist, node.name, node.support
        self.tree.show(tree_style=ts) # , show_internal=True)

        
mcts_tree = MCTSTree()
            


--


In [357]:
def execute_episode():
    
    for eps in range(20):
        prev_obs = env.reset([0, 0, 0.01, 0]) # list of 4 elements
        prev_obs = env.get_mcts_state(g_accuracy)
        done = False
        score = 0

        while not done:
            action = env.action_space.sample()   # choose random action (0-left or 1-right)      
            observation, reward, done, info = env.step(action)
            obs = env.get_mcts_state(g_accuracy)
            mcts_tree.update_tree(prev_obs, obs)
            prev_obs = obs

            score += 1
        print(score)
    mcts_tree.tree.show()

In [358]:
execute_episode()

[0.0, 0.01]
[0.0, 0.01]
[0.0, 0.0]
[0.01, -0.01]
[0.02, -0.02]
[0.04, -0.05]
[0.06, -0.08]
[0.07, -0.1]
[0.09, -0.13]
[0.11, -0.16]
[0.12, -0.18]
[0.13, -0.19]
[0.13, -0.2]
14
[0.0, 0.01]
[0.0, 0.0]
[0.01, -0.01]
[0.02, -0.01]
[0.02, -0.02]
[0.04, -0.04]
[0.05, -0.07]
[0.06, -0.08]
[0.07, -0.1]
[0.07, -0.1]
[0.08, -0.12]
[0.09, -0.13]
[0.09, -0.13]
[0.08, -0.13]
[0.08, -0.12]
[0.06, -0.11]
[0.05, -0.09]
[0.04, -0.08]
[0.03, -0.08]
[0.03, -0.08]
[0.03, -0.09]
[0.03, -0.11]
[0.04, -0.13]
[0.04, -0.14]
[0.04, -0.16]
[0.04, -0.16]
[0.04, -0.18]
29
[0.0, 0.01]
[0.0, 0.01]
[0.0, 0.0]
[0.0, 0.0]
[-0.0, 0.01]
[-0.01, 0.02]
[-0.02, 0.04]
[-0.03, 0.05]
[-0.03, 0.06]
[-0.03, 0.06]
[-0.04, 0.07]
[-0.04, 0.07]
[-0.03, 0.06]
[-0.02, 0.06]
[-0.01, 0.04]
[-0.0, 0.03]
[0.01, 0.02]
[0.02, -0.0]
[0.04, -0.03]
[0.06, -0.05]
[0.07, -0.06]
[0.08, -0.08]
[0.1, -0.09]
[0.11, -0.11]
[0.13, -0.14]
[0.15, -0.18]
27


AttributeError: 'TreeNode' object has no attribute 'state'

## Node Annotation and Presentation

In [5]:
# Creates a tree
t = Tree( '((H:0.3,I:0.1):0.5, A:1, (B:0.4,(C:0.5,(J:1.3, (F:1.2, D:0.1):0.5):0.5):0.5):0.5);' )
print(t)

# Let's locate some nodes using the get common ancestor method
ancestor=t.get_common_ancestor("J", "F", "C")
# the search_nodes method (I take only the first match )
A = t.search_nodes(name="A")[0]
# and using the shorcut to finding nodes by name
C= t&"C"
H= t&"H"
I= t&"I"

# Let's now add some custom features to our nodes. add_features can be
# used to add many features at the same time.
C.add_features(vowel=False, confidence=1.0)
A.add_features(vowel=True, confidence=0.5)
ancestor.add_features(nodetype="internal")

# Or, using the oneliner notation
(t&"H").add_features(vowel=False, confidence=0.2)

# But we can automatize this. (note that i will overwrite the previous
# values)
for leaf in t.traverse():
    if leaf.name in "AEIOU":
        leaf.add_features(vowel=True, confidence=random.random())
    else:
        leaf.add_features(vowel=False, confidence=random.random())

# Now we use these information to analyze the tree.
print("This tree has", len(t.search_nodes(vowel=True)), "vowel nodes")
print("Which are", [leaf.name for leaf in t.iter_leaves() if leaf.vowel==True])

# But features may refer to any kind of data, not only simple
# values. For example, we can calculate some values and store them
# within nodes.
#
# Let's detect leaf nodes under "ancestor" with distance higher thatn
# 1. Note that I'm traversing a subtree which starts from "ancestor"
matches = [leaf for leaf in ancestor.traverse() if leaf.dist>1.0]

# And save this pre-computed information into the ancestor node
ancestor.add_feature("long_branch_nodes", matches)

# Prints the precomputed nodes
print("These are nodes under ancestor with long branches", [n.name for n in ancestor.long_branch_nodes])
ancestor.add_feature('k', 6)


      /-H
   /-|
  |   \-I
  |
--|--A
  |
  |   /-B
   \-|
     |   /-C
      \-|
        |   /-J
         \-|
           |   /-F
            \-|
               \-D
This tree has 8 vowel nodes
Which are ['I', 'A']
These are nodes under ancestor with long branches ['J', 'F']


In [6]:
ts = TreeStyle()
ts.show_leaf_name = True
ts.show_branch_length = True
ts.show_branch_support = True
ts.title.add_face(TextFace("Hello ETE", fsize=20), column=0)
t.show(tree_style=ts)

## Node Basics

In [3]:
t = Tree("(A:1,(B:1,(E:1,D:1):0.5):0.5);" )
print(t)
node = t.search_nodes(name="B")[0]
while node:
    print(node)
    node = node.up


   /-A
--|
  |   /-B
   \-|
     |   /-E
      \-|
         \-D

--B

   /-B
--|
  |   /-E
   \-|
      \-D

   /-A
--|
  |   /-B
   \-|
     |   /-E
      \-|
         \-D


In [4]:
D = t.search_nodes(name="D")[0]

# I get all nodes with distance=0.5
nodes = t.search_nodes(dist=0.5)
print(len(nodes), "nodes have distance=0.5", nodes)

# We can limit the search to leaves and node names (faster method).
D = t.get_leaves_by_name(name="D")
print(D)

2 nodes have distance=0.5 [Tree node '' (-0x7fffffda5f5056d2), Tree node '' (-0x7fffffda5f505670)]
[Tree node 'D' (-0x7fffffda5f50565b)]


In [4]:
t1 = Tree(format=5)
# We create a random tree topology
t1.populate(15)
print(t1)
print(t1.children)
print(t1.get_children())
print(t1.up)
print(t1.name)
print(t1.dist)
print(t1.is_leaf())
print(t1.get_tree_root())
print(t1.children[0].get_tree_root())
print(t1.children[0].children[0].get_tree_root())
# You can also iterate over tree leaves using a simple syntax
for leaf in t1:
    print(leaf.name)


         /-aaaaaaaaah
      /-|
     |  |   /-aaaaaaaaai
     |   \-|
     |     |   /-aaaaaaaaaj
     |      \-|
   /-|         \-aaaaaaaaak
  |  |
  |  |      /-aaaaaaaaal
  |  |   /-|
  |  |  |   \-aaaaaaaaam
  |   \-|
  |     |   /-aaaaaaaaan
--|      \-|
  |         \-aaaaaaaaao
  |
  |      /-aaaaaaaaaa
  |   /-|
  |  |  |   /-aaaaaaaaab
  |  |   \-|
  |  |      \-aaaaaaaaac
   \-|
     |      /-aaaaaaaaad
     |   /-|
     |  |   \-aaaaaaaaae
      \-|
        |   /-aaaaaaaaaf
         \-|
            \-aaaaaaaaag
[Tree node '' (0x13f42d075c), Tree node '' (-0x7fffffec0bd2f89a)]
[Tree node '' (0x13f42d075c), Tree node '' (-0x7fffffec0bd2f89a)]
None

1.0
False

         /-aaaaaaaaah
      /-|
     |  |   /-aaaaaaaaai
     |   \-|
     |     |   /-aaaaaaaaaj
     |      \-|
   /-|         \-aaaaaaaaak
  |  |
  |  |      /-aaaaaaaaal
  |  |   /-|
  |  |  |   \-aaaaaaaaam
  |   \-|
  |     |   /-aaaaaaaaan
--|      \-|
  |         \-aaaaaaaaao
  |
  |      /-aaaaaaaaaa
  |   /-|
  

In [340]:
from matplotlib import cm
for x in range(20, 25, 0.25):
    print(x, cm.viridis(x))

TypeError: 'float' object cannot be interpreted as an integer

In [251]:
import colorsys
def hsv2rgb(h,s,v):
    return tuple(round(i * 255) for i in colorsys.hsv_to_rgb(h,s,v))

In [252]:
hsv2rgb(0.5,0.5,0.5)

(64, 128, 128)