In [1]:
import random
from game import Game, Move, Player
from copy import deepcopy
from tqdm import tqdm
import numpy as np
import math

In [2]:
class RandomPlayer(Player):
    def __init__(self) -> None:
        super().__init__()

    def make_move(self, game: 'Game') -> tuple[tuple[int, int], Move]:
        from_pos = (random.randint(0, 4), random.randint(0, 4))
        move = random.choice([Move.TOP, Move.BOTTOM, Move.LEFT, Move.RIGHT])
        return from_pos, move

In [3]:
class MyGame(Game):
    def __init__(self) -> None:
        super().__init__()

    
    def print(self):
        '''Prints the board. -1 are neutral pieces, 0 are pieces of player 0, 1 pieces of player 1'''
        for row in self._board:
            for cell in row:
                if cell == -1:
                    print('\033[90m', "-", '\033[0m', end=' ')  # Grey color for -1
                elif cell == 0:
                    print('\033[91m', "X", '\033[0m', end=' ')  # Red color for 0
                elif cell == 1:
                    print('\033[92m', "O", '\033[0m', end=' ')  # Green color for 1
                else:
                    print(cell, end=' ')
            print()
    
    def get_legal_moves(self) -> list[tuple[tuple[int, int], Move]]:
        legal_moves = []
        for i in range(5):
            for j in range(5):
                if self._board[i][j] == self.get_current_player() or self._board[i][j] == -1:
                    for move in self.correct_slide((j,i)):
                        from_pos = (i, j)  
                        legal_moves.append((from_pos, move))
        return legal_moves
    
    def correct_slide(self, from_pos):
        acceptable_slides = [Move.BOTTOM, Move.TOP, Move.LEFT, Move.RIGHT]
        axis_0 = from_pos[0]    # axis_0 = 0 means uppermost row
        axis_1 = from_pos[1]    # axis_1 = 0 means leftmost column

        if axis_0 == 0:  # can't move upwards if in the top row...
            acceptable_slides.remove(Move.TOP)
        elif axis_0 == 4:
            acceptable_slides.remove(Move.BOTTOM)

        if axis_1 == 0:
            acceptable_slides.remove(Move.LEFT)
        elif axis_1 == 4:
            acceptable_slides.remove(Move.RIGHT)
        return acceptable_slides


In [4]:
class Node:
    def __init__(self, move=None, parent=None, player=None) -> None:
        self.children = []
        self.parent=parent
        self.n_use = 0  # how many times this move has been used
        self.n_wins = 0
        self.move = move
        self.player = player
    
    def get_percentage(self):   # gives the percentage of wins
        return self.n_wins / self.n_use 

In [6]:
class MyPlayer(Player):
    def __init__(self, n_training=200) -> None:
        super().__init__()
        self.n_train = n_training
        self.n_simulations = 10
        self.tree = self.training()

    def make_move(self, game: 'Game') -> tuple[tuple[int, int], Move]:

        best_move = self.get_best_move(game)
        return best_move
     

    def training(self):
        root = Node()

        for _ in tqdm(range(self.n_train), total=self.n_train):
            game = MyGame()
            node = root
            for _ in range(self.n_simulations):
                
                
                while node.children:
                    node = max(node.children, key=self.uct)
                
                legal_moves = game.get_legal_moves()
                if not legal_moves:
                    break
                random_move = random.choice(legal_moves)
                
                node= self.expand(node, random_move, game)
                winner = self.simulate(game)
                self.backpropagate(node, winner)
        return root

    def select(self, node):
        while node.children:
            node = max(node.children, key=self.uct)
        return node

    def expand(self, node, move, game):
        new_node = Node(move, parent=node, player=game.get_current_player())
        node.children.append(new_node)
        # Simulate the move in the game
        game._Game__move((move[0][1], move[0][0]), move[1], game.get_current_player())
        return new_node

    def simulate(self, game):
        cloned_game = deepcopy(game)
        current_player = cloned_game.get_current_player()

        while cloned_game.check_winner() == -1:
            legal_moves = cloned_game.get_legal_moves()
            if not legal_moves:
                break
            random_move = random.choice(legal_moves)
            cloned_game._Game__move((random_move[0][1], random_move[0][0]), random_move[1], current_player)
            current_player = 1 - current_player  # Switch player

        return cloned_game.check_winner()

    def backpropagate(self, node, winner):
        while node.player:
            node.n_use += 1
            if winner == node.player:
                node.n_wins += 1
            node = node.parent

    def uct(self, node):
        if node.n_use == 0 or node.parent.n_use == 0:
            return float('inf')

        return node.get_percentage() + 1.41 * (2 * math.log(node.parent.n_use) / node.n_use) ** 0.5

    def get_best_move(self, game):
        root = self.tree
        current_player = game.get_current_player()

        # Find the node corresponding to the current game state
        current_node = None
        for child in root.children:
            if child.move and child.player == current_player and (game.get_board()[child.move[0][1]][child.move[0][0]] == current_player or game.get_board()[child.move[0][1]][child.move[0][0]] == -1):
                current_node = child
                break

        if not current_node:
            # If the node does not exist, choose a random legal move
            print("random")
            return RandomPlayer().make_move(game)

        # Select the best move based on the UCT value
        best_child = max(current_node.children, key=lambda c: c.get_percentage())
        return best_child.move

In [7]:
player1 = MyPlayer()

100%|██████████| 200/200 [01:31<00:00,  2.20it/s]


In [None]:
g = Game()
player2 = RandomPlayer()
winner = g.play(player1, player2)
g.print()
print(f"Winner: Player {winner}")