In [None]:
import numpy as np
import tensorflow as tf

In [None]:
class Net:
    def __init__(self,data,model_path="./Model_save/model_mtcs.ckpt"):
        self.s = tf.constant(data,dtype=tf.float32)
        self.model_path = model_path
        self.shape = data.shape
        self.p, self.v = self.inference(self.s)
    
    def __call__(self,data):
        self.reset(data)
        return self
        
    def reset(self,data):
        self.s = tf.constant(data,dtype=tf.float32)
        self.shape = data.shape
#         self.p, self.v = self.inference(self.s)
        
    def input_module(self,inputs,filters=256):
        outputs = tf.reshape(inputs,[1,self.shape[0],self.shape[1],-1]) # 后续改进
        outputs = tf.layers.conv2d(outputs,filters=filters,kernel_size=3,padding='same')
        outputs = tf.layers.batch_normalization(outputs)
        outputs = tf.nn.relu(outputs)
        return outputs
    
    def residual_module(self,inputs,filters=256):
        outputs = tf.layers.conv2d(inputs,filters=filters,kernel_size=3,padding='same')
        outputs = tf.layers.batch_normalization(outputs)
        outputs = tf.nn.relu(outputs)

        outputs = tf.layers.conv2d(outputs,filters=filters,kernel_size=3,padding='same')
        outputs = tf.layers.batch_normalization(outputs)

        outputs += inputs
        outputs = tf.nn.relu(outputs)  
        return outputs
    
    def output_module(self,inputs,fc_units_v=256,fc_units_p=19*19+1):
        outputs_v = tf.layers.conv2d(inputs,filters=1,kernel_size=1,padding='same')
        outputs_v = tf.layers.batch_normalization(outputs_v)
        outputs_v = tf.nn.relu(outputs_v)

        outputs_v = tf.layers.flatten(outputs_v)

        outputs_v = tf.layers.dense(outputs_v,fc_units_v,activation=tf.nn.relu)

        outputs_v = tf.layers.dense(outputs_v,1,activation=tf.nn.tanh)
        
        outputs_p = tf.layers.conv2d(inputs,filters=2,kernel_size=1,padding='same')
        outputs_p = tf.layers.batch_normalization(outputs_p)
        outputs_p = tf.nn.relu(outputs_p)
        
        outputs_p = tf.layers.flatten(outputs_p)
        
        outputs_p = tf.layers.dense(outputs_p,fc_units_p,activation=tf.nn.softmax)
        return outputs_p, outputs_v
    
    def inference(self,inputs,filters=4,fc_units_v=32,fc_units_p=4,num_res_module=4):
        inputs = self.input_module(inputs,filters)
        
        mid_values = inputs
        
        for i_module in range(num_res_module):
            mid_values = self.residual_module(mid_values,filters)
            
        outputs_p, outputs_v = self.output_module(mid_values,fc_units_v,fc_units_p)
        
        self.train_vars = tf.trainable_variables()
        self.saver = tf.train.Saver(self.train_vars)

        return outputs_p, outputs_v
    
#     def policy(self,state):
#         action_probs = dict()
#         value = None
        
#         p, v = self.inference(state)
#         with tf.Session() as sess:
#             sess.run(tf.global_variables_initializer())
#             p_, value = sess.run([p,v])
#             print(p_)
            
    def pre_train(self,v_list,p_list,z,mcts_list,c=0.0001):
        loss = tf.Variable(0.,trainable=False,dtype=tf.float32)
        for v in v_list:
            loss = tf.add(loss,tf.square(v-z)) 

        for p,m in zip(p_list,mcts_list):
            loss = tf.add(loss,-tf.matmul(tf.log(p),m)) 

        var = tf.trainable_variables()
#         print(len(var))
        regularizer = tf.contrib.layers.l2_regularizer(scale=c)
        l2_var = tf.contrib.layers.apply_regularization(regularizer,var)
        
        loss = tf.add(loss, l2_var)

        self.loss = loss
        self.optimizer = tf.train.AdamOptimizer().minimize(self.loss)


In [None]:
net = Net(np.zeros((2,2),dtype=np.float32))  

In [None]:
net.policy(np.zeros((2,2),dtype=np.float32))

In [None]:
net = net(2*np.random.randn(2,2))

In [None]:
v_list = [net.v,net.v,net.v,net.v]
p_list = [net.p,net.p,net.p]
z = tf.constant(1.0,dtype=tf.float32) 
const_p = tf.constant([0.2,0.1,0.6,0.1],dtype=tf.float32)
const_p = tf.reshape(const_p,shape=[-1,1])
mcts_list = [const_p,const_p,const_p]

In [None]:
net.pre_train(v_list,p_list,z,mcts_list,c=0.0001)

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    try:
        net.saver.restore(sess,save_path=net.model_path)
        print("last model and params restore path",net.model_path)
    except Exception as e:
        print(e.message)
        print("not found restore path!")
    finally:
#         for it in range(2):
#             sess.run(net.optimizer)
            
        net.saver.save(sess,save_path=net.model_path)
        print("latest model and prams save path:",net.model_path)
        
        print(sess.run(net.loss))
        print(sess.run(net.p))
        print(sess.run(net.v))
              

In [None]:
class TreeNode:
    def __init__(self,parent,prior_p):
        self._parent = parent
        self._children = dict()
        self._n_visits = 0
        
        self._Q = 0
        self._W = 0
        self._u = prior_p
        self._p = prior_p
        
    def expand(self, action_priors):
        for action, prob in action_priors:
            if action not in self._children:
                self._children[action] = TreeNode(self,prob)
                
    def select(self):
        return max(self._children.items(), key=lambda act_node: act_node[1].get_value())
    
    def get_value(self):
        return self._Q + self._u
    
    def update(self,leaf_value,c_puct):
        self._n_visits += 1
        self._W += leaf_value
#         self._Q = (leaf_value - self._Q) / self._n_visits
        self._Q = self._W / self._n_visits
        if not self.is_root():
            self._u = c_puct * self._p * np.sqrt(self._parent._n_visits) / (1+self._n_visits)
            
    def is_root(self):
        return self._parent is None
    
    def is_leaf(self):
        return self._children == dict()
     
    def backup(self,leaf_value,c_puct):
        if self._parent:
            self._parent.backup(leaf_value,c_puct)
            
        self.update(leaf_value,c_puct)

In [None]:
root = TreeNode(None,1.0)

In [None]:
class MCTS:
    def __init__(self,net,init_data,c_puct=5,max_depth=20):
        self._root = TreeNode(None,1.0)
        self._policy, self._value = net(init_data).inference(net.s)
        self.c_puct = c_puct
        self.max_depth = max_depth
        
    def play(self,state,leaf_depth):
        node = self._root
        
        for i in range(leaf_depth):
            if node.is_leaf():
                action_probs = self._policy(state)
                node.expand(action_probs)
            
            action, node = node.select()
            state.do_move(action)
        v = self._value(state)
        

In [None]:
root.is_leaf()

In [None]:
class MCTS:
    root_data = np.zeros((2,2),np.float32)
    net = Net(root_data)
    actions = np.arange(0,4)
    
    def __init__(self,data):
        # 节点信息
        self.data = data # 实体的数据表示
        self.det = np.linalg.det(self.data)
        self.num_layer = 0 # 层数
        self.status = 0 # 0 表示叶子节点，1 表示未完全展开节点， 2表示完全展开节点
        self.label = self.is_done() # 是否为最终节点（可以判断胜负的节点）
        
        self.net.reset()
#         self.value = self.net.get_value(self.data) # （p, v） = f_theta(data) 公式中的v，当前玩家的胜率
        
        # 边信息
        self.num_visited = 0
        self.cum_value = 0
        self.average_value = 0
        
        # 树信息
        self.parent = None
        self.children = []
        
    def is_done(self):
        if abs(self.det) >= 1:
            return True
        else:
            return False
    
    def select(self,c_puct,epsilon):
        if self.children:
            a_list = []
            p_list = self.net.p
            
            for child in self.childern:
                p
            
    def expand(self):
        pass
    def evaluate(self):
        pass
    def backup(self):
        pass
    def play(self):
        pass
    
    


In [None]:
class T:
    a = 1
    def __init__(self):
        self.a += 10
     
t = T()
print(T.a)
print(t.a)
print(T.a)
T.a = 2
t = T()
print(T.a)
print(t.a)
print(T.a)

In [None]:
tree = MCTS(np.zeros((2,2)))

In [None]:
tree.actions

In [None]:
tree.net.v