In [254]:
import torch

seed = 42

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [255]:
import torch


class Board:
    def __init__(self) -> None:
        self.current_player = 1

        self.board = torch.zeros(9)

        self.game_history = []

    def get_legality_matrix(self) -> torch.Tensor:
        legality_matrix = torch.zeros_like(self.board)
        for index, element in enumerate(self.board):
            if torch.round(torch.abs(element)).item() == 0:
                legality_matrix[index] = 1
            else:
                legality_matrix[index] = 0

        return legality_matrix

    def get_game_state(self) -> torch.Tensor:
        curr_player = torch.Tensor([self.current_player])
        board = self.board

        return torch.cat([curr_player, board])

    def check_if_game_over(self) -> tuple:
        draw = True
        for index, element in enumerate(self.board):
            if torch.round(element).item() == 0:
                draw = False
        if draw:
            return True, 0
        else:
            vis_board = self.board.reshape(3, 3)

            row_sums = torch.round(torch.sum(vis_board, dim=1))
            col_sums = torch.round(torch.sum(vis_board, dim=0))
            diag1_sum = torch.round(torch.sum(torch.diag(vis_board)))
            diag2_sum = torch.round(torch.sum(torch.diag(torch.flip(vis_board, dims=[1]))))

            all_sums = torch.cat((row_sums, col_sums, diag1_sum.unsqueeze(0), diag2_sum.unsqueeze(0)), dim=0)

            for index, element in enumerate(all_sums):
                if abs(element.item()) == 3:
                    return True, element.item() / 3

        return False, 0

    def player_make_turn(self, where_player_went: int) -> None:
        self.game_history.append((self.get_game_state(), where_player_went))
        self.board[where_player_went] = self.current_player
        self.current_player = -self.current_player

    def set_winner(self, winner: int) -> tuple:
        good_moves, bad_moves, neutral_moves = [], [], []
        if winner == 0:
            neutral_moves = self.game_history.copy()
        else:
            for index, (game_state, player_choice) in enumerate(self.game_history):
                if winner == game_state[0]:
                    good_moves.append((game_state, player_choice))
                else:
                    bad_moves.append((game_state, player_choice))

        return good_moves, bad_moves, neutral_moves


TTT = Board()

In [256]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()

        self.layers = nn.Sequential(
            nn.Linear(in_features=10, out_features=32),
            self.tanh,
            nn.Linear(in_features=32, out_features=32),
            self.tanh,
            nn.Linear(in_features=32, out_features=32),
            self.tanh,
            nn.Linear(in_features=32, out_features=9),
        )

    def forward(self, x: torch.Tensor, legality_matrix: torch.Tensor) -> torch.Tensor:
        x = self.layers(x)
        x = (self.tanh(x) + 1) / 2
        x = legality_matrix * x
        x = self.softmax(x)
        return x


model = Model()

In [257]:
def play_one_game():
    game_over = False
    while not game_over:
        logits = model(TTT.get_game_state(), TTT.get_legality_matrix())
        choice = torch.argmax(logits).item()
        TTT.player_make_turn(choice)
        is_over, player = TTT.check_if_game_over()
        if is_over:
            g_moves, b_moves, n_moves = TTT.set_winner(player)
            print(TTT.board.reshape(3,3))
            return g_moves, b_moves, n_moves

In [258]:
g, b, n = play_one_game()
print(g, end='\n\n')
print(b, end='\n\n')
print(n, end='\n\n')

tensor([[ 1.,  1., -1.],
        [ 0.,  1.,  1.],
        [-1., -1., -1.]])
[(tensor([-1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.]), 7), (tensor([-1.,  1.,  0.,  0.,  0.,  0.,  1.,  0., -1.,  0.]), 8), (tensor([-1.,  1.,  0.,  0.,  0.,  1.,  1.,  0., -1., -1.]), 2), (tensor([-1.,  1.,  1., -1.,  0.,  1.,  1.,  0., -1., -1.]), 6)]

[(tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 5), (tensor([ 1.,  0.,  0.,  0.,  0.,  0.,  1.,  0., -1.,  0.]), 0), (tensor([ 1.,  1.,  0.,  0.,  0.,  0.,  1.,  0., -1., -1.]), 4), (tensor([ 1.,  1.,  0., -1.,  0.,  1.,  1.,  0., -1., -1.]), 1)]

[]

