In [1]:
import random
import torch
import math
import numpy as np
from torch import nn
from torch import optim

In [2]:
WIDTH = 3
HEIGHT = 3
MOVE_COUNT = 9

PLAYERS = {
    -1: "O",
    1: "X",
}

def get_move_coords(move_index):
    return move_index % 3, move_index // 3

class Board(object):
    def __init__(self, pieces):
        assert(pieces.size() == (HEIGHT, WIDTH))
        self.pieces = pieces

    @classmethod
    def initial(cls):
        return cls(torch.zeros(3, 3))

    def is_valid_move(self, move_index):
        x, y = get_move_coords(move_index)
        return self.pieces[y][x] == 0
    
    def get_valid_moves(self):
        return [index for index in range(MOVE_COUNT) if self.is_valid_move(index)]

    def get_valid_move_mask(self):
        return torch.tensor([1 if self.is_valid_move(index) else 0 for index in range(MOVE_COUNT)])

    def has_empty_slots(self):
        return len(self.get_valid_moves()) > 0

    def execute_move(self, player, move_index):
        assert(player in PLAYERS.keys())
        assert(self.is_valid_move(move_index))
        x, y = get_move_coords(move_index)
        self.pieces[y][x] = player

    def has_won(self, player):
        assert(player in PLAYERS.keys())
        target = 3 * player
        for x in range(3):
            if torch.sum(self.pieces[:, x]) == target:
                return True
        for y in range(3):
            if torch.sum(self.pieces[y, :]) == target:
                return True
        if sum([self.pieces[i, i].item() for i in range(3)]) == target:
            return True
        if sum([self.pieces[2 - i, i].item() for i in range(3)]) == target:
            return True
        return False

    def get_score(self):
        for player in PLAYERS.keys():
            if self.has_won(player):
                return player
        if not self.has_empty_slots():
            return 0
        return None

    def get_random_move(self):
        legal_moves = self.get_valid_moves()
        legal_move_count = len(legal_moves)
        if legal_move_count == 0:
            return None
        return legal_moves[random.randrange(legal_move_count)]


def get_next_state(pieces, player, action):
    board = Board(torch.clone(pieces))
    board.execute_move(player, action)
    return board.pieces, -player


def get_canonical_form(pieces, player):
    return player * pieces


_SLOTS = {
    -1: "O",
    0: "-",
    1: "X",
}

def get_repr(pieces):
    return "".join([_SLOTS[slot.item()] for slot in pieces.view(9)])

# TODO: Symmetries


In [3]:
SIMULATION_COUNT = 10
C_PUCT = 0.1
EPS = 1e-8

class MonteCarloTreeSearch(object):
    def __init__(self, model):
        self.model = model
        self.Qsa = {}  # stores Q values for s,a (as defined in the paper)
        self.Nsa = {}  # stores #times edge s,a was visited
        self.Ns = {}  # stores #times board s was visited
        self.Ps = {}  # stores initial policy (returned by neural net)
        self.Es = {}  # stores game.getGameEnded ended for board s
        self.Vs = {}  # stores game.getValidMoves for board s

    def get_actions(self, pieces, temp=1):
        for i in range(SIMULATION_COUNT):
            self.search(pieces)

        s = get_repr(pieces)
        counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(MOVE_COUNT)]
        counts = [n ** (1. / temp) for n in counts]
        total = float(sum(counts))
        return [n / total for n in counts]

    def search(self, pieces):
        s = get_repr(pieces)
        board = Board(pieces)

        if s not in self.Es:
            self.Es[s] = board.get_score()
        if self.Es[s] is not None:
            # terminal node
            return -self.Es[s]

        if s not in self.Ps:
            # leaf node
            self.Ps[s], v = self.model.predict(pieces)
            valids = board.get_valid_move_mask()
            self.Ps[s] = self.Ps[s] * valids  # masking invalid moves
            sum_Ps_s = torch.sum(self.Ps[s])
            if sum_Ps_s > 0:
                self.Ps[s] /= sum_Ps_s  # renormalize
            else:
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= torch.sum(self.Ps[s])

            self.Vs[s] = valids
            self.Ns[s] = 0
            return -v

        valids = self.Vs[s]
        best_upper_confidence = -float('inf')
        best_action = -1

        for a in range(MOVE_COUNT):
            if valids[a]:
                if (s, a) in self.Qsa:
                    u = self.Qsa[(s, a)] + C_PUCT * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)])
                else:
                    u = C_PUCT * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS)  # Q = 0 ?

                if u > best_upper_confidence:
                    best_upper_confidence = u
                    best_action = a

        a = best_action
        next_pieces, next_player = get_next_state(pieces, 1, a)
        next_pieces = get_canonical_form(next_pieces, next_player)

        v = self.search(next_pieces)

        if (s, a) in self.Qsa:
            self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
            self.Nsa[(s, a)] += 1
        else:
            self.Qsa[(s, a)] = v
            self.Nsa[(s, a)] = 1

        self.Ns[s] += 1
        return -v



In [4]:
import unittest

class TestBoard(unittest.TestCase):
    def test_legal_moves(self):
        board = Board.initial()
        self.assertTrue(board.has_empty_slots())
        self.assertListEqual(list(range(MOVE_COUNT)), board.get_valid_moves())
        self.assertFalse(board.has_won(1))
        self.assertFalse(board.has_won(-1))

    def test_execute_move(self):
        board = Board.initial()
        self.assertTrue(board.is_valid_move(1))
        board.execute_move(1, 1)
        self.assertFalse(board.is_valid_move(1))
        self.assertEqual(MOVE_COUNT - 1, len(board.get_valid_moves()))

    def test_has_won(self):
        board = Board.initial()
        board.execute_move(1, 0)
        board.execute_move(1, 1)
        self.assertFalse(board.has_won(1))
        self.assertFalse(board.has_won(-1))
        self.assertEqual(None, board.get_score())
        board.execute_move(1, 2)
        self.assertTrue(board.has_won(1))
        self.assertFalse(board.has_won(-1))
        self.assertEqual(1, board.get_score())

    def test_has_won_diagonal(self):
        board = Board.initial()
        board.execute_move(-1, 0)
        board.execute_move(-1, 4)
        self.assertFalse(board.has_won(1))
        self.assertFalse(board.has_won(-1))
        self.assertEqual(None, board.get_score())
        board.execute_move(-1, 8)
        self.assertFalse(board.has_won(1))
        self.assertTrue(board.has_won(-1))
        self.assertEqual(-1, board.get_score())


class RandomModel(object):
    def predict(self, pieces):
        return torch.rand(MOVE_COUNT), random.randrange(-1, 1)


class TestMCTS(unittest.TestCase):
    def test_get_actions(self):
        model = RandomModel()
        mcts = MonteCarloTreeSearch(model)
        board = Board.initial()
        actions = mcts.get_actions(board.pieces)
        self.assertEqual(MOVE_COUNT, len(actions))


unittest.main(argv=[''], verbosity=2, exit=False)

test_execute_move (__main__.TestBoard) ... ok
test_has_won (__main__.TestBoard) ... ok
test_has_won_diagonal (__main__.TestBoard) ... ok
test_legal_moves (__main__.TestBoard) ... ok
test_get_actions (__main__.TestMCTS) ... ok

----------------------------------------------------------------------
Ran 5 tests in 0.017s

OK


<unittest.main.TestProgram at 0x7f5b1d0f9c40>