In [17]:
from collections import namedtuple, deque
from random import shuffle
from torch import nn
from torch import optim
from tqdm import tqdm
import math
import numpy as np
import random
import torch
import torch.nn.functional as F

In [16]:
WIDTH = 3
HEIGHT = 3
ACTION_COUNT = 9

PLAYERS = {
    -1: "O",
    1: "X",
}

def get_action_coords(action_index):
    return action_index % 3, action_index // 3

class Board(object):
    def __init__(self, pieces):
        assert(pieces.size() == (HEIGHT, WIDTH))
        self.pieces = pieces

    @classmethod
    def initial(cls):
        return cls(torch.zeros(3, 3))

    def is_valid_action(self, action_index):
        x, y = get_action_coords(action_index)
        return self.pieces[y][x] == 0
    
    def get_valid_actions(self):
        return [index for index in range(ACTION_COUNT) if self.is_valid_action(index)]

    def get_valid_action_mask(self):
        return torch.tensor([1 if self.is_valid_action(index) else 0 for index in range(ACTION_COUNT)])

    def has_empty_slots(self):
        return len(self.get_valid_actions()) > 0

    def execute_action(self, player, action_index):
        assert(player in PLAYERS.keys())
        assert(self.is_valid_action(action_index))
        x, y = get_action_coords(action_index)
        self.pieces[y][x] = player

    def has_won(self, player):
        assert(player in PLAYERS.keys())
        target = 3 * player
        for x in range(3):
            if torch.sum(self.pieces[:, x]) == target:
                return True
        for y in range(3):
            if torch.sum(self.pieces[y, :]) == target:
                return True
        if sum([self.pieces[i, i].item() for i in range(3)]) == target:
            return True
        if sum([self.pieces[2 - i, i].item() for i in range(3)]) == target:
            return True
        return False

    def get_score(self):
        for player in PLAYERS.keys():
            if self.has_won(player):
                return player
        if not self.has_empty_slots():
            return 0
        return None

    def get_random_action(self):
        legal_actions = self.get_valid_actions()
        legal_ACTION_COUNT = len(legal_actions)
        if legal_ACTION_COUNT == 0:
            return None
        return legal_actions[random.randrange(legal_ACTION_COUNT)]


def get_next_state(pieces, player, action):
    board = Board(torch.clone(pieces))
    board.execute_action(player, action)
    return board.pieces, -player


def get_canonical_form(pieces, player):
    return player * pieces


_SLOTS = {
    -1: "O",
    0: "-",
    1: "X",
}

def get_repr(pieces):
    return "".join([_SLOTS[slot.item()] for slot in pieces.view(9)])

# TODO: Symmetries


In [5]:
SIMULATION_COUNT = 25
C_PUCT = 1.0
EPS = 1e-8

class MonteCarloTreeSearch(object):
    def __init__(self, model):
        self.model = model
        self.Qsa = {}  # stores Q values for s,a (as defined in the paper)
        self.Nsa = {}  # stores #times edge s,a was visited
        self.Ns = {}  # stores #times board s was visited
        self.Ps = {}  # stores initial policy (returned by neural net)
        self.Es = {}  # stores game.getGameEnded ended for board s
        self.Vs = {}  # stores game.getValidactions for board s

    def get_actions(self, pieces, temp=1):
        for i in range(SIMULATION_COUNT):
            self.search(pieces)

        s = get_repr(pieces)
        counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(ACTION_COUNT)]
        counts = [n ** (1. / temp) for n in counts]
        total = float(sum(counts))
        return [n / total for n in counts]

    def search(self, pieces):
        s = get_repr(pieces)
        board = Board(pieces)

        if s not in self.Es:
            self.Es[s] = board.get_score()
        if self.Es[s] is not None:
            # terminal node
            return -self.Es[s]

        if s not in self.Ps:
            # leaf node
            self.Ps[s], v = self.model.predict(pieces)
            valids = board.get_valid_action_mask()
            self.Ps[s] = self.Ps[s] * valids  # masking invalid actions
            sum_Ps_s = torch.sum(self.Ps[s])
            if sum_Ps_s > 0:
                self.Ps[s] /= sum_Ps_s  # renormalize
            else:
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= torch.sum(self.Ps[s])

            self.Vs[s] = valids
            self.Ns[s] = 0
            return -v

        valids = self.Vs[s]
        best_upper_confidence = -float('inf')
        best_action = -1

        for a in range(ACTION_COUNT):
            if valids[a]:
                if (s, a) in self.Qsa:
                    u = self.Qsa[(s, a)] + C_PUCT * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)])
                else:
                    u = C_PUCT * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS)  # Q = 0 ?

                if u > best_upper_confidence:
                    best_upper_confidence = u
                    best_action = a

        a = best_action
        next_pieces, next_player = get_next_state(pieces, 1, a)
        next_pieces = get_canonical_form(next_pieces, next_player)

        v = self.search(next_pieces)

        if (s, a) in self.Qsa:
            self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
            self.Nsa[(s, a)] += 1
        else:
            self.Qsa[(s, a)] = v
            self.Nsa[(s, a)] = 1

        self.Ns[s] += 1
        return -v


In [14]:
INTERATION_COUNT = 10 #00
EPISODE_COUNT = 10 #0

Observation = namedtuple('Observation', ('board', 'player', 'pi'))

class Learner(object):
    def __init__(self, model):
        self.model = model
        self.mcts = None
        self.history = deque([], maxlen=20)

    def play_game(self):
        pieces = Board.initial().pieces
        player = 1
        observations = []
        while True:
            canonical_pieces = get_canonical_form(pieces, player)
            # TODO: Adjust temperature
            pi = self.mcts.get_actions(canonical_pieces)
            # TODO: Regularize using symmetries
            observations.append(Observation(pieces, player, pi))
            action = np.random.choice(len(pi), pi)
            pieces, player = get_next_state(pieces, player, action)
            score = Board(pieces).get_score()
            if score is not None:
                return [(obs.board, obs.pi, score * (-1) ** (obs.player != player)) for obs in observations]

    def learn(self):
        for i in range(INTERATION_COUNT):
            print("Starting iteration f{i} ...")
            observations = deque([], maxlen=20000)
            for _ in tqdm(range(EPISODE_COUNT), desc="Self Play"):
                self.mcts = MonteCarloTreeSearch(self.model)
                observations += self.play_game()
            self.history.append(observations)

            dataset = []
            for observations in self.history:
                dataset.extend(observations)
            shuffle(dataset)
            self.model.train(dataset)



In [20]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.fc1 = nn.Linear(9, 32)
        self.relu1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(32, 32)
        self.relu2 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(32, ACTION_COUNT)
        self.fc4 = nn.Linear(32, 1)

    def forward(self, x):
        out = x.view(x.size()[0], -1)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        pi = self.fc3(out)
        v = self.fc4(out)
        return F.log_softmax(pi, dim=1), torch.tanh(v)

In [24]:
class Model(object):
    def __init__(self, nnet):
        self.nnet = nnet

    def train(self, dataset):
        pass
    
    def predict(self, pieces):
        self.nnet.eval()
        pieces = pieces.unsqueeze(0)
        with torch.no_grad():
            pi, v = self.nnet(pieces)
        return torch.exp(pi)[0], v[0].item()

In [25]:
nnet = Network()
model = Model(nnet)
board = Board.initial()
model.predict(board.pieces)

# learner = Learner(model)
# learner.learn()

(tensor([0.0969, 0.0979, 0.1316, 0.1287, 0.0978, 0.1055, 0.1163, 0.1298, 0.0956]),
 0.010271133854985237)

In [4]:
import unittest

class TestBoard(unittest.TestCase):
    def test_legal_actions(self):
        board = Board.initial()
        self.assertTrue(board.has_empty_slots())
        self.assertListEqual(list(range(ACTION_COUNT)), board.get_valid_actions())
        self.assertFalse(board.has_won(1))
        self.assertFalse(board.has_won(-1))

    def test_execute_action(self):
        board = Board.initial()
        self.assertTrue(board.is_valid_action(1))
        board.execute_action(1, 1)
        self.assertFalse(board.is_valid_action(1))
        self.assertEqual(ACTION_COUNT - 1, len(board.get_valid_actions()))

    def test_has_won(self):
        board = Board.initial()
        board.execute_action(1, 0)
        board.execute_action(1, 1)
        self.assertFalse(board.has_won(1))
        self.assertFalse(board.has_won(-1))
        self.assertEqual(None, board.get_score())
        board.execute_action(1, 2)
        self.assertTrue(board.has_won(1))
        self.assertFalse(board.has_won(-1))
        self.assertEqual(1, board.get_score())

    def test_has_won_diagonal(self):
        board = Board.initial()
        board.execute_action(-1, 0)
        board.execute_action(-1, 4)
        self.assertFalse(board.has_won(1))
        self.assertFalse(board.has_won(-1))
        self.assertEqual(None, board.get_score())
        board.execute_action(-1, 8)
        self.assertFalse(board.has_won(1))
        self.assertTrue(board.has_won(-1))
        self.assertEqual(-1, board.get_score())


class RandomModel(object):
    def predict(self, pieces):
        return torch.rand(ACTION_COUNT), random.randrange(-1, 1)


class TestMCTS(unittest.TestCase):
    def test_get_actions(self):
        model = RandomModel()
        mcts = MonteCarloTreeSearch(model)
        board = Board.initial()
        actions = mcts.get_actions(board.pieces)
        self.assertEqual(ACTION_COUNT, len(actions))


unittest.main(argv=[''], verbosity=2, exit=False)

test_execute_move (__main__.TestBoard) ... ok
test_has_won (__main__.TestBoard) ... ok
test_has_won_diagonal (__main__.TestBoard) ... ok
test_legal_moves (__main__.TestBoard) ... ok
test_get_actions (__main__.TestMCTS) ... ok

----------------------------------------------------------------------
Ran 5 tests in 0.017s

OK


<unittest.main.TestProgram at 0x7f5b1d0f9c40>