In [1]:
# 環境構築
# pip3 install torch

In [2]:
# ゲームの実装 ここを他のドメインの実装に置き換えれば色々なゲームで動かせる
# 高速な言語でこの部分のみ実装すると手軽に高速化出来る(C++ならpybind11やBoost.Pythonを利用可)

import numpy as np

BLACK, WHITE =1, -1 # 先手後手

class State:
    '''○×ゲームの盤面実装'''
    X, Y = 'ABC',  '123'
    C = {0: '_', BLACK: 'O', WHITE: 'X'}

    def __init__(self):
        self.board = np.zeros((3, 3)) # (x, y)
        self.color = 1
        self.win_color = 0
        self.record = []

    def action2str(self, a):
        return self.X[a // 3] + self.Y[a % 3]

    def str2action(self, s):
        return self.X.find(s[0]) * 3 + self.Y.find(s[1])

    def record_string(self):
        return ' '.join([self.action2str(a) for a in self.record])

    def __str__(self):
        # 表示
        s = '   ' + ' '.join(self.Y) + '\n'
        for i in range(3):
            s += self.X[i] + ' ' + ' '.join([self.C[self.board[i, j]] for j in range(3)]) + '\n'
        s += 'record = ' + self.record_string()
        return s

    def play(self, action):
        # 行動で状態を進める関数
        # action は board 上の位置 (0 ~ 8) または行動系列の文字列
        if isinstance(action, str):
            for astr in action.split():
                self.play(self.str2action(astr))
            return self

        x, y = action // 3, action % 3
        self.board[x, y] = self.color

        # 3つ揃ったか調べる
        if self.board[x, :].sum() == 3 * self.color \
          or self.board[:, y].sum() == 3 * self.color \
          or (x == y and np.diag(self.board, k=0).sum() == 3 * self.color) \
          or (x == 2 - y and np.diag(self.board[::-1,:], k=0).sum() == 3 * self.color):
            self.win_color = self.color

        self.color = -self.color
        self.record.append(action)
        return self

    def terminal(self):
        # 終端状態かどうか返す
        return self.win_color != 0 or len(self.record) == 3 * 3

    def terminal_reward(self):
        # 終端状態での勝敗による報酬を返す
        return self.win_color if self.color == BLACK else -self.win_color

    def legal_actions(self):
        # 可能な行動リストを返す
        return [a for a in range(3 * 3) if self.board[a // 3, a % 3] == 0]

    def action_length(self):
        # 行動ラベルの総数(policyの出力サイズを決める)
        return 3 * 3

    def feature(self):
        # ニューラルネットに入力する状態表現を返す
        return np.stack([self.board == self.color, self.board == -self.color]).astype(np.float64)


state = State().play('B1')
print(state)
print('input feature')
print(state.feature())
state = State().play('B2 A1 C2')
print('input feature')
print(state.feature())

   1 2 3
A _ _ _
B O _ _
C _ _ _
record = B1
input feature
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [1. 0. 0.]
  [0. 0. 0.]]]
input feature
[[[1. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 1. 0.]
  [0. 1. 0.]]]


In [3]:
# ニューラルネットの実装(PyTorch)
# AlphaZeroの論文のネットワーク構成を小さく再現

import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBN(nn.Module):
    def __init__(self, filters0, filters1, kernel_size):
        super().__init__()
        self.conv = nn.Conv2d(filters0, filters1, kernel_size, stride=1, padding=kernel_size//2, bias=False)
        self.bn = nn.BatchNorm2d(filters1)

    def forward(self, x):
        return self.bn(self.conv(x))

class ResidualBlock(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.conv1 = ConvBN(filters, filters, 3)
        self.conv2 = ConvBN(filters, filters, 3)

    def forward(self, x):
        return F.relu(x + self.conv2(F.relu(self.conv1(x))))

In [4]:
num_filters = 16
num_blocks = 2

class Net(nn.Module):
    '''ニューラルネット計算を行うクラス'''
    def __init__(self):
        super().__init__()
        state = State()
        self.input_shape = state.feature().shape
        self.board_size = self.input_shape[1] * self.input_shape[2]

        self.layer0 = ConvBN(self.input_shape[0], num_filters, 3)
        self.blocks = nn.ModuleList([ResidualBlock(num_filters) for _ in range(num_blocks)])
        
        self.conv_p = ConvBN(num_filters, 2, 1)
        self.fc_p = nn.Linear(self.board_size * 2, state.action_length())
        
        self.conv_v = ConvBN(num_filters, 1, 1)
        self.fc_v1 = nn.Linear(self.board_size * 1, 4)
        self.fc_v2 = nn.Linear(4, 1, bias=False)

    def forward(self, x):
        h = F.relu(self.layer0(x))
        for block in self.blocks:
            h = block(h)

        h_p = F.relu(self.conv_p(h))
        h_p = self.fc_p(h_p.view(-1, self.board_size * 2))

        h_v = F.relu(self.conv_v(h))
        h_v = F.relu(self.fc_v1(h_v.view(-1, self.board_size * 1)))
        h_v = self.fc_v2(h_v)

        # value(状態価値)にtanhを適用するので負け -1 ~ 勝ち 1
        return F.softmax(h_p, dim=-1), torch.tanh(h_v)

    def predict(self, state):
        # 探索中に呼ばれる推論関数
        self.eval()
        x = torch.FloatTensor(state.feature()).view(-1, *self.input_shape)
        with torch.no_grad():
            p, v = self.forward(x)
        return p.cpu().numpy()[0], v.cpu().numpy()[0][0]

In [5]:
def show_net(net, state):
    '''方策 (p)　と　状態価値 (v) を表示'''
    print(state)
    p, v = net.predict(state)
    print('p = ')
    print((p *1000).astype(int).reshape((-1, *net.input_shape[1:3])))
    print('v = ', v)

show_net(Net(), State())

   1 2 3
A _ _ _
B _ _ _
C _ _ _
record = 
p = 
[[[102  93 134]
  [131 118 128]
  [ 93  95 103]]]
v =  -0.010430467


In [6]:
# モンテカルロ木探索の実装

class Node:
    '''ある1状態の探索結果を保存するクラス'''
    def __init__(self, p, v):
        self.p, self.v = p, v
        self.n, self.q_sum = np.zeros_like(p), np.zeros_like(p)
        self.n_all, self.q_sum_all = 1, v / 2 # 事前分布(見解が分かれる点)

    def update(self, action, q_new):
        # 行動のスタッツを更新
        self.n[action] += 1
        self.q_sum[action] += q_new

        # ノード全体のスタッツも更新
        self.n_all += 1
        self.q_sum_all += q_new

In [7]:
import time, copy

class Tree:
    '''探索木を保持してモンテカルロ木探索を行うクラス'''
    def __init__(self, net):
        self.net = net
        self.nodes = {}
    
    def search(self, state, depth):
        # 終端状態の場合は末端報酬を返す
        if state.terminal():
            return state.terminal_reward()

        # まだ未到達の状態はニューラルネットを計算して推定価値を返す
        key = str(state)
        if key not in self.nodes:
            p, v = self.net.predict(state)
            self.nodes[key] = Node(p, v)
            return v

        # 到達済みの状態はバンディットで行動を選んで状態を進める
        node = self.nodes[key]
        p = node.p
        if depth == 0:
            # ルートノード(現局面)では方策にノイズを加える
            p = 0.75 * p + 0.25 * np.random.dirichlet([0.1] * len(p))

        best_action, best_ucb = None, -float('inf')
        for action in state.legal_actions():
            n, q_sum = 1 + node.n[action], node.q_sum_all / node.n_all + node.q_sum[action]
            ucb = q_sum / n + 2.0 * node.p[action] * np.sqrt(node.n_all) / n # PUCBの式

            if ucb > best_ucb:
                best_action, best_ucb = action, ucb

        # 一手進めて再帰で探索
        state.play(best_action)
        q_new = -self.search(state, depth + 1) # 1手ごとの手番交代を想定
        node.update(best_action, q_new)

        return q_new

    def think(self, state, num_simulations, temperature = 0, show=False):
        # 探索のエンドポイント
        if show:
            print(state)
        start, prev_time = time.time(), 0
        for _ in range(num_simulations):
            self.search(copy.deepcopy(state), depth=0)

            # 1秒ごとに探索結果を表示
            if show:
                tmp_time = time.time() - start
                if int(tmp_time) > int(prev_time):
                    prev_time = tmp_time
                    root, pv = self.nodes[str(state)], self.pv(state)
                    print('%.2f sec. best %s. q = %.4f. n = %d / %d. pv = %s'
                          % (tmp_time, state.action2str(pv[0]), root.q_sum[pv[0]] / root.n[pv[0]],
                             root.n[pv[0]], root.n_all, ' '.join([state.action2str(a) for a in pv])))

        #  訪問回数で重みつけた確率分布を返す
        root = self.nodes[str(state)]
        n = (root.n / np.max(root.n)) ** (1 / (temperature + 1e-8))
        return n / n.sum()
        
    def pv(self, state):
        # 最善応手列（読み筋）を返す
        s, pv_seq = copy.deepcopy(state), []
        while True:
            key = str(s)
            if key not in self.nodes or self.nodes[key].n.sum() == 0:
                break
            best_action = sorted([(a, self.nodes[key].n[a]) for a in s.legal_actions()], key=lambda x: -x[1])[0][0]
            pv_seq.append(best_action)
            s.play(best_action)
        return pv_seq

In [8]:
# 初期ネットワークで探索を行う
tree = Tree(Net())
tree.think(State(), 1000, show=True)

tree = Tree(Net())
tree.think(State().play('A1 C1 A2 C2'), 10000, show=True)

tree = Tree(Net())
tree.think(State().play('B2 A2 A3 C1 B3'), 10000, show=True)

tree = Tree(Net())
tree.think(State().play('B2 A2 A3 C1'), 10000, show=True)

   1 2 3
A _ _ _
B _ _ _
C _ _ _
record = 
1.00 sec. best C3. q = -0.0102. n = 94 / 567. pv = C3 B3 A1 C1
   1 2 3
A O O _
B _ _ _
C X X _
record = A1 C1 A2 C2
1.00 sec. best A3. q = 1.0000. n = 4752 / 4981. pv = A3
   1 2 3
A _ X O
B _ O O
C X _ _
record = B2 A2 A3 C1 B3
1.00 sec. best C3. q = -0.9839. n = 1310 / 4531. pv = C3 B1
2.00 sec. best B1. q = -0.9889. n = 2437 / 8537. pv = B1 C3
   1 2 3
A _ X O
B _ O _
C X _ _
record = B2 A2 A3 C1
1.00 sec. best B3. q = 0.9831. n = 2938 / 3021. pv = B3 B1 C3
2.00 sec. best B3. q = 0.9885. n = 5875 / 6138. pv = B3 B1 C3
3.00 sec. best B3. q = 0.9913. n = 8936 / 9325. pv = B3 B1 C3


array([0., 0., 0., 0., 0., 1., 0., 0., 0.], dtype=float32)

In [9]:
# ニューラルネットの学習

import torch.optim as optim

batch_size = 32
num_epochs = 40

def gen_target(ep):
    '''ニューラルネットの学習用 input, targets を生成'''
    turn_idx = np.random.randint(len(ep[0]))
    state = State()
    for a in ep[0][:turn_idx]:
        state.play(a)
    return state.feature(), ep[1][turn_idx], [ep[2] if turn_idx % 2 == 0 else -ep[2]]

def train(episodes):
    net = Net()
    optimizer = optim.SGD(net.parameters(), lr=3e-3, weight_decay=1e-4, momentum=0.9)
    for epoch in range(num_epochs):
        p_loss_sum, v_loss_sum = 0, 0
        net.train()
        for i in range(0, len(episodes), batch_size):
            x, p_target, v_target = zip(*[gen_target(episodes[np.random.randint(len(episodes))]) for j in range(batch_size)])
            x = torch.FloatTensor(np.array(x))
            p_target = torch.FloatTensor(np.array(p_target))
            v_target = torch.FloatTensor(np.array(v_target))

            p, v = net(x)
            p_loss = torch.sum(-p_target * torch.log(p))
            v_loss = torch.sum((v_target - v) ** 2)

            p_loss_sum += p_loss.item()
            v_loss_sum += v_loss.item()

            optimizer.zero_grad()
            (p_loss + v_loss).backward()
            optimizer.step()

        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.85
    print('p_loss %f v_loss %f' % (p_loss_sum / len(episodes), v_loss_sum / len(episodes)))
    return net

In [10]:
# AlphaZeroのアルゴリズムメイン

num_games = 200
num_train_steps = 20
num_simulations = 80

net = Net()
episodes = []
result_distribution = {1:0, 0:0, -1:0}
for g in range(num_games):
    # 1対戦のエピソード生成
    record, p_targets = [], []
    state = State()
    tree = Tree(net)
    temperature = 1.0 # 探索結果から方策のtargetを作るときの温度
    while not state.terminal():
        p_target = tree.think(state, num_simulations, temperature)
        # 行動をランダムに選んで進める
        action = np.random.choice(np.arange(len(p_target)), p=p_target)
        state.play(action)
        record.append(action)
        p_targets.append(p_target)
        temperature *= 0.8
    # 先手視点の報酬
    reward = state.terminal_reward() * (1 if len(record) % 2 == 0 else -1)
    result_distribution[reward] += 1
    episodes.append((record, p_targets, reward))
    if g % num_train_steps == 0:
        print('game ', end='')
    print(g, ' ', end='')

    # ニューラルネットの学習
    if (g + 1) % num_train_steps == 0:
        print(result_distribution)
        net = train(episodes)
print('finished')

game 0  1  2  3  4  5  6  7  8  9  10  11  12  13  14  15  16  17  18  19  {1: 5, 0: 10, -1: 5}
p_loss 3.043722 v_loss 0.491622
game 20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  {1: 12, 0: 21, -1: 7}
p_loss 2.977945 v_loss 0.332617
game 40  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  {1: 18, 0: 31, -1: 11}
p_loss 1.997322 v_loss 0.236930
game 60  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  {1: 33, 0: 36, -1: 11}
p_loss 2.358468 v_loss 0.296939
game 80  81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99  {1: 46, 0: 42, -1: 12}
p_loss 1.807588 v_loss 0.365034
game 100  101  102  103  104  105  106  107  108  109  110  111  112  113  114  115  116  117  118  119  {1: 51, 0: 57, -1: 12}
p_loss 1.614088 v_loss 0.374058
game 120  121  122  123  124  125  126  127  128  129  130  131  132  133  134  135  136  137  138  139  {1: 58, 0: 68, -1: 14}
p_loss 1.387779 v_loss 0.

In [11]:
# ニューラルネットの出力を見てみる

#　初期状態
print('initial state')
show_net(net, State())

# 置けば勝ち
print('WIN by put')
show_net(net, State().play('A1 C1 A2 C2'))

# ダブルリーチにされているので負け
print('LOSE by opponent\'s double reach')
show_net(net, State().play('B2 A2 A3 C1 B3'))

#　ダブルリーチにすれば勝ち
print('WIN through double reach')
show_net(net, State().play('B2 A2 A3 C1'))


# 難問: A1に置けば次の手番でダブルリーチにできて勝ち
print('strategic WIN by following double')
show_net(net, State().play('B1 A3'))


initial state
   1 2 3
A _ _ _
B _ _ _
C _ _ _
record = 
p = 
[[[572  33  97]
  [ 39  94  26]
  [ 59  27  48]]]
v =  0.40222245
WIN by put
   1 2 3
A O O _
B _ _ _
C X X _
record = A1 C1 A2 C2
p = 
[[[  1   1 244]
  [ 19  23  17]
  [ 51   1 639]]]
v =  0.4048041
LOSE by opponent's double reach
   1 2 3
A _ X O
B _ O O
C X _ _
record = B2 A2 A3 C1 B3
p = 
[[[  0  34   0]
  [882   0   2]
  [  6   3  70]]]
v =  0.1724575
WIN through double reach
   1 2 3
A _ X O
B _ O _
C X _ _
record = B2 A2 A3 C1
p = 
[[[ 20  86   3]
  [452  16 269]
  [ 13  14 122]]]
v =  0.36645132
strategic WIN by following double
   1 2 3
A _ _ X
B O _ _
C _ _ _
record = B1 A3
p = 
[[[152  14  61]
  [  4 539 101]
  [ 58  30  38]]]
v =  -0.7442839


In [12]:
# 学習済みモデルでの探索

tree = Tree(net)
tree.think(State(), 100000, show=True)

   1 2 3
A _ _ _
B _ _ _
C _ _ _
record = 
1.00 sec. best C1. q = 0.2096. n = 292 / 602. pv = C1 B2 C3 C2 A2 A1 A3 B3 B1
2.00 sec. best C1. q = 0.1951. n = 341 / 896. pv = C1 B2 C3 C2 A2 A1 A3 B3 B1
3.00 sec. best C3. q = 0.1933. n = 379 / 1217. pv = C3 B2 A3 B3 B1 A1 C2 C1 A2
4.00 sec. best C3. q = 0.1901. n = 424 / 1528. pv = C3 B2 A3 B3 B1 A1 C2 C1 A2
5.00 sec. best C3. q = 0.1655. n = 526 / 1859. pv = C3 B2 C1 C2 A2 A1 A3 B3 B1
6.00 sec. best C3. q = 0.1488. n = 622 / 2206. pv = C3 B2 A3 B3 B1 A1 C2 C1 A2
7.00 sec. best C3. q = 0.1379. n = 777 / 2566. pv = C3 B2 C1 C2 A2 A1 A3 B3 B1
8.00 sec. best C3. q = 0.1240. n = 826 / 2972. pv = C3 B2 C1 C2 A2 A1 A3 B3 B1
9.00 sec. best C3. q = 0.1129. n = 898 / 3324. pv = C3 B2 C1 C2 A2 A1 A3 B3 B1
10.00 sec. best C3. q = 0.1030. n = 998 / 3719. pv = C3 B2 C1 C2 A2 A1 A3 B3 B1
11.00 sec. best C3. q = 0.0988. n = 1088 / 4118. pv = C3 B2 C1 C2 A2 A1 A3 B3 B1
12.00 sec. best C1. q = 0.0918. n = 1212 / 4526. pv = C1 B2 C3 C2 A2 A1 A3 B3 B1
13.00 

101.00 sec. best A1. q = 0.0110. n = 12424 / 50884. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
102.00 sec. best A1. q = 0.0109. n = 12603 / 51344. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
103.00 sec. best A1. q = 0.0109. n = 12653 / 51800. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
104.00 sec. best A1. q = 0.0108. n = 12729 / 52270. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
105.00 sec. best A1. q = 0.0108. n = 12853 / 52704. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
106.00 sec. best A1. q = 0.0107. n = 12960 / 53142. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
107.00 sec. best A1. q = 0.0106. n = 13081 / 53535. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
108.00 sec. best A1. q = 0.0105. n = 13221 / 54018. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
109.00 sec. best A1. q = 0.0106. n = 13337 / 54497. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
110.01 sec. best A1. q = 0.0105. n = 13467 / 54995. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
111.00 sec. best A1. q = 0.0103. n = 13604 / 55498. pv = A1 B2 A3 A2 C2 B3 B1 C1 C3
112.00 sec. best A1. q = 0.0103. n = 13722 / 55968. pv = A1 B2 A3 A2 C2 B3 B

199.00 sec. best B2. q = 0.0245. n = 29789 / 98644. pv = B2 A1 B3 B1 C1 A3 A2 C2 C3
200.00 sec. best B2. q = 0.0243. n = 30175 / 99706. pv = B2 A1 B3 B1 C1 A3 A2 C2 C3


array([0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=float32)