In [None]:
import numpy as np
import logging
from game import Board, Game
from MCTS import MCTSPlayer
from network import PolicyValueNet

In [None]:
logging.basicConfig(filename="train_log.txt",
                    format="%(asctime)s - %(funcName)s - %(levelname)s: %(message)s",
                    level=logging.INFO)
logger = logging.getLogger()

In [None]:
class Train:
    def __init__(self, initial_model=None):
        self.board_length = 19
        self.game = Game(self.board_length)
        self.board = self.game.board
        self.initial_model = initial_model
        self.learning_rate = [1e-2, 1e-3, 1e-4]
        self.batch_size = 512
        self.game_rounds = 1  # 每一局后都都更新策略
        self.check_point = self.game_rounds * 500  # 每500轮检验模型一次
        self.train_data = []

        if initial_model:
            self.PolicyValueNet = PolicyValueNet(self.board_length, initial_model=initial_model)
        else:
            self.PolicyValueNet = PolicyValueNet(self.board_length)
        self.MCTSPlayer = MCTSPlayer(PolicyValueF=self.PolicyValueNet.PolicyValueFunction)

    def augment(self, data):
        """数据增强:棋盘旋转90°、180°或270°，外加镜像共8个不同状态的价值相同"""
        augemnt_data = []
        state, prob, winner = data
        for i in range(4):
            #  旋转
            new_state = np.rot90(state, k=i, axes=(1, 2))
            new_prob = np.rot90(prob[:-1].reshape(self.board_length, self.board_length),
                                k=i).flatten()
            new_prob = np.append(new_prob, prob[-1])
            augemnt_data.append((new_state, new_prob, winner))
            #  镜像
            new_state = np.array([np.fliplr(d) for d in state])
            new_prob = np.fliplr(prob[:-1].reshape(self.board_length, self.board_length)).flatten()
            new_prob = np.append(new_prob, prob[-1])
            augemnt_data.append((new_state, new_prob, winner))
        return augemnt_data

    def collect_data(self):
        """自我博弈获得训练数据"""
        datas = self.game.self_play(self.MCTSPlayer, show=False)
        for data in datas:
            self.train_data.extend(self.augment(data))

    def policy_updata(self, learning_rate):
        """策略升级"""
        states, probs, winner = list(zip(*self.train_data))
        states = np.array(states)
        probs = np.array(probs)
        winner = np.array(winner)
        self.PolicyValueNet.fit(states, probs, winner, learning_rate, batch_size=self.batch_size)

    def eva(self, model):
        old_net = PolicyValueNet(self.board_length, initial_model=model)
        old_MCTSplayer = MCTSPlayer(old_net.PolicyValueFunction)
        win = []
        for _ in range(5):
            winner = self.game.play(self.MCTSPlayer, old_MCTSplayer, show=False)
            win.append(winner == -1)
        for _ in range(5):
            winner = self.game.play(old_MCTSplayer, self.MCTSPlayer, show=False)
            win.append(winner == 1)
        return sum(win) / 10
        
    def evaluate(self):
        """评估当前模型是否有进步"""
        if not self.initial_model:
            return 1.0
        old_net = PolicyValueNet(self.board_length, initial_model=self.initial_model)
        old_MCTSplayer = MCTSPlayer(old_net.PolicyValueFunction)
        win = []
        for _ in range(5):
            winner = self.game.play(self.MCTSPlayer, old_MCTSplayer, show=False)
            win.append(winner == -1)
        for _ in range(5):
            winner = self.game.play(old_MCTSplayer, self.MCTSPlayer, show=False)
            win.append(winner == 1)
        return sum(win) / 10
    
    def train(self):
        epoch = 0
        try:
            while True:
                epoch += 1
                for _ in range(self.game_rounds):
                    self.collect_data()
                lr = 0 if epoch < 400 else 1 if epoch < 600 else 2
                logger.info("Policy update using %s, length %s", self.train_data[:17:8], len(self.train_data))
                self.policy_updata(learning_rate=self.learning_rate[lr])
                self.train_data.clear()

                if epoch % self.check_point == 0:
                    ration = self.evaluate()
                    logger.info("Policy evaluate: %s", ration)
                    if ration > 0.5:
                        self.PolicyValueNet.save_model(self.initial_model)
                        self.PolicyValueNet.save_model(f"models/epoch_{epoch}")
        except KeyboardInterrupt:
            self.PolicyValueNet.save_model('last_model')
            logger.info("Done")
        finally:
            self.PolicyValueNet.save_model('error')

In [None]:
train = Train("error.keras")

In [None]:
train.train()