In [38]:
import torch

seed = 42

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [39]:
import torch


class Board:
    def __init__(self) -> None:
        self.current_player = 1
        self.board = torch.zeros(9, device=device)
        self.game_history = []

    def reset_board(self):
        self.current_player = 1
        self.board = torch.zeros(9, device=device)
        self.game_history = []

    def get_game_state(self) -> tuple:
        curr_player = torch.Tensor([self.current_player]).to(device)
        board = self.board

        legality_matrix = torch.zeros_like(self.board, device=device)
        for i, element in enumerate(self.board):
            if torch.round(torch.abs(element)).item() == 0:
                legality_matrix[i] = 1
            else:
                legality_matrix[i] = 0

        return torch.cat([curr_player, board]), legality_matrix

    def check_if_game_over(self) -> tuple:
        game_over = False

        vis_board = self.board.reshape(3, 3)

        row_sums = torch.round(torch.sum(vis_board, dim=1))
        col_sums = torch.round(torch.sum(vis_board, dim=0))
        diag1_sum = torch.round(torch.sum(torch.diag(vis_board)))
        diag2_sum = torch.round(torch.sum(torch.diag(torch.flip(vis_board, dims=[1]))))

        all_sums = torch.cat((row_sums, col_sums, diag1_sum.unsqueeze(0), diag2_sum.unsqueeze(0)), dim=0)

        for i, element in enumerate(all_sums):
            if abs(element.item()) == 3:
                game_over = True
                return True, element.item() / 3
        if not game_over:
            draw = True
            for i, element in enumerate(self.board):
                if torch.round(element).item() == 0:
                    draw = False
            if draw:
                return True, 0
            else:
                return False, 0

    def player_make_turn(self, where_player_went: int) -> None:
        g_state, l_matrix = self.get_game_state()
        self.game_history.append((g_state, where_player_went, l_matrix))
        self.board[where_player_went] = self.current_player
        self.current_player = -self.current_player

    def set_winner(self, winner: int) -> tuple:
        good_moves, bad_moves, neutral_moves = [], [], []
        if winner == 0:
            neutral_moves = self.game_history.copy()
        else:
            for i, (g_state, p_choice, l_matrix) in enumerate(self.game_history):
                if winner == g_state[0]:
                    good_moves.append((g_state, p_choice, l_matrix))
                else:
                    bad_moves.append((g_state, p_choice, l_matrix))

        return good_moves, bad_moves, neutral_moves


TTT = Board()

In [40]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=0)

        self.layers = nn.Sequential(
            nn.Linear(in_features=10, out_features=32),
            self.tanh,
            nn.Linear(in_features=32, out_features=32),
            self.tanh,
            nn.Linear(in_features=32, out_features=32),
            self.tanh,
            nn.Linear(in_features=32, out_features=9),
        )

    def forward(self, x: torch.Tensor, leg_matrix: torch.Tensor) -> torch.Tensor:
        x = self.layers(x)
        x = (self.tanh(x) + 1) / 2
        x = x.masked_fill(leg_matrix == 0, float('-inf'))
        return x


model = Model().to(device)

In [41]:
def play_machine_only():
    TTT.reset_board()
    gameover = False
    while not gameover:
        is_over, player = TTT.check_if_game_over()
        if not is_over:
            g_state, l_matrix = TTT.get_game_state()
            raw_logits = model(g_state, l_matrix)
            softmax_logits = model.softmax(raw_logits)
            choice = torch.argmax(softmax_logits).item()
            TTT.player_make_turn(choice)
        else:
            # gameover = True
            g_moves, b_moves, n_moves = TTT.set_winner(player)
            return g_moves, b_moves, n_moves

In [42]:
def play_with_player(playerturn: bool = True):
    TTT.reset_board()
    gameover = False
    print(TTT.board.reshape(3, 3), end="\n\n")
    while not gameover:
        is_over, player = TTT.check_if_game_over()
        if not is_over:
            if playerturn:
                player_index = int(input(f"{TTT.board.reshape(3, 3)}\n\nWhere to go? "))
                TTT.player_make_turn(player_index)
                print(TTT.board.reshape(3, 3), end="\n\n")
                playerturn = False
            else:
                g_state, l_matrix = TTT.get_game_state()
                raw_logits = model(g_state, l_matrix)
                softmax_logits = model.softmax(raw_logits)
                choice = torch.argmax(softmax_logits).item()
                TTT.player_make_turn(choice)
                # print(TTT.board.reshape(3, 3), end="\n\n")
                playerturn = True
        else:
            gameover = True
            g_moves, b_moves, n_moves = TTT.set_winner(player)
            if playerturn:
                print("Machine won")
            else:
                print("Human won")
            # return g_moves, b_moves, n_moves

In [48]:
import random

def play_against_bot(rounds):
    wins, draws, losses = 0, 0, 0
    for i in range(rounds):
        print(f"Round {i}")
        machinefirst = random.choice([True, False])
        machineturn = True if machinefirst else False
        TTT.reset_board()
        while True:
            is_over, player = TTT.check_if_game_over()
            if is_over:
                break
            else:
                g_state, l_matrix = TTT.get_game_state()
                if machineturn:
                    raw_logits = model(g_state, l_matrix)
                    softmax_logits = model.softmax(raw_logits)
                    choice = torch.argmax(softmax_logits).item()
                    TTT.player_make_turn(choice)
                    machineturn = False
                else:
                    proper_choice = False
                    while not proper_choice:
                        random_index = random.randint(0,8)
                        if l_matrix[random_index] == 1:
                            proper_choice = True
                    TTT.player_make_turn(random_index)
                    machineturn = True
        is_over, player = TTT.check_if_game_over()
        if player == 0:
            draws += 1
        elif (player == 1 and machinefirst) or (player == -1 and not machinefirst):
            wins += 1
        else:
            losses += 1

    print(f"Wins: {wins:,} | Draws: {draws:,} | Losses: {losses:,}")

In [43]:
import torch

EPOCHS = 1000
LEARNING_RATE = 0.01

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    print(f"E {epoch + 1:,} - {((epoch + 1) / EPOCHS) * 100:.2f}%")
    model.eval()

    with torch.no_grad():
        g, b, n = play_machine_only()
        good, bad, neutral = [], [], []
        for index, (g_state, p_choice, l_matrix) in enumerate(g):
            onehot = torch.zeros(9, device=device)
            onehot[p_choice] = 1.0
            good.append((g_state, onehot, l_matrix))
        for index, (g_state, p_choice, l_matrix) in enumerate(b):
            onehot = torch.ones(9, device=device)
            onehot[p_choice] = 0.0
            bad.append((g_state, onehot, l_matrix))
        for index, (g_state, p_choice, l_matrix) in enumerate(n):
            onehot = torch.full(size=(9,), fill_value=0.5, device=device)
            onehot[p_choice] = 0.0
            neutral.append((g_state, onehot, l_matrix))
        game_positions = good + bad + neutral

    model.train()
    optimizer.zero_grad()

    for index, (g, p, l) in enumerate(game_positions):
        outputs = model(g, l)
        loss = loss_fn(outputs, p)
        loss.backward()

    optimizer.step()

E 1 - 0.00%
E 2 - 0.00%
E 3 - 0.00%
E 4 - 0.00%
E 5 - 0.01%
E 6 - 0.01%
E 7 - 0.01%
E 8 - 0.01%
E 9 - 0.01%
E 10 - 0.01%
E 11 - 0.01%
E 12 - 0.01%
E 13 - 0.01%
E 14 - 0.01%
E 15 - 0.01%
E 16 - 0.02%
E 17 - 0.02%
E 18 - 0.02%
E 19 - 0.02%
E 20 - 0.02%
E 21 - 0.02%
E 22 - 0.02%
E 23 - 0.02%
E 24 - 0.02%
E 25 - 0.03%
E 26 - 0.03%
E 27 - 0.03%
E 28 - 0.03%
E 29 - 0.03%
E 30 - 0.03%
E 31 - 0.03%
E 32 - 0.03%
E 33 - 0.03%
E 34 - 0.03%
E 35 - 0.03%
E 36 - 0.04%
E 37 - 0.04%
E 38 - 0.04%
E 39 - 0.04%
E 40 - 0.04%
E 41 - 0.04%
E 42 - 0.04%
E 43 - 0.04%
E 44 - 0.04%
E 45 - 0.04%
E 46 - 0.05%
E 47 - 0.05%
E 48 - 0.05%
E 49 - 0.05%
E 50 - 0.05%
E 51 - 0.05%
E 52 - 0.05%
E 53 - 0.05%
E 54 - 0.05%
E 55 - 0.06%
E 56 - 0.06%
E 57 - 0.06%
E 58 - 0.06%
E 59 - 0.06%
E 60 - 0.06%
E 61 - 0.06%
E 62 - 0.06%
E 63 - 0.06%
E 64 - 0.06%
E 65 - 0.07%
E 66 - 0.07%
E 67 - 0.07%
E 68 - 0.07%
E 69 - 0.07%
E 70 - 0.07%
E 71 - 0.07%
E 72 - 0.07%
E 73 - 0.07%
E 74 - 0.07%
E 75 - 0.07%
E 76 - 0.08%
E 77 - 0.08%
E 78 - 0

In [49]:
play_against_bot(1000)

Round 0
Round 1
Round 2
Round 3
Round 4
Round 5
Round 6
Round 7
Round 8
Round 9
Round 10
Round 11
Round 12
Round 13
Round 14
Round 15
Round 16
Round 17
Round 18
Round 19
Round 20
Round 21
Round 22
Round 23
Round 24
Round 25
Round 26
Round 27
Round 28
Round 29
Round 30
Round 31
Round 32
Round 33
Round 34
Round 35
Round 36
Round 37
Round 38
Round 39
Round 40
Round 41
Round 42
Round 43
Round 44
Round 45
Round 46
Round 47
Round 48
Round 49
Round 50
Round 51
Round 52
Round 53
Round 54
Round 55
Round 56
Round 57
Round 58
Round 59
Round 60
Round 61
Round 62
Round 63
Round 64
Round 65
Round 66
Round 67
Round 68
Round 69
Round 70
Round 71
Round 72
Round 73
Round 74
Round 75
Round 76
Round 77
Round 78
Round 79
Round 80
Round 81
Round 82
Round 83
Round 84
Round 85
Round 86
Round 87
Round 88
Round 89
Round 90
Round 91
Round 92
Round 93
Round 94
Round 95
Round 96
Round 97
Round 98
Round 99
Round 100
Round 101
Round 102
Round 103
Round 104
Round 105
Round 106
Round 107
Round 108
Round 109
Round 110
