### 数据表示
* 状态 state: 例如 state = np.zeros((3,3),dtype=np.float32)
* 网络策略 p: 例如 p = np.array([0.1,0.2,0.05,0.1,0.4,0.15],dtype=np.float32)
* MCTS策略 pi: 同p格式一致
* 网络值 v: 为float型标量 例如 v = 0.75
* 模拟器结果 z: 为float型标量 例如 v = -1.0
* 动作 action: 为int型标量 例如 action = 2

### 超参数

* 问题规模 size: 例如 size = 19
* 输入历史表示规模 history: 例如 history = 8
* 每层卷积层核的个数 filters: 例如 filters = 256
* 残差模块的个数 n_modules: n_modules = 19
* 正则化参数 lambda_c: lambda_c = .0001
* 动作空间的大小 n_actions: n_actions = 9
* 探索与利用的平衡参数 c_puct: c_puct = 5
* 探索与利用的惯性因子 epsilon: epsilon = .25
* 探索与利用的Dirchlet分布参数 eta: eta = .03
* 退火算法参数 tau: tau = .7
* 搜索次数 n_search: n_search = 1600
* 搜索最大深度 max_search_depth: max_search_depth = 20
* 执行最大步数 max_play_steps: max_play_steps = 500
* 估值阈值 v_resign: v_resign = -.95
* 网络损失函数的迭代次数 n_iters: n_iters = 50

### 数学公式

1. annealing(arr,tau) # 退火公式，arr为计数值
2. sampling_from_dirchlet(alpha,size) # 从狄利克雷分布中采样,alpha为噪声参数，size为采样大小
3. add_noise_to_p(p,epsilon,eta) # 为策略p加入噪声eta
4. ucb(p,n_visits,c_puct) # 上置信界计算，p为策略，n_visits为访问数
5. cross_entropy(p,q) # p与q的交叉熵

### 内部函数及参数命名规范

#### Net类

1. p_op,v_op = inference(state)
2. p = get_var_value(p_op,feed_dict) 
3. v = get_var_value(v_op,feed_dict)
4. l_op = loss_op(z,v_op,pi,p_op,lambda_c) # 单个的状态的loss
5. optimizer = minimize(l_op)
6. l = get_var_value(l_op,feed_dict)
7. update() # 将 T个state，T-1个pi，1个z ,迭代次数n_iters传入，更新网络参数
8. 初始化state_placeholder, z_placeholder,pi_placeholder为placeholder类型
9. save_params()
10. restore_params()

#### TreeNode类
1. action, sub_node = select()
2. expand(action_priors)
3. update(leaf_value,c_puct)
4. backup(leaf_value,c_puct)

#### MCTS类
1. search() # 单次搜索
2. pi = policy(state,tao) # 多次搜索后返回pi
3. action = choice(tao) # 在当前state下给出action，是外界调用的接口
4. next_prepare() # 为下次搜索做准备，在choice中调用

#### Simulator类
1. move(action)
2. get_state(state,action)
3. render()
4. is_done() # 判断是否结束
5. who_win() # 1 代表赢， -1 代表输， 0 代表未结束 

## 流程

0. 初始化
1. 模拟一局
2. 显示过程和结果并保存记录
3. 训练网络
4. 更新网络
5. 跳转 1.
6. 效果评估

In [117]:
p = np.array([0.1,0.5,0.05,.2,.15])
q = np.array([.2,.3,.005,.095,.4])

In [119]:
p = [0.1,0.5,0.05,.2,.15]
q = [.2,.3,.005,.095,.4]

In [120]:
cross_entropy(p,q)

-1.6360653489912225

In [16]:
def annealing(arr,tau):
    return np.power(arr,1.0/tau) / np.sum(np.power(arr,1.0/tau))

In [70]:
def sampling_from_dirchlet(alpha,size):
    eta = np.random.dirichlet([alpha]*size)
    return eta

In [92]:
def add_noise_to_p(p,epsilon=.25,eta=0.03):
    return (1-epsilon) * p + epsilon * sampling_from_dirchlet(eta,p.size)
    

In [97]:
def ucb(p,n_visits,c_puct):
    return c_puct * p * np.sqrt(np.sum(n_visits)) / (1+np.array(n_visits))

In [112]:
def cross_entropy(p,q): 
    return np.dot(np.array(p),np.log(np.array(q)))

In [7]:
import numpy as np

In [89]:
sampling_from_dirchlet(0.03, 5)

array([4.65918461e-21, 1.31356672e-01, 8.68416201e-01, 2.24254998e-04,
       2.87167695e-06])

In [5]:
np.array([0.1,0.2,0.05,0.1,0.4,0.15],dtype=np.float32)

array([0.1 , 0.2 , 0.05, 0.1 , 0.4 , 0.15], dtype=float32)

In [None]:
import datetime
import numpy as np
import tensorflow as tf

In [None]:
class Net:
    def __init__(self, size = 3, model_path="./Model_save/model_net_demo.ckpt"):
        self.s = tf.placeholder(dtype=tf.float32,shape=(size,size))
        self.shape= self.s.shape
        self.n_actions = size*size ## +1?
        self.model_path = model_path
        self.p, self.v = self.inference(self.s)
        self.train_vars = tf.trainable_variables()
        self.saver = tf.train.Saver(self.train_vars)
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
   
    def save(self):
        self.saver.save(self.sess,save_path=self.model_path)
        print(str(datetime.datetime.now())+\
              ": latest model and params saved to path:",self.model_path)
        
    def restore(self):
        self.saver.restore(self.sess,save_path=self.model_path)
        print(str(datetime.datetime.now())+\
              ": last model and params restored from path",self.model_path)
        
    def inference(self,inputs,filters=4,fc_units_v=32,num_res_module=4):
        fc_units_p = self.n_actions
        
        inputs = self.input_module(inputs,filters) # 输入模块
        
        mid_values = inputs
        for i_module in range(num_res_module):
            mid_values = self.residual_module(mid_values,filters)
            
        outputs_p, outputs_v = self.output_module(mid_values,fc_units_v,fc_units_p)
        
        
        return outputs_p, outputs_v

    def update(self,v_op_list,p_op_list,z_value,pi_value_list,c=0.01,num_iters=20):
        loss = tf.Variable(0.,trainable=False,dtype=tf.float32)
        for v in v_op_list:
            loss = tf.add(loss,tf.square(v-z_value)) 

        for p,pi in zip(p_op_list,pi_value_list):
            loss = tf.add(loss,-tf.matmul(tf.log(p),tf.transpose(pi))) 

        var = tf.trainable_variables()
        regularizer = tf.contrib.layers.l2_regularizer(scale=c)
        l2_var = tf.contrib.layers.apply_regularization(regularizer,var)
        
        loss = tf.add(loss, l2_var)
        
        self.loss = loss
        self.optimizer = tf.train.AdamOptimizer().minimize(self.loss)
        
        self.sess.run(tf.global_variables_initializer())
        try:
            self.restore()
        except Exception as e:
            print(e.message)
            
        finally:
            for it in range(num_iters):
                self.sess.run(self.optimizer)
                
            self.save()
        
        
    def get_var_value(self,var,feed_dict):
#         print("get var value:",var)
        return self.sess.run(var,feed_dict)
    
    def input_module(self,inputs,filters=256):
        outputs = tf.reshape(inputs,[1,self.shape[0],self.shape[1],-1]) # 后续改进
        outputs = tf.layers.conv2d(outputs,filters=filters,kernel_size=3,padding='same')
        outputs = tf.layers.batch_normalization(outputs)
        outputs = tf.nn.relu(outputs)
        return outputs
    
    def residual_module(self,inputs,filters=256):
        outputs = tf.layers.conv2d(inputs,filters=filters,kernel_size=3,padding='same')
        outputs = tf.layers.batch_normalization(outputs)
        outputs = tf.nn.relu(outputs)

        outputs = tf.layers.conv2d(outputs,filters=filters,kernel_size=3,padding='same')
        outputs = tf.layers.batch_normalization(outputs)

        outputs += inputs
        outputs = tf.nn.relu(outputs)  
        return outputs
    
    def output_module(self,inputs,fc_units_v=256,fc_units_p=19*19+1):
        outputs_v = tf.layers.conv2d(inputs,filters=1,kernel_size=1,padding='same')
        outputs_v = tf.layers.batch_normalization(outputs_v)
        outputs_v = tf.nn.relu(outputs_v)

        outputs_v = tf.layers.flatten(outputs_v)

        outputs_v = tf.layers.dense(outputs_v,fc_units_v,activation=tf.nn.relu)

        outputs_v = tf.layers.dense(outputs_v,1,activation=tf.nn.tanh)
        
        outputs_p = tf.layers.conv2d(inputs,filters=2,kernel_size=1,padding='same')
        outputs_p = tf.layers.batch_normalization(outputs_p)
        outputs_p = tf.nn.relu(outputs_p)
        
        outputs_p = tf.layers.flatten(outputs_p)
        
        outputs_p = tf.layers.dense(outputs_p,fc_units_p,activation=tf.nn.softmax)
        return outputs_p, outputs_v
   

In [None]:
class MCTS:
    def __init__(self,net,simulator,c_puct=5,max_depth=20,num_sims=30):
        self.net = net
        self.simulator = simulator
        self._root = TreeNode(None,1.0)
        self.c_puct = c_puct
        self.max_depth = max_depth
        self.num_sims = num_sims
        self.cur_depth = 0
        self.cur_node = self._root
        self.root_state = simulator.cur_state
        self.cur_state = self.root_state.copy()
        
    def policy(self,state,tao):
        for it in range(self.num_sims):
            print("第%d次搜索"%(it+1))
            self.rollout()
            self.cur_node = self._root
            self.cur_state = self.root_state
            self.cur_depth = 0
            
        pi = [x._n_visits for _,x in self._root._children.items()]
        print(pi)
   
        # 需考虑退火 温度系数，
        pi = (np.power(pi,1/tao)/np.sum(np.power(pi,1/tao))).tolist()
        return pi
      
    def next_search(self):
        self._root = self._root.childern[self.action]
        self._root._parent = None
        
        self.root_state = self.simulator.move(self.action)
        self.cur_state = self.simulator.move(self.action)
        
        self.cur_depth = 0
        self.cur_node = self._root

        return self
        
    def choice(self,tao):
        
        pi = self.policy(self.cur_state,tao)
#         print(pi)
        action = np.random.choice(range(len(pi)),p=pi)
        self.action = action
        return action
        
        
    def rollout(self):
        while self.cur_depth < self.max_depth:
            if not self.cur_node.is_leaf():
#                 print("No leaf")
                
                action, self.cur_node = self.cur_node.select()
                
                self.cur_state = self.simulator.get_state(self.cur_state,action)
#                 print(self.cur_state)
                self.cur_depth += 1
            else:
#                 print("yes leaf")
                action_priors = enumerate(self.net.get_var_value(self.net.p,{self.net.s:self.cur_state}).reshape(-1).tolist())
                self.cur_node.expand(action_priors)
                
                leaf_value = self.net.get_var_value (self.net.v,{self.net.s:self.cur_state})[0,0]       
                self.cur_node.backup(leaf_value,self.c_puct)

In [None]:
class TreeNode:
    def __init__(self,parent,prior_p):
        self._parent = parent
        self._children = dict()
        self._n_visits = 0
        
        self._Q = 0
        self._W = 0
        self._u = prior_p
        self._p = prior_p
        
    def expand(self, action_priors):
        for action, prob in action_priors:
#             print(prob)
            if action not in self._children:
                self._children[action] = TreeNode(self,prob)
                
    def select(self):
        return max(self._children.items(), key=lambda act_node: act_node[1].get_value())
    
    def get_value(self):
        print(self._u,self._Q,self._p)
        return self._Q + self._u
    
    def update(self,leaf_value,c_puct):
        self._n_visits += 1
        self._W += leaf_value
#         self._Q = (leaf_value - self._Q) / self._n_visits
        self._Q = self._W / self._n_visits
        if not self.is_root():
            self._u = c_puct * self._p * np.sqrt(self._parent._n_visits) / (1+self._n_visits)
            
    def is_root(self):
        return self._parent is None
    
    def is_leaf(self):
        return self._children == dict()
     
    def backup(self,leaf_value,c_puct):
        self.update(leaf_value,c_puct)
        if self._parent:
            self._parent.backup(leaf_value,c_puct)
            
        

In [None]:
class Simulator:
    def __init__(self,size=3):
        self.size = size
        self.cur_state = np.zeros((size,size),dtype=np.float32)
        self.records = [self.cur_state]
        self.action_records = []
        self.cur_hands = 0
        self.n_actions = size**2
        
    def transform(self,state):
        pass
    
    def move(self,action):
        self.cur_hands += 1
        state = self.cur_state
        
        if action>self.n_actions or action<0:
            print("pass")
        elif action == self.n_actions:
            print("pass")
        else:
            i = action//self.size
            j = action % self.size
            state[i,j] += 1
        self.cur_state = state
        self.records.append(state.copy())
        self.action_records.append(action)
        
    def get_state(self,state,action):
        if  not (action>self.n_actions or action<0):
            i = action//self.size
            j = action % self.size
            state[i,j] += 1
            
        return state

    def is_done(self,state):
        if abs(np.linalg.det(state)) >= 1:
            return True
        else:
            return False
        
    def is_win(self,state):
        if np.linalg.det(state) == 1:
            return True
        else:
            return False
        
    def render(self):
        print("hands:",self.cur_hands)
        print("state: \n")
        print(self.cur_state)
        
    def save_records(self):
        pass
    
    def reload_record():
        pass

In [None]:
net = Net()
simulator = Simulator()


In [None]:
mcts_tree = MCTS(net,simulator,c_puct=5,max_depth=20,num_sims=30)

In [None]:
max_steps = 10

for i in range(max_steps):
    if simulator.is_win(simulator.cur_state):
        print("Win, steps is %d"%(i+1))
        break
        
    else:
        action = mcts_tree.choice(1.)
        simulator.move(action)
        if i == max_steps -1 :
            print("Lose, steps is %d"%(i+1))
            
print("Records:\n",simulator.records)
            

In [None]:
simulator.render()
action = mcts_tree.choice(1.)
simulator.move(action)
simulator.render()

simulator.cur_state


In [None]:
mcts_tree.cur_depth