In [2]:
# 训练五子棋AI

from __future__ import print_function
import random
import numpy as np
# deque 是一个双端队列
from collections import defaultdict, deque
from game import Board, Game
from mcts_pure import MCTSPlayer as MCTS_Pure # 随机走子策略的AI
from mcts_alphaZero import MCTSPlayer # AlphaGo方式的AI
from policy_value_net_pytorch import PolicyValueNet  # Pytorch

class TrainPipeline():
    def __init__(self, init_model=None):
        # 设置棋盘和游戏的参数
        self.board_width = 6
        self.board_height = 6
        self.n_in_row = 4
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # 设置训练参数
        self.learn_rate = 2e-3 # 基准学习率
        self.lr_multiplier = 1.0  # 基于KL自动调整学习倍速
        self.temp = 1.0  # 温度参数
        self.n_playout = 400  # 每下一步棋，模拟的步骤数
        self.c_puct = 5 # exploitation和exploration之间的折中系数
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size) #使用 deque 创建一个双端队列
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02 # 早停检查
        self.check_freq = 50 # 每50次检查一次，策略价值网络是否更新
        self.game_batch_num = 500 # 训练多少个epoch
        self.best_win_ratio = 0.0 # 当前最佳胜率，用他来判断是否有更好的模型
        # 弱AI（纯MCTS）模拟步数，用于给训练的策略AI提供对手
        self.pure_mcts_playout_num = 1000
        if init_model:
            # 通过init_model设置策略网络
            self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model)
        else:
            # 训练一个新的策略网络
            self.policy_value_net = PolicyValueNet(self.board_width, self.board_height)
        # AI Player，设置is_selfplay=1 自我对弈，因为是在进行训练
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)
        
    # 通过旋转和翻转增加数据集, play_data: [(state, mcts_prob, winner_z), ..., ...]
    def get_equi_data(self, play_data):
        extend_data = []
        for state, mcts_porb, winner in play_data:
            # 在4个方向上进行expand，每个方向都进行旋转，水平翻转
            for i in [1, 2, 3, 4]:
                # 逆时针旋转
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(np.flipud(mcts_porb.reshape(self.board_height, self.board_width)), i)
                extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # 水平翻转
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    # 收集自我对弈数据，用于训练
    def collect_selfplay_data(self, n_games=1):
        for i in range(n_games):
            # 与MCTS Player进行对弈
            winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp)
            play_data = list(play_data)[:]
            # 保存下了多少步
            self.episode_len = len(play_data)
            # 增加数据 play_data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)
            
    # 更新策略网络
    def policy_update(self):
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        # 保存更新前的old_probs, old_v
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            # 每次训练，调整参数，返回loss和entropy
            loss, entropy = self.policy_value_net.train_step(
                    state_batch,
                    mcts_probs_batch,
                    winner_batch,
                    self.learn_rate*self.lr_multiplier)
            # 输入状态，得到行动的可能性和状态值，按照batch进行输入
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            # 计算更新前后两次的loss差
            kl = np.mean(np.sum(old_probs * (
                    np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                    axis=1)
            )
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # 动态调整学习倍率 lr_multiplier
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}"
               ).format(kl,
                        self.lr_multiplier,
                        loss,
                        entropy,
                        explained_var_old,
                        explained_var_new))
        return loss, entropy

    # 用于评估训练网络的质量，评估一共10场play，返回比赛胜率（赢1分、输0分、平0.5分）
    def policy_evaluate(self, n_games=10):
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct, n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            # AI和弱AI（纯MCTS）对弈，不需要可视化 is_shown=0，双方轮流职黑 start_player=i % 2
            winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0)
            win_cnt[winner] += 1
        # 计算胜率，平手计为0.5分
        win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
                self.pure_mcts_playout_num,
                win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        # 开始训练
        try:
            # 训练game_batch_num次，每个batch比赛play_batch_size场
            for i in range(self.game_batch_num):
                # 收集自我对弈数据
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(i+1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                # 判断当前模型的表现，保存最优模型
                if (i+1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i+1))
                    win_ratio = self.policy_evaluate()
                    # 保存当前策略
                    self.policy_value_net.save_model('./current_policy.model')
                    if win_ratio > self.best_win_ratio:
                        print("发现新的最优策略，进行策略更新")
                        self.best_win_ratio = win_ratio
                        # 更新最优策略
                        self.policy_value_net.save_model('./best_policy.model')
                        if (self.best_win_ratio == 1.0 and
                                self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')


if __name__ == '__main__':
    training_pipeline = TrainPipeline()
    training_pipeline.run()


  x_act = F.log_softmax(self.act_fc1(x_act))


batch i:1, episode_len:16
batch i:2, episode_len:13
batch i:3, episode_len:16
batch i:4, episode_len:9
batch i:5, episode_len:13
kl:0.00837,lr_multiplier:1.500,loss:4.569027900695801,entropy:3.574512004852295,explained_var_old:0.000,explained_var_new:0.010
batch i:6, episode_len:11
kl:0.00517,lr_multiplier:2.250,loss:4.463287353515625,entropy:3.5658254623413086,explained_var_old:0.009,explained_var_new:0.160
batch i:7, episode_len:20
kl:0.01638,lr_multiplier:2.250,loss:3.8992257118225098,entropy:3.58003306388855,explained_var_old:0.117,explained_var_new:0.731
batch i:8, episode_len:14
kl:0.03151,lr_multiplier:2.250,loss:3.738576650619507,entropy:3.5634446144104004,explained_var_old:0.712,explained_var_new:0.817
batch i:9, episode_len:16
kl:0.03283,lr_multiplier:2.250,loss:3.7645103931427,entropy:3.513538122177124,explained_var_old:0.645,explained_var_new:0.740
batch i:10, episode_len:9
kl:0.07023,lr_multiplier:1.500,loss:3.801973819732666,entropy:3.4803783893585205,explained_var_old:0.

kl:0.01489,lr_multiplier:0.296,loss:3.4429798126220703,entropy:2.860381603240967,explained_var_old:0.385,explained_var_new:0.448
batch i:57, episode_len:13
kl:0.02573,lr_multiplier:0.296,loss:3.3254921436309814,entropy:2.860454797744751,explained_var_old:0.435,explained_var_new:0.507
batch i:58, episode_len:15
kl:0.01505,lr_multiplier:0.296,loss:3.393432855606079,entropy:2.8720874786376953,explained_var_old:0.407,explained_var_new:0.496
batch i:59, episode_len:9
kl:0.02153,lr_multiplier:0.296,loss:3.379958391189575,entropy:2.898717164993286,explained_var_old:0.444,explained_var_new:0.490
batch i:60, episode_len:11
kl:0.01710,lr_multiplier:0.296,loss:3.384429454803467,entropy:2.9179115295410156,explained_var_old:0.433,explained_var_new:0.500
batch i:61, episode_len:17
kl:0.01688,lr_multiplier:0.296,loss:3.3567094802856445,entropy:2.871095657348633,explained_var_old:0.446,explained_var_new:0.523
batch i:62, episode_len:9
kl:0.01184,lr_multiplier:0.296,loss:3.343198776245117,entropy:2.886

batch i:108, episode_len:11
kl:0.01830,lr_multiplier:0.296,loss:3.2801766395568848,entropy:2.714512348175049,explained_var_old:0.349,explained_var_new:0.405
batch i:109, episode_len:11
kl:0.01714,lr_multiplier:0.296,loss:3.258305311203003,entropy:2.6420280933380127,explained_var_old:0.313,explained_var_new:0.393
batch i:110, episode_len:9
kl:0.02256,lr_multiplier:0.296,loss:3.2752175331115723,entropy:2.7378740310668945,explained_var_old:0.387,explained_var_new:0.453
batch i:111, episode_len:7
kl:0.02137,lr_multiplier:0.296,loss:3.288073778152466,entropy:2.7057204246520996,explained_var_old:0.344,explained_var_new:0.417
batch i:112, episode_len:17
kl:0.01943,lr_multiplier:0.296,loss:3.177194356918335,entropy:2.679326295852661,explained_var_old:0.399,explained_var_new:0.470
batch i:113, episode_len:7
kl:0.01743,lr_multiplier:0.296,loss:3.2704668045043945,entropy:2.6680240631103516,explained_var_old:0.331,explained_var_new:0.391
batch i:114, episode_len:9
kl:0.01635,lr_multiplier:0.296,lo

kl:0.02049,lr_multiplier:0.296,loss:2.999134063720703,entropy:2.387637138366699,explained_var_old:0.347,explained_var_new:0.412
batch i:160, episode_len:9
kl:0.00918,lr_multiplier:0.444,loss:2.958428382873535,entropy:2.4280755519866943,explained_var_old:0.392,explained_var_new:0.452
batch i:161, episode_len:13
kl:0.02595,lr_multiplier:0.444,loss:2.93707013130188,entropy:2.4047129154205322,explained_var_old:0.382,explained_var_new:0.454
batch i:162, episode_len:8
kl:0.03855,lr_multiplier:0.444,loss:2.8654470443725586,entropy:2.315258741378784,explained_var_old:0.385,explained_var_new:0.459
batch i:163, episode_len:11
kl:0.04697,lr_multiplier:0.296,loss:2.900712013244629,entropy:2.376889228820801,explained_var_old:0.389,explained_var_new:0.454
batch i:164, episode_len:7
kl:0.03022,lr_multiplier:0.296,loss:2.8962457180023193,entropy:2.414734363555908,explained_var_old:0.471,explained_var_new:0.512
batch i:165, episode_len:9
kl:0.01435,lr_multiplier:0.296,loss:2.922762393951416,entropy:2.4

kl:0.01910,lr_multiplier:0.296,loss:2.7182576656341553,entropy:2.2233548164367676,explained_var_old:0.400,explained_var_new:0.465
batch i:211, episode_len:8
kl:0.02096,lr_multiplier:0.296,loss:2.873323917388916,entropy:2.298790216445923,explained_var_old:0.335,explained_var_new:0.401
batch i:212, episode_len:15
kl:0.02355,lr_multiplier:0.296,loss:2.804321765899658,entropy:2.231110095977783,explained_var_old:0.372,explained_var_new:0.439
batch i:213, episode_len:15
kl:0.01889,lr_multiplier:0.296,loss:2.808680295944214,entropy:2.2624142169952393,explained_var_old:0.364,explained_var_new:0.435
batch i:214, episode_len:9
kl:0.02198,lr_multiplier:0.296,loss:2.8187131881713867,entropy:2.1855597496032715,explained_var_old:0.315,explained_var_new:0.387
batch i:215, episode_len:8
kl:0.02114,lr_multiplier:0.296,loss:2.8514342308044434,entropy:2.3254010677337646,explained_var_old:0.356,explained_var_new:0.433
batch i:216, episode_len:8
kl:0.01475,lr_multiplier:0.296,loss:2.740318775177002,entropy

batch i:262, episode_len:10
kl:0.02596,lr_multiplier:0.296,loss:2.776826858520508,entropy:2.1469390392303467,explained_var_old:0.298,explained_var_new:0.366
batch i:263, episode_len:9
kl:0.02144,lr_multiplier:0.296,loss:2.7173662185668945,entropy:2.121511220932007,explained_var_old:0.324,explained_var_new:0.394
batch i:264, episode_len:10
kl:0.02889,lr_multiplier:0.296,loss:2.761590003967285,entropy:2.1245436668395996,explained_var_old:0.335,explained_var_new:0.409
batch i:265, episode_len:13
kl:0.04118,lr_multiplier:0.198,loss:2.723532199859619,entropy:2.146169662475586,explained_var_old:0.314,explained_var_new:0.381
batch i:266, episode_len:10
kl:0.02775,lr_multiplier:0.198,loss:2.684473991394043,entropy:2.0542404651641846,explained_var_old:0.342,explained_var_new:0.386
batch i:267, episode_len:7
kl:0.01965,lr_multiplier:0.198,loss:2.6608424186706543,entropy:2.0805182456970215,explained_var_old:0.287,explained_var_new:0.356
batch i:268, episode_len:11
kl:0.01886,lr_multiplier:0.198,l

kl:0.01620,lr_multiplier:0.198,loss:2.442476511001587,entropy:1.9048348665237427,explained_var_old:0.360,explained_var_new:0.406
batch i:314, episode_len:13
kl:0.02463,lr_multiplier:0.198,loss:2.541721820831299,entropy:1.8817124366760254,explained_var_old:0.250,explained_var_new:0.302
batch i:315, episode_len:10
kl:0.01676,lr_multiplier:0.198,loss:2.5483837127685547,entropy:1.9198479652404785,explained_var_old:0.268,explained_var_new:0.326
batch i:316, episode_len:9
kl:0.01521,lr_multiplier:0.198,loss:2.6024131774902344,entropy:1.906521201133728,explained_var_old:0.245,explained_var_new:0.298
batch i:317, episode_len:9
kl:0.02611,lr_multiplier:0.198,loss:2.529456377029419,entropy:1.9302372932434082,explained_var_old:0.291,explained_var_new:0.348
batch i:318, episode_len:13
kl:0.02771,lr_multiplier:0.198,loss:2.5205252170562744,entropy:1.8272356986999512,explained_var_old:0.259,explained_var_new:0.328
batch i:319, episode_len:8
kl:0.02483,lr_multiplier:0.198,loss:2.434601306915283,entro

kl:0.03617,lr_multiplier:0.198,loss:2.378756523132324,entropy:1.6581573486328125,explained_var_old:0.200,explained_var_new:0.262
batch i:365, episode_len:11
kl:0.02769,lr_multiplier:0.198,loss:2.404292345046997,entropy:1.7224876880645752,explained_var_old:0.257,explained_var_new:0.327
batch i:366, episode_len:11
kl:0.02704,lr_multiplier:0.198,loss:2.2707438468933105,entropy:1.6103641986846924,explained_var_old:0.293,explained_var_new:0.356
batch i:367, episode_len:7
kl:0.03255,lr_multiplier:0.198,loss:2.2847580909729004,entropy:1.6692496538162231,explained_var_old:0.316,explained_var_new:0.392
batch i:368, episode_len:11
kl:0.04446,lr_multiplier:0.132,loss:2.3014307022094727,entropy:1.6169390678405762,explained_var_old:0.288,explained_var_new:0.340
batch i:369, episode_len:22
kl:0.01622,lr_multiplier:0.132,loss:2.2980422973632812,entropy:1.6436944007873535,explained_var_old:0.286,explained_var_new:0.336
batch i:370, episode_len:7
kl:0.01450,lr_multiplier:0.132,loss:2.355462074279785,en

kl:0.02936,lr_multiplier:0.198,loss:2.1916494369506836,entropy:1.6366709470748901,explained_var_old:0.356,explained_var_new:0.434
batch i:416, episode_len:7
kl:0.03226,lr_multiplier:0.198,loss:2.2173004150390625,entropy:1.6282845735549927,explained_var_old:0.330,explained_var_new:0.378
batch i:417, episode_len:11
kl:0.03097,lr_multiplier:0.198,loss:2.1824092864990234,entropy:1.6332037448883057,explained_var_old:0.369,explained_var_new:0.429
batch i:418, episode_len:17
kl:0.02169,lr_multiplier:0.198,loss:2.0618157386779785,entropy:1.565187692642212,explained_var_old:0.408,explained_var_new:0.466
batch i:419, episode_len:15
kl:0.01613,lr_multiplier:0.198,loss:2.1423234939575195,entropy:1.586427092552185,explained_var_old:0.342,explained_var_new:0.389
batch i:420, episode_len:12
kl:0.02308,lr_multiplier:0.198,loss:2.16903018951416,entropy:1.5797944068908691,explained_var_old:0.314,explained_var_new:0.370
batch i:421, episode_len:7
kl:0.02197,lr_multiplier:0.198,loss:2.2467994689941406,ent

kl:0.01803,lr_multiplier:0.198,loss:2.022700071334839,entropy:1.4912569522857666,explained_var_old:0.401,explained_var_new:0.457
batch i:467, episode_len:12
kl:0.02689,lr_multiplier:0.198,loss:2.031461715698242,entropy:1.5350595712661743,explained_var_old:0.385,explained_var_new:0.433
batch i:468, episode_len:7
kl:0.02384,lr_multiplier:0.198,loss:2.040299415588379,entropy:1.4909179210662842,explained_var_old:0.367,explained_var_new:0.417
batch i:469, episode_len:15
kl:0.02326,lr_multiplier:0.198,loss:2.1619138717651367,entropy:1.5150854587554932,explained_var_old:0.330,explained_var_new:0.385
batch i:470, episode_len:11
kl:0.02071,lr_multiplier:0.198,loss:2.0084261894226074,entropy:1.4876272678375244,explained_var_old:0.390,explained_var_new:0.450
batch i:471, episode_len:11
kl:0.02493,lr_multiplier:0.198,loss:2.086244583129883,entropy:1.5237441062927246,explained_var_old:0.392,explained_var_new:0.455
batch i:472, episode_len:11
kl:0.02569,lr_multiplier:0.198,loss:2.045355796813965,ent