# Use AlphaZero to Play Tic-Tac-Toe

PyTorch version

In [1]:
import collections
import math
import logging
import sys

import numpy as np
np.random.seed(0)
import pandas as pd
import gym
import torch
torch.manual_seed(0)
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as functional

import boardgame2
from boardgame2 import BLACK, WHITE

logging.basicConfig(level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        stream=sys.stdout, datefmt='%H:%M:%S')

Environment

In [2]:
env = gym.make('TicTacToe-v0')
for key in vars(env):
    logging.info('%s: %s', key, vars(env)[key])
for key in vars(env.spec):
    logging.info('%s: %s', key, vars(env.spec)[key])

00:10:57 [INFO] allow_pass: True
00:10:57 [INFO] illegal_equivalent_action: [-1  0]
00:10:57 [INFO] render_characters: {0: '+', 1: 'o', -1: 'x'}
00:10:57 [INFO] board: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
00:10:57 [INFO] observation_space: Tuple(Box(-1, 1, (3, 3), int8), Box(-1, 1, (), int8))
00:10:57 [INFO] action_space: Box(-1, 2, (2,), int8)
00:10:57 [INFO] target_length: 3
00:10:57 [INFO] spec: EnvSpec(TicTacToe-v0)
00:10:57 [INFO] id: TicTacToe-v0
00:10:57 [INFO] entry_point: boardgame2:KInARowEnv
00:10:57 [INFO] reward_threshold: None
00:10:57 [INFO] nondeterministic: False
00:10:57 [INFO] max_episode_steps: None
00:10:57 [INFO] _kwargs: {'board_shape': 3, 'target_length': 3}
00:10:57 [INFO] _env_name: TicTacToe


Agent

In [3]:
class AlphaZeroReplayer:

    def __init__(self):
        self.fields = ['player', 'board', 'prob', 'winner']
        self.memory = pd.DataFrame(columns=self.fields)

    def store(self, df):
        self.memory = pd.concat([self.memory, df[self.fields]], ignore_index=True)

    def sample(self, size):
        indices = np.random.choice(self.memory.shape[0], size=size)
        return (np.stack(self.memory.loc[indices, field]) for field in
                self.fields)

In [4]:
class AlphaZeroNet(nn.Module):
    def __init__(self, input_shape):
        super().__init__()

        self.input_shape = input_shape

        # common net
        self.input_net = nn.Sequential(
                nn.Conv2d(1, 256, kernel_size=3, padding="same"),
                nn.BatchNorm2d(256), nn.ReLU())
        self.residual_nets = [nn.Sequential(
                nn.Conv2d(256, 256, kernel_size=3, padding="same"),
                nn.BatchNorm2d(256)) for _ in range(2)]

        # probability net
        self.prob_net = nn.Sequential(
                nn.Conv2d(256, 256, kernel_size=3, padding="same"),
                nn.BatchNorm2d(256), nn.ReLU(),
                nn.Conv2d(256, 1, kernel_size=3, padding="same"))

        # value net
        self.value_net0 = nn.Sequential(
                nn.Conv2d(256, 1, kernel_size=3, padding="same"),
                nn.BatchNorm2d(1), nn.ReLU())
        self.value_net1 = nn.Sequential(
                nn.Linear(np.prod(input_shape), 1), nn.Tanh())

    def forward(self, board_tensor):
        # common net
        input_tensor = board_tensor.view(-1, 1, *self.input_shape)
        x = self.input_net(input_tensor)
        for i_net, residual_net in enumerate(self.residual_nets):
            y = residual_net(x)
            if i_net == len(self.residual_nets) - 1:
                y = y + x
            x = torch.clamp(y, 0)
        common_feature_tensor = x

        # probability net
        logit_tensor = self.prob_net(common_feature_tensor)
        logit_flatten_tensor = logit_tensor.view(-1)
        prob_flatten_tensor = functional.softmax(logit_flatten_tensor, dim=-1)
        prob_tensor = prob_flatten_tensor.view(-1, *self.input_shape)

        # value net
        v_feature_tensor = self.value_net0(common_feature_tensor)
        v_flatten_tensor = v_feature_tensor.view(-1, np.prod(self.input_shape))
        v_tensor = self.value_net1(v_flatten_tensor)

        return prob_tensor, v_tensor

In [5]:
class AlphaZeroAgent:

    def __init__(self, env):
        self.env = env
        self.board = np.zeros_like(env.board)
        self.reset_mcts()

        self.replayer = AlphaZeroReplayer()

        self.net = AlphaZeroNet(input_shape=self.board.shape)
        self.prob_loss = nn.BCELoss()
        self.v_loss = nn.MSELoss()
        self.optimizer = optim.Adam(self.net.parameters(), 1e-3,
                weight_decay=1e-4)

    def reset_mcts(self):
        def zero_board_factory():  # for construct default_dict
            return np.zeros_like(self.board, dtype=float)
        self.q = collections.defaultdict(zero_board_factory)
                # q estimates: board -> board
        self.count = collections.defaultdict(zero_board_factory)
                # q count visitation: board -> board
        self.policy = {}  # policy: board -> board
        self.valid = {}  # valid position: board -> board
        self.winner = {}  # winner: board -> None or int

    def reset(self, mode):
        self.mode = mode
        if mode == "train":
            self.trajectory = []

    def step(self, observation, winner, _):
        board, player = observation
        canonical_board = player * board
        s = boardgame2.strfboard(canonical_board)
        while self.count[s].sum() < 200:  # conduct MCTS 200 times
            self.search(canonical_board, prior_noise=True)
        prob = self.count[s] / self.count[s].sum()

        # sample
        location_index = np.random.choice(prob.size, p=prob.reshape(-1))
        action = np.unravel_index(location_index, prob.shape)

        if self.mode == 'train':
            self.trajectory += [player, board, prob, winner]
        return action

    def close(self):
        if self.mode == 'train':
            self.save_trajectory_to_replayer()
            if len(self.replayer.memory) >= 1000:
                for batch in range(2):  # learn multiple times
                    self.learn()
                self.replayer = AlphaZeroReplayer()
                        # reset replayer after the agent changes itself
                self.reset_mcts()

    def save_trajectory_to_replayer(self):
        df = pd.DataFrame(
                np.array(self.trajectory, dtype=object).reshape(-1, 4),
                columns=['player', 'board', 'prob', 'winner'], dtype=object)
        winner = self.trajectory[-1]
        df['winner'] = winner
        self.replayer.store(df)

    def search(self, board, prior_noise=False):  # MCTS
        s = boardgame2.strfboard(board)

        if s not in self.winner:
            self.winner[s] = self.env.get_winner((board, BLACK))
        if self.winner[s] is not None:  # if there is a winner
            return self.winner[s]

        if s not in self.policy:  # leaf that has not calculate the policy
            board_tensor = torch.as_tensor(board, dtype=torch.float).view(1, 1,
                    *self.board.shape)
            pi_tensor, v_tensor = self.net(board_tensor)
            pi = pi_tensor.detach().numpy()[0]
            v = v_tensor.detach().numpy()[0]
            valid = self.env.get_valid((board, BLACK))
            masked_pi = pi * valid
            total_masked_pi = np.sum(masked_pi)
            if total_masked_pi <= 0:
                # all valid actions do not have probabilities. rarely occur
                masked_pi = valid  # workaround
                total_masked_pi = np.sum(masked_pi)
            self.policy[s] = masked_pi / total_masked_pi
            self.valid[s] = valid
            return v

        # calculate PUCT
        count_sum = self.count[s].sum()
        c_init = 1.25
        c_base = 19652.
        coef = (c_init + np.log1p((1 + count_sum) / c_base)) * \
                math.sqrt(count_sum) / (1. + self.count[s])
        if prior_noise:
            alpha = 1. / self.valid[s].sum()
            noise = np.random.gamma(alpha, 1., board.shape)
            noise *= self.valid[s]
            noise /= noise.sum()
            prior_exploration_fraction=0.25
            prior = (1. - prior_exploration_fraction) * self.policy[s] \
                    + prior_exploration_fraction * noise
        else:
            prior = self.policy[s]
        ub = np.where(self.valid[s], self.q[s] + coef * prior, np.nan)
        location_index = np.nanargmax(ub)
        location = np.unravel_index(location_index, board.shape)

        (next_board, next_player), _, _, _ = self.env.next_step(
                (board, BLACK), np.array(location))
        next_canonical_board = next_player * next_board
        next_v = self.search(next_canonical_board)  # recursive
        v = next_player * next_v

        self.count[s][location] += 1
        self.q[s][location] += (v - self.q[s][location]) / \
                self.count[s][location]
        return v

    def learn(self):
        players, boards, probs, winners = self.replayer.sample(64)
        canonical_boards = players[:, np.newaxis, np.newaxis] * boards
        targets = (players * winners)[:, np.newaxis]

        target_prob_tensor = torch.as_tensor(probs, dtype=torch.float)
        canonical_board_tensor = torch.as_tensor(canonical_boards,
                dtype=torch.float)
        target_tensor = torch.as_tensor(targets, dtype=torch.float)

        prob_tensor, v_tensor = self.net(canonical_board_tensor)

        flatten_target_prob_tensor = target_prob_tensor.view(-1, self.board.size)
        flatten_prob_tensor = prob_tensor.view(-1, self.board.size)
        prob_loss_tensor = self.prob_loss(flatten_prob_tensor,
                flatten_target_prob_tensor)
        v_loss_tensor = self.v_loss(v_tensor, target_tensor)
        loss_tensor = prob_loss_tensor + v_loss_tensor
        self.optimizer.zero_grad()
        loss_tensor.backward()
        self.optimizer.step()


agent = AlphaZeroAgent(env=env)

In [6]:
def play_boardgame2_episode(env, agent, mode=None, verbose=False):
    observation, _ = env.reset()
    winner = 0
    terminated, truncated = False, False
    agent.reset(mode=mode)
    elapsed_steps = 0
    while True:
        if verbose:
            board, player = observation
            print(boardgame2.strfboard(board))
        action = agent.step(observation, winner, terminated)
        if verbose:
            logging.info('step %d：player %d, action %s', elapsed_steps, player,
                    action)
        observation, winner, terminated, truncated, _ = env.step(action)
        if terminated or truncated:
            if verbose:
                board, _ = observation
                print(boardgame2.strfboard(board))
            break
        elapsed_steps += 1
    agent.close()
    return winner, elapsed_steps


for episode in range(5000):
    winner, elapsed_steps = play_boardgame2_episode(env, agent, mode='train')
    logging.info('train episode %d: winner = %d, steps = %d', episode, winner,
            elapsed_steps)

    if len(agent.replayer.memory) == 0:  # just finish learning
        logging.info('test episode %d:', episode)
        winner, elapsed_steps = play_boardgame2_episode(env, agent, mode='test',
                verbose=True)
        logging.info('test episode %d: winner = %d, steps = %d',
                episode, winner, elapsed_steps)

00:11:21 [INFO] train episode 0: winner = -1, steps = 5
00:11:38 [INFO] train episode 1: winner = 1, steps = 6
00:11:50 [INFO] train episode 2: winner = 1, steps = 6
00:11:53 [INFO] train episode 3: winner = 1, steps = 6
00:12:12 [INFO] train episode 4: winner = 0, steps = 8
00:12:15 [INFO] train episode 5: winner = 1, steps = 4
00:12:35 [INFO] train episode 6: winner = 0, steps = 8
00:12:47 [INFO] train episode 7: winner = 1, steps = 6
00:12:56 [INFO] train episode 8: winner = 1, steps = 6
00:13:07 [INFO] train episode 9: winner = 1, steps = 6
00:13:14 [INFO] train episode 10: winner = 1, steps = 4
00:13:22 [INFO] train episode 11: winner = 1, steps = 4
00:13:23 [INFO] train episode 12: winner = 0, steps = 8
00:13:27 [INFO] train episode 13: winner = 0, steps = 8
00:13:38 [INFO] train episode 14: winner = 0, steps = 8
00:13:41 [INFO] train episode 15: winner = 1, steps = 6
00:13:44 [INFO] train episode 16: winner = 1, steps = 6
00:13:49 [INFO] train episode 17: winner = -1, steps = 5
