MCTS完成以下功能
    >  pi = policy(net,state)
    >  action = choice(pi)
    >  tree = next_tree(state)

In [5]:
import datetime

import numpy as np
import tensorflow as tf

In [None]:
class MCTS:
    def __init__(self,net,simulator,init_data,c_puct=5,max_depth=20,num_sims=30):
        self.net = net
        self.simulator = simulator
        self._root = TreeNode(None,1.0)
        self.c_puct = c_puct
        self.max_depth = max_depth
        self.cur_depth = 0
        self.cur_node = self._root
        self.root_state = init_data
        self.cur_state = init_data.copy()
        
    def policy(self,state,tao):
        for it in range(num_sims):
            self.rollout()
            self.cur_node = self._root
            self.cur_state = self.root_state
            
        pi = [x._n_visits for x in self._root._children]
   
        # 需考虑退火 温度系数，
        pi = (np.power(pi,1/tao)/np.sum(np.power(pi,1/tao))).tolist()
        return pi
      
    def next_search(self):
        self._root = self._root.childern[self.action]
        self._root._parent = None
        
        self.root_state = self.simulator.move(self.action)
        self.cur_state = self.simulator.move(self.action)
        
        self.cur_depth = 0
        self.cur_node = self._root

        return self
        
    def choice(self,tao):
        
        pi = self.policy(self.cur_state,tao)
        action = np.random.choice(range(len(pi)),p=pi)
        self.action = action
        return action
        
        
    def rollout(self):
        while self.cur_depth < self.max_depth:
            if not self.cur_node.is_leaf():
                action, self.cur_node = self.cur_node.select()
                self.cur_state = self.simulator.get_state(self.cur_state,action)
                self.cur_depth += 1
            else:
                action_priors = enumerate(self.net.get_var_value(self.net.p,{self.net.s:self.cur_state}).reshape(-1).tolist())
                self.cur_node.expand(action_priors)
                
                leaf_value = self.net.get_var_value (self.net.v,{self.net.s:self.cur_state})[0,0]       
                self.cur_node.backup(leaf_value,self.c_puct)
                
                
        

In [19]:
class TreeNode:
    def __init__(self,parent,prior_p):
        self._parent = parent
        self._children = dict()
        self._n_visits = 0
        
        self._Q = 0
        self._W = 0
        self._u = prior_p
        self._p = prior_p
        
    def expand(self, action_priors):
        for action, prob in action_priors:
            if action not in self._children:
                self._children[action] = TreeNode(self,prob) #prob考虑加入噪声
                
    def select(self):
        return max(self._children.items(), key=lambda act_node: act_node[1].get_value())
    
    def get_value(self):
        return self._Q + self._u
    
    def update(self,leaf_value,c_puct):
        self._n_visits += 1
        self._W += leaf_value
#         self._Q += (leaf_value - self._Q) / self._n_visits
        self._Q = self._W / self._n_visits
        if not self.is_root():
            self._u = c_puct * self._p * np.sqrt(self._parent._n_visits) / (1+self._n_visits)
            
    def is_root(self):
        return self._parent is None
    
    def is_leaf(self):
        return self._children == dict()
     
    def backup(self,leaf_value,c_puct):
        if self._parent:
            self._parent.backup(leaf_value,c_puct)
            
        self.update(leaf_value,c_puct)