# 基于强化学习的五子棋AI训练

* Content：采用强化学习（策略价值网络），训练五子棋AI；棋盘大小 10 * 10
* Author:  HuiHui
* Date:    2020-06-04
* Reference:

## 介绍

* **强化学习：**  
    &emsp;1. 强化学习四要素：状态(state)、动作(action)、策略（policy）、奖励(reward)  
    &emsp;2. RL考虑的是个体（Agent）与环境（Environment）的交互问题，目标是找到一个最优策略，使Agent获得尽可能多的来自环境的奖励  
    &emsp;3. 强化学习Agent：基于价值的强化学习、基于策略强化学习、结合策略梯度以及价值函数的强化学习
    
* **实现思路：**  
    &emsp;下棋对弈场景是环境，每一种棋盘布局就是一个状态，所有可能的落子位置是动作空间，落子概率为策略，胜率就是奖励；采用策略价值网络训练AI,使用蒙特卡洛树搜索（MCTS）来进行策略优化,MCTS通过self－play来生成数据供深度神经网络学习,神经网络学习输入为当前棋盘，输出为双端口，分别表示当前棋盘的状态值（value）和当前棋盘各个位置的走子的概率

* **项目文件结构：**  
    &emsp;1. game.py：定义了游戏的棋盘、获取棋盘状态、下棋、判断是否有人胜利、绘制人机对弈的可视化棋盘、自我对弈；  
    &emsp;2. human_play.py：人机对弈，在可视化界面下实现与AI的对弈；  
    &emsp;3. mcts_alphaZero.py：实现Alpha_Zero中的MCTS，使用策略网络来指导树搜索并计算叶节点；  
    &emsp;4. mcts_pure.py：实现随机走子策略的MCTS；  
    &emsp;5. policy_value_net_pytorch.py：策略价值网络，用来指导MCTS搜索并计算叶子节点；  
    &emsp;6. train_10x10.ipynb：训练AI主程序;  
    &emsp;7. best_policy.model：10X10棋盘下训练得到的AI最优策略模型；  
    &emsp;8. demo_10x10.ipynb：人机对弈demo.

## 导入相关库

In [1]:
from __future__ import print_function
import random
import numpy as np
# deque 是一个双端队列
from collections import defaultdict, deque
from game import Board, Game
from mcts_pure import MCTSPlayer as MCTS_Pure # 随机走子策略的AI
from mcts_alphaZero import MCTSPlayer # AlphaGo方式的AI
from policy_value_net_pytorch import PolicyValueNet  # Pytorch

## AI训练主程序

In [None]:
class TrainPipeline():
    def __init__(self, init_model=None):
        # 设置棋盘和游戏的参数
        self.board_width = 10
        self.board_height = 10
        self.n_in_row = 4
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # 设置训练参数
        self.learn_rate = 2e-3 # 基准学习率
        self.lr_multiplier = 1.0  # 基于KL自动调整学习倍速
        self.temp = 1.0  # 温度参数
        self.n_playout = 400  # 每下一步棋，模拟的步骤数
        self.c_puct = 5 # exploitation和exploration之间的折中系数
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size) #使用 deque 创建一个双端队列
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02 # 早停检查
        self.check_freq = 50 # 每50次检查一次，策略价值网络是否更新
        self.game_batch_num = 500 # 训练多少个epoch
        self.best_win_ratio = 0.0 # 当前最佳胜率，用他来判断是否有更好的模型
        # 弱AI（纯MCTS）模拟步数，用于给训练的策略AI提供对手
        self.pure_mcts_playout_num = 1000
        if init_model:
            # 通过init_model设置策略网络
            self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model)
        else:
            # 训练一个新的策略网络
            self.policy_value_net = PolicyValueNet(self.board_width, self.board_height)
        # AI Player，设置is_selfplay=1 自我对弈，因为是在进行训练
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)
        
  # 通过旋转和翻转增加数据集, play_data: [(state, mcts_prob, winner_z), ..., ...]
    def get_equi_data(self, play_data):
        extend_data = []
        for state, mcts_porb, winner in play_data:
            # 在4个方向上进行expand，每个方向都进行旋转，水平翻转
            for i in [1, 2, 3, 4]:
                # 逆时针旋转
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(np.flipud(mcts_porb.reshape(self.board_height, self.board_width)), i)
                extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # 水平翻转
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    # 收集自我对弈数据，用于训练
    def collect_selfplay_data(self, n_games=1):
        for i in range(n_games):
            # 与MCTS Player进行对弈
            winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp)
            play_data = list(play_data)[:]
            # 保存下了多少步
            self.episode_len = len(play_data)
            # 增加数据 play_data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)
            
      # 更新策略网络
    def policy_update(self):
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        # 保存更新前的old_probs, old_v
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            # 每次训练，调整参数，返回loss和entropy
            loss, entropy = self.policy_value_net.train_step(
                    state_batch,
                    mcts_probs_batch,
                    winner_batch,
                    self.learn_rate*self.lr_multiplier)
            # 输入状态，得到行动的可能性和状态值，按照batch进行输入
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            # 计算更新前后两次的loss差
            kl = np.mean(np.sum(old_probs * (
                    np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                    axis=1)
            )
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # 动态调整学习倍率 lr_multiplier
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}"
               ).format(kl,
                        self.lr_multiplier,
                        loss,
                        entropy,
                        explained_var_old,
                        explained_var_new))
        return loss, entropy

    # 用于评估训练网络的质量，评估一共10场play，返回比赛胜率（赢1分、输0分、平0.5分）
    def policy_evaluate(self, n_games=10):
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct, n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            # AI和弱AI（纯MCTS）对弈，不需要可视化 is_shown=0，双方轮流职黑 start_player=i % 2
            winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0)
            win_cnt[winner] += 1
        # 计算胜率，平手计为0.5分
        win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
                self.pure_mcts_playout_num,
                win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio
    
    def run(self):
        # 开始训练
        try:
            # 训练game_batch_num次，每个batch比赛play_batch_size场
            for i in range(self.game_batch_num):
                # 收集自我对弈数据
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(i+1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                # 判断当前模型的表现，保存最优模型
                if (i+1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i+1))
                    win_ratio = self.policy_evaluate()
                    # 保存当前策略
                    self.policy_value_net.save_model('./current_policy.model')
                    if win_ratio > self.best_win_ratio:
                        print("发现新的最优策略，进行策略更新")
                        self.best_win_ratio = win_ratio
                        # 更新最优策略
                        self.policy_value_net.save_model('./best_policy.model')
                        if (self.best_win_ratio == 1.0 and
                                self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')


if __name__ == '__main__':
    training_pipeline = TrainPipeline()
    training_pipeline.run()

  x_act = F.log_softmax(self.act_fc1(x_act))


batch i:1, episode_len:22
batch i:2, episode_len:33
batch i:3, episode_len:17
kl:0.02070,lr_multiplier:1.000,loss:5.559216499328613,entropy:4.589074611663818,explained_var_old:-0.001,explained_var_new:0.084
batch i:4, episode_len:22
kl:0.00769,lr_multiplier:1.500,loss:5.586826801300049,entropy:4.593297958374023,explained_var_old:-0.001,explained_var_new:0.006
batch i:5, episode_len:14
kl:0.00883,lr_multiplier:2.250,loss:5.5724921226501465,entropy:4.587553977966309,explained_var_old:-0.002,explained_var_new:0.031
batch i:6, episode_len:24
kl:0.01463,lr_multiplier:2.250,loss:5.482388496398926,entropy:4.586917877197266,explained_var_old:0.034,explained_var_new:0.139
batch i:7, episode_len:19
kl:0.05558,lr_multiplier:1.500,loss:5.329034328460693,entropy:4.532655239105225,explained_var_old:0.068,explained_var_new:0.246
batch i:8, episode_len:13
kl:0.00848,lr_multiplier:2.250,loss:5.277288913726807,entropy:4.496841907501221,explained_var_old:0.210,explained_var_new:0.281
batch i:9, episode_l

batch i:55, episode_len:13
kl:0.02070,lr_multiplier:0.444,loss:4.429287433624268,entropy:3.9447779655456543,explained_var_old:0.432,explained_var_new:0.493
batch i:56, episode_len:16
kl:0.02922,lr_multiplier:0.444,loss:4.428858280181885,entropy:3.8997788429260254,explained_var_old:0.361,explained_var_new:0.437
batch i:57, episode_len:16
kl:0.02962,lr_multiplier:0.444,loss:4.474386692047119,entropy:3.9629242420196533,explained_var_old:0.373,explained_var_new:0.452
batch i:58, episode_len:21
kl:0.02959,lr_multiplier:0.444,loss:4.357915878295898,entropy:3.7589383125305176,explained_var_old:0.417,explained_var_new:0.491
batch i:59, episode_len:17
kl:0.02514,lr_multiplier:0.444,loss:4.35282039642334,entropy:3.855694532394409,explained_var_old:0.396,explained_var_new:0.489
batch i:60, episode_len:12
kl:0.01918,lr_multiplier:0.444,loss:4.412149906158447,entropy:3.856865882873535,explained_var_old:0.386,explained_var_new:0.443
batch i:61, episode_len:13
kl:0.01557,lr_multiplier:0.444,loss:4.36

kl:0.02041,lr_multiplier:0.444,loss:3.7817678451538086,entropy:3.254063367843628,explained_var_old:0.396,explained_var_new:0.518
batch i:108, episode_len:13
kl:0.02435,lr_multiplier:0.444,loss:3.7486660480499268,entropy:3.2429635524749756,explained_var_old:0.398,explained_var_new:0.531
batch i:109, episode_len:18
kl:0.01756,lr_multiplier:0.444,loss:3.701927661895752,entropy:3.256883382797241,explained_var_old:0.400,explained_var_new:0.536
batch i:110, episode_len:9
kl:0.02112,lr_multiplier:0.444,loss:3.61443829536438,entropy:3.2025375366210938,explained_var_old:0.469,explained_var_new:0.589
batch i:111, episode_len:19
kl:0.02525,lr_multiplier:0.444,loss:3.7931008338928223,entropy:3.254236936569214,explained_var_old:0.373,explained_var_new:0.497
batch i:112, episode_len:24
kl:0.02908,lr_multiplier:0.444,loss:3.643279552459717,entropy:3.1417601108551025,explained_var_old:0.408,explained_var_new:0.533
batch i:113, episode_len:17
kl:0.02434,lr_multiplier:0.444,loss:3.6370487213134766,entro

batch i:159, episode_len:8
kl:0.01464,lr_multiplier:0.444,loss:3.484156608581543,entropy:3.0384504795074463,explained_var_old:0.437,explained_var_new:0.579
batch i:160, episode_len:17
kl:0.01104,lr_multiplier:0.444,loss:3.420957326889038,entropy:3.001589298248291,explained_var_old:0.472,explained_var_new:0.581
batch i:161, episode_len:9
kl:0.01455,lr_multiplier:0.444,loss:3.4133055210113525,entropy:3.0430362224578857,explained_var_old:0.488,explained_var_new:0.583
batch i:162, episode_len:8
kl:0.01703,lr_multiplier:0.444,loss:3.4141592979431152,entropy:2.9070510864257812,explained_var_old:0.403,explained_var_new:0.537
batch i:163, episode_len:11
kl:0.01699,lr_multiplier:0.444,loss:3.3170604705810547,entropy:2.913010835647583,explained_var_old:0.527,explained_var_new:0.613
batch i:164, episode_len:7
kl:0.02307,lr_multiplier:0.444,loss:3.4104509353637695,entropy:3.0014710426330566,explained_var_old:0.450,explained_var_new:0.579
batch i:165, episode_len:15
kl:0.02621,lr_multiplier:0.444,l

batch i:211, episode_len:11
kl:0.01767,lr_multiplier:0.667,loss:3.252023935317993,entropy:2.86856746673584,explained_var_old:0.464,explained_var_new:0.583
batch i:212, episode_len:8
kl:0.03684,lr_multiplier:0.667,loss:3.102418899536133,entropy:2.6546061038970947,explained_var_old:0.490,explained_var_new:0.620
batch i:213, episode_len:10
kl:0.03610,lr_multiplier:0.667,loss:3.199740171432495,entropy:2.8492119312286377,explained_var_old:0.470,explained_var_new:0.586
batch i:214, episode_len:11
kl:0.02364,lr_multiplier:0.667,loss:3.222397565841675,entropy:2.8733153343200684,explained_var_old:0.472,explained_var_new:0.609
batch i:215, episode_len:9
kl:0.04347,lr_multiplier:0.444,loss:3.091931104660034,entropy:2.644282579421997,explained_var_old:0.493,explained_var_new:0.635
batch i:216, episode_len:11
kl:0.01521,lr_multiplier:0.444,loss:3.2428836822509766,entropy:2.828047752380371,explained_var_old:0.454,explained_var_new:0.572
batch i:217, episode_len:9
kl:0.01767,lr_multiplier:0.444,loss:

batch i:263, episode_len:14
kl:0.01992,lr_multiplier:0.667,loss:3.0003201961517334,entropy:2.4518301486968994,explained_var_old:0.358,explained_var_new:0.494
batch i:264, episode_len:12
kl:0.02237,lr_multiplier:0.667,loss:3.0346839427948,entropy:2.5780646800994873,explained_var_old:0.393,explained_var_new:0.546
batch i:265, episode_len:8
kl:0.01864,lr_multiplier:0.667,loss:3.0126936435699463,entropy:2.60884165763855,explained_var_old:0.451,explained_var_new:0.553
batch i:266, episode_len:11
kl:0.02174,lr_multiplier:0.667,loss:3.0220463275909424,entropy:2.4787886142730713,explained_var_old:0.384,explained_var_new:0.501
batch i:267, episode_len:29
kl:0.02686,lr_multiplier:0.667,loss:3.0686628818511963,entropy:2.5634920597076416,explained_var_old:0.373,explained_var_new:0.497
batch i:268, episode_len:7
kl:0.02815,lr_multiplier:0.667,loss:3.0291459560394287,entropy:2.6212100982666016,explained_var_old:0.379,explained_var_new:0.534
batch i:269, episode_len:13
kl:0.03131,lr_multiplier:0.667,

batch i:315, episode_len:11
kl:0.02858,lr_multiplier:0.444,loss:2.6569175720214844,entropy:2.2795233726501465,explained_var_old:0.525,explained_var_new:0.634
batch i:316, episode_len:12
kl:0.02066,lr_multiplier:0.444,loss:2.8296217918395996,entropy:2.374901533126831,explained_var_old:0.489,explained_var_new:0.562
batch i:317, episode_len:9
kl:0.03280,lr_multiplier:0.444,loss:2.7447869777679443,entropy:2.3846004009246826,explained_var_old:0.528,explained_var_new:0.609
batch i:318, episode_len:9
kl:0.03596,lr_multiplier:0.444,loss:2.8415799140930176,entropy:2.406064033508301,explained_var_old:0.420,explained_var_new:0.513
batch i:319, episode_len:11
kl:0.01979,lr_multiplier:0.444,loss:2.9038608074188232,entropy:2.4538631439208984,explained_var_old:0.460,explained_var_new:0.572
batch i:320, episode_len:7
kl:0.01716,lr_multiplier:0.444,loss:2.829874277114868,entropy:2.3650383949279785,explained_var_old:0.478,explained_var_new:0.557
batch i:321, episode_len:8
kl:0.01763,lr_multiplier:0.444,

batch i:367, episode_len:7
kl:0.00582,lr_multiplier:0.296,loss:2.524541139602661,entropy:2.1939444541931152,explained_var_old:0.605,explained_var_new:0.638
batch i:368, episode_len:15
kl:0.01473,lr_multiplier:0.296,loss:2.6620209217071533,entropy:2.258345603942871,explained_var_old:0.520,explained_var_new:0.589
batch i:369, episode_len:9
kl:0.01327,lr_multiplier:0.296,loss:2.6376140117645264,entropy:2.2980637550354004,explained_var_old:0.603,explained_var_new:0.654
batch i:370, episode_len:8
kl:0.01206,lr_multiplier:0.296,loss:2.5385353565216064,entropy:2.2661726474761963,explained_var_old:0.625,explained_var_new:0.675
batch i:371, episode_len:7
kl:0.00857,lr_multiplier:0.444,loss:2.5996694564819336,entropy:2.21311616897583,explained_var_old:0.581,explained_var_new:0.627
batch i:372, episode_len:7
kl:0.02055,lr_multiplier:0.444,loss:2.561502456665039,entropy:2.1894474029541016,explained_var_old:0.601,explained_var_new:0.675
batch i:373, episode_len:7
kl:0.01610,lr_multiplier:0.444,loss

batch i:419, episode_len:7
kl:0.01513,lr_multiplier:0.296,loss:2.42922306060791,entropy:2.04325008392334,explained_var_old:0.579,explained_var_new:0.644
batch i:420, episode_len:8
kl:0.01305,lr_multiplier:0.296,loss:2.4912190437316895,entropy:2.1159090995788574,explained_var_old:0.506,explained_var_new:0.576
batch i:421, episode_len:8
kl:0.01064,lr_multiplier:0.296,loss:2.534266471862793,entropy:2.138031244277954,explained_var_old:0.585,explained_var_new:0.647
batch i:422, episode_len:16
kl:0.01349,lr_multiplier:0.296,loss:2.3451199531555176,entropy:1.9674537181854248,explained_var_old:0.572,explained_var_new:0.607
batch i:423, episode_len:10
kl:0.01728,lr_multiplier:0.296,loss:2.5271315574645996,entropy:2.10886549949646,explained_var_old:0.540,explained_var_new:0.596
batch i:424, episode_len:7
kl:0.01487,lr_multiplier:0.296,loss:2.4784939289093018,entropy:2.0936203002929688,explained_var_old:0.510,explained_var_new:0.586
batch i:425, episode_len:9
kl:0.01877,lr_multiplier:0.296,loss:2

batch i:471, episode_len:23
kl:0.02649,lr_multiplier:0.444,loss:2.4156930446624756,entropy:2.0292773246765137,explained_var_old:0.527,explained_var_new:0.631
batch i:472, episode_len:12
kl:0.03046,lr_multiplier:0.444,loss:2.4763786792755127,entropy:2.0184881687164307,explained_var_old:0.443,explained_var_new:0.551
batch i:473, episode_len:11
kl:0.01564,lr_multiplier:0.444,loss:2.349874258041382,entropy:2.056079626083374,explained_var_old:0.583,explained_var_new:0.674
batch i:474, episode_len:7
kl:0.01982,lr_multiplier:0.444,loss:2.4140679836273193,entropy:2.020801067352295,explained_var_old:0.557,explained_var_new:0.631
batch i:475, episode_len:9
kl:0.02399,lr_multiplier:0.444,loss:2.5137743949890137,entropy:2.037578821182251,explained_var_old:0.391,explained_var_new:0.525
batch i:476, episode_len:9
kl:0.02366,lr_multiplier:0.444,loss:2.315424919128418,entropy:1.9650179147720337,explained_var_old:0.528,explained_var_new:0.616
batch i:477, episode_len:18
kl:0.01919,lr_multiplier:0.444,l