In [127]:
import torch

seed = 42

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [128]:
import torch


class Board:
    def __init__(self) -> None:
        self.current_player = 1
        self.board = torch.zeros(9, device=device)
        self.game_history = []

    def reset_board(self):
        self.current_player = 1
        self.board = torch.zeros(9, device=device)
        self.game_history = []

    def get_game_state(self) -> tuple:
        curr_player = torch.Tensor([self.current_player]).to(device)
        board = self.board

        legality_matrix = torch.zeros_like(self.board, device=device)
        for i, element in enumerate(self.board):
            if torch.round(torch.abs(element)).item() == 0:
                legality_matrix[i] = 1
            else:
                legality_matrix[i] = 0

        return torch.cat([curr_player, board]), legality_matrix

    def check_if_game_over(self) -> tuple:
        game_over = False

        vis_board = self.board.reshape(3, 3)

        row_sums = torch.round(torch.sum(vis_board, dim=1))
        col_sums = torch.round(torch.sum(vis_board, dim=0))
        diag1_sum = torch.round(torch.sum(torch.diag(vis_board)))
        diag2_sum = torch.round(torch.sum(torch.diag(torch.flip(vis_board, dims=[1]))))

        all_sums = torch.cat((row_sums, col_sums, diag1_sum.unsqueeze(0), diag2_sum.unsqueeze(0)), dim=0)

        for i, element in enumerate(all_sums):
            if abs(element.item()) == 3:
                game_over = True
                return True, element.item() / 3
        if not game_over:
            draw = True
            for i, element in enumerate(self.board):
                if torch.round(element).item() == 0:
                    draw = False
            if draw:
                return True, 0
            else:
                return False, 0

    def player_make_turn(self, where_player_went: int) -> None:
        g_state, l_matrix = self.get_game_state()
        self.game_history.append((g_state, where_player_went, l_matrix))
        self.board[where_player_went] = self.current_player
        self.current_player = -self.current_player

    def set_winner(self, winner: int) -> tuple:
        good_moves, bad_moves, neutral_moves = [], [], []
        if winner == 0:
            neutral_moves = self.game_history.copy()
        else:
            for i, (g_state, p_choice, l_matrix) in enumerate(self.game_history):
                if winner == g_state[0]:
                    good_moves.append((g_state, p_choice, l_matrix))
                else:
                    bad_moves.append((g_state, p_choice, l_matrix))

        return good_moves, bad_moves, neutral_moves


TTT = Board()

In [129]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.tanh = nn.Tanh()
        self.lrelu = nn.LeakyReLU(negative_slope=0.2)
        self.softmax = nn.Softmax(dim=0)

        self.layers = nn.Sequential(
            nn.Linear(in_features=10, out_features=64),
            self.lrelu,
            nn.Linear(in_features=64, out_features=64),
            self.lrelu,
            nn.Linear(in_features=64, out_features=64),
            self.lrelu,
            nn.Linear(in_features=64, out_features=9),
        )

    def forward(self, x: torch.Tensor, leg_matrix: torch.Tensor) -> torch.Tensor:
        x = self.layers(x)
        # x = (self.tanh(x) + 1) / 2
        x = x.masked_fill(leg_matrix == 0, float('-99999999999'))
        return x


model = Model().to(device)

In [130]:
def play_machine_only():
    TTT.reset_board()
    gameover = False
    while not gameover:
        is_over, player = TTT.check_if_game_over()
        if not is_over:
            g_state, l_matrix = TTT.get_game_state()
            raw_logits = model(g_state, l_matrix)
            softmax_logits = model.softmax(raw_logits)
            choice = torch.argmax(softmax_logits).item()
            TTT.player_make_turn(choice)
        else:
            # gameover = True
            g_moves, b_moves, n_moves = TTT.set_winner(player)
            return g_moves, b_moves, n_moves

In [131]:
def play_with_player(playerturn: bool = True):
    TTT.reset_board()
    gameover = False
    print(TTT.board.reshape(3, 3), end="\n\n")
    while not gameover:
        is_over, player = TTT.check_if_game_over()
        if not is_over:
            if playerturn:
                player_index = int(input(f"{TTT.board.reshape(3, 3)}\n\nWhere to go? "))
                TTT.player_make_turn(player_index)
                print(TTT.board.reshape(3, 3), end="\n\n")
                playerturn = False
            else:
                g_state, l_matrix = TTT.get_game_state()
                raw_logits = model(g_state, l_matrix)
                softmax_logits = model.softmax(raw_logits)
                choice = torch.argmax(softmax_logits).item()
                TTT.player_make_turn(choice)
                # print(TTT.board.reshape(3, 3), end="\n\n")
                playerturn = True
        else:
            gameover = True
            g_moves, b_moves, n_moves = TTT.set_winner(player)
            if playerturn:
                print("Machine won")
            else:
                print("Human won")
            # return g_moves, b_moves, n_moves

In [132]:
import random


def play_against_bot(rounds):
    wins, draws, losses = 0, 0, 0
    for i in range(rounds):
        if (i + 1) % 10 == 0: print(f"Round {i + 1}")
        machinefirst = random.choice([True, False])
        machineturn = True if machinefirst else False
        TTT.reset_board()
        while True:
            is_over, player = TTT.check_if_game_over()
            if is_over:
                break
            else:
                g_state, l_matrix = TTT.get_game_state()
                if machineturn:
                    raw_logits = model(g_state, l_matrix)
                    softmax_logits = model.softmax(raw_logits)
                    choice = torch.argmax(softmax_logits).item()
                    TTT.player_make_turn(choice)
                    machineturn = False
                else:
                    proper_choice = False
                    while not proper_choice:
                        random_index = random.randint(0, 8)
                        if l_matrix[random_index] == 1:
                            proper_choice = True
                    TTT.player_make_turn(random_index)
                    machineturn = True
        is_over, player = TTT.check_if_game_over()
        if player == 0:
            draws += 1
        elif (player == 1 and machinefirst) or (player == -1 and not machinefirst):
            wins += 1
        else:
            losses += 1

    print(f"Wins: {wins:,} | Draws: {draws:,} | Losses: {losses:,}")
    print(f"WR: {(wins/rounds)*100:.2f}% | DR: {(draws/rounds)*100:.2f}% | LR: {(losses/rounds)*100:.2f}%")

In [133]:
import torch

EPOCHS = 10000
LEARNING_RATE = 0.01

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    if (epoch + 1) % 10 == 0: print(f"E {epoch + 1:,} - {((epoch + 1) / EPOCHS) * 100:.2f}%")
    model.eval()

    with torch.no_grad():
        g, b, n = play_machine_only()
        good, bad, neutral = [], [], []
        for index, (g_state, p_choice, l_matrix) in enumerate(g):
            onehot = torch.zeros(9, device=device)
            onehot[p_choice] = 1.0
            good.append((g_state, onehot, l_matrix))
        for index, (g_state, p_choice, l_matrix) in enumerate(b):
            onehot = torch.ones(9, device=device)
            onehot[p_choice] = 0.0
            bad.append((g_state, onehot, l_matrix))
        for index, (g_state, p_choice, l_matrix) in enumerate(n):
            onehot = torch.full(size=(9,), fill_value=0.5, device=device)
            onehot[p_choice] = 0.0
            neutral.append((g_state, onehot, l_matrix))
        game_positions = good + bad + neutral

    model.train()
    optimizer.zero_grad()

    for index, (g, p, l) in enumerate(game_positions):
        outputs = model(g, l)
        loss = loss_fn(outputs, p)
        loss.backward()

    optimizer.step()

E 10 - 0.10%
E 20 - 0.20%
E 30 - 0.30%
E 40 - 0.40%
E 50 - 0.50%
E 60 - 0.60%
E 70 - 0.70%
E 80 - 0.80%
E 90 - 0.90%
E 100 - 1.00%
E 110 - 1.10%
E 120 - 1.20%
E 130 - 1.30%
E 140 - 1.40%
E 150 - 1.50%
E 160 - 1.60%
E 170 - 1.70%
E 180 - 1.80%
E 190 - 1.90%
E 200 - 2.00%
E 210 - 2.10%
E 220 - 2.20%
E 230 - 2.30%
E 240 - 2.40%
E 250 - 2.50%
E 260 - 2.60%
E 270 - 2.70%
E 280 - 2.80%
E 290 - 2.90%
E 300 - 3.00%
E 310 - 3.10%
E 320 - 3.20%
E 330 - 3.30%
E 340 - 3.40%
E 350 - 3.50%
E 360 - 3.60%
E 370 - 3.70%
E 380 - 3.80%
E 390 - 3.90%
E 400 - 4.00%
E 410 - 4.10%
E 420 - 4.20%
E 430 - 4.30%
E 440 - 4.40%
E 450 - 4.50%
E 460 - 4.60%
E 470 - 4.70%
E 480 - 4.80%
E 490 - 4.90%
E 500 - 5.00%
E 510 - 5.10%
E 520 - 5.20%
E 530 - 5.30%
E 540 - 5.40%
E 550 - 5.50%
E 560 - 5.60%
E 570 - 5.70%
E 580 - 5.80%
E 590 - 5.90%
E 600 - 6.00%
E 610 - 6.10%
E 620 - 6.20%
E 630 - 6.30%
E 640 - 6.40%
E 650 - 6.50%
E 660 - 6.60%
E 670 - 6.70%
E 680 - 6.80%
E 690 - 6.90%
E 700 - 7.00%
E 710 - 7.10%
E 720 - 7.20%
E

In [134]:
play_against_bot(1000)

Round 10
Round 20
Round 30
Round 40
Round 50
Round 60
Round 70
Round 80
Round 90
Round 100
Round 110
Round 120
Round 130
Round 140
Round 150
Round 160
Round 170
Round 180
Round 190
Round 200
Round 210
Round 220
Round 230
Round 240
Round 250
Round 260
Round 270
Round 280
Round 290
Round 300
Round 310
Round 320
Round 330
Round 340
Round 350
Round 360
Round 370
Round 380
Round 390
Round 400
Round 410
Round 420
Round 430
Round 440
Round 450
Round 460
Round 470
Round 480
Round 490
Round 500
Round 510
Round 520
Round 530
Round 540
Round 550
Round 560
Round 570
Round 580
Round 590
Round 600
Round 610
Round 620
Round 630
Round 640
Round 650
Round 660
Round 670
Round 680
Round 690
Round 700
Round 710
Round 720
Round 730
Round 740
Round 750
Round 760
Round 770
Round 780
Round 790
Round 800
Round 810
Round 820
Round 830
Round 840
Round 850
Round 860
Round 870
Round 880
Round 890
Round 900
Round 910
Round 920
Round 930
Round 940
Round 950
Round 960
Round 970
Round 980
Round 990
Round 1000
Wins: 60