In [3]:
import random
import torch
import math
from torch import nn
from torch import optim

In [19]:
WIDTH = 3
HEIGHT = 3
MOVE_COUNT = 9

PLAYERS = {
    -1: "O",
    1: "X",
}

def get_move_coords(move_index):
    return move_index % 3, move_index // 3

class Board(object):
    def __init__(self, pieces):
        assert(pieces.size() == (HEIGHT, WIDTH))
        self.pieces = pieces

    @classmethod
    def initial(cls):
        return cls(torch.zeros(3, 3))

    def is_legal_move(self, move_index):
        x, y = get_move_coords(move_index)
        return self.pieces[y][x] == 0
    
    def get_legal_moves(self):
        return [index for index in range(MOVE_COUNT) if self.is_legal_move(index)]

    def has_legal_moves(self):
        return len(self.get_legal_moves()) > 0

    def execute_move(self, player, move_index):
        assert(player in PLAYERS.keys())
        assert(self.is_legal_move(move_index))
        x, y = get_move_coords(move_index)
        self.pieces[y][x] = player

    def has_won(self, player):
        assert(player in PLAYERS.keys())
        target = 3 * player
        for x in range(3):
            if torch.sum(self.pieces[:, x]) == target:
                return True
        for y in range(3):
            if torch.sum(self.pieces[y, :]) == target:
                return True
        if sum([self.pieces[i, i].item() for i in range(3)]) == target:
            return True
        if sum([self.pieces[2 - i, i].item() for i in range(3)]) == target:
            return True
        return False



def get_next_state(pieces, player, action):
    board = Board(pieces.copy())
    board.execute_move(player, action)
    return board.pieces, -player

In [22]:
import unittest

class TestBoard(unittest.TestCase):
    def test_legal_moves(self):
        board = Board.initial()
        self.assertTrue(board.has_legal_moves())
        self.assertListEqual(list(range(MOVE_COUNT)), board.get_legal_moves())
        self.assertFalse(board.has_won(1))
        self.assertFalse(board.has_won(-1))

    def test_execute_move(self):
        board = Board.initial()
        self.assertTrue(board.is_legal_move(1))
        board.execute_move(1, 1)
        self.assertFalse(board.is_legal_move(1))
        self.assertEqual(MOVE_COUNT - 1, len(board.get_legal_moves()))

    def test_has_won(self):
        board = Board.initial()
        board.execute_move(1, 0)
        board.execute_move(1, 1)
        self.assertFalse(board.has_won(1))
        self.assertFalse(board.has_won(-1))
        board.execute_move(1, 2)
        self.assertTrue(board.has_won(1))
        self.assertFalse(board.has_won(-1))

    def test_has_won_diagonal(self):
        board = Board.initial()
        board.execute_move(-1, 0)
        board.execute_move(-1, 4)
        self.assertFalse(board.has_won(1))
        self.assertFalse(board.has_won(-1))
        board.execute_move(-1, 8)
        self.assertFalse(board.has_won(1))
        self.assertTrue(board.has_won(-1))


unittest.main(argv=[''], verbosity=2, exit=False)

test_execute_move (__main__.TestBoard) ... ok
test_has_won (__main__.TestBoard) ... ok
test_has_won_diagonal (__main__.TestBoard) ... ok
test_legal_moves (__main__.TestBoard) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.009s

OK


<unittest.main.TestProgram at 0x7fbb6e6eac10>