In [1]:
import numpy as np
from tqdm import tqdm
from copy import copy, deepcopy

In [2]:
import random


class TicTacToe:
    # Inspiration: https://geekflare.com/tic-tac-toe-python-code/
    def __init__(self):
        self.board = []
        self.last_player = None
        self.current_player = None
    def create_board(self):
        for i in range(3):
            row = []
            for j in range(3):
                row.append('-')
            self.board.append(row)
        self.board = np.array(self.board)

    def get_random_first_player(self):
        return random.randint(0, 1)

    def fix_spot(self, board, row, col, player):
        new_board = board
        new_board[row][col] = player
        return new_board

    def is_player_win(self, player):
        win = None
        n = len(self.board)
        # checking rows
        for i in range(n):
            win = True
            for j in range(n):
                if self.board[i][j] != player:
                    win = False
                    break
            if win:
                return win
        # checking columns
        for i in range(n):
            win = True
            for j in range(n):
                if self.board[j][i] != player:
                    win = False
                    break
            if win:
                return win
        # checking diagonals
        win = True
        for i in range(n):
            if self.board[i][i] != player:
                win = False
                break
        if win:
            return win
        win = True
        for i in range(n):
            if self.board[i][n - 1 - i] != player:
                win = False
                break
        if win:
            return win
        return False

        for row in self.board:
            for item in row:
                if item == '-':
                    return False
        return True

    def is_board_filled(self):
        for row in self.board:
            for item in row:
                if item == '-':
                    return False
        return True

    def is_board_almost_filled(self):
        count = 0
        for row in self.board:
            for item in row:
                if item == '-':
                    count += 1
        if count == 1:
            return True
        else: 
            return False

    def swap_player_turn(self, player):
        return 'X' if player == 'O' else 'O'

    def show_board(self):
        for row in self.board:
            for item in row:
                print(item, end=" ")
            print()
        print()

    def create_start_position(self):
        self.board[1][0] = 'X'
        self.board[1][1] = 'X'
        self.board[0][1] = 'O'
        self.board[1][2] = 'O'

    def create_start_almost_full_position(self):
        self.board[0][0] = 'X'
        self.board[0][1] = 'O'
        self.board[0][2] = 'O'
        self.board[1][0] = 'X'
        self.board[1][1] = 'X'
        self.board[1][2] = 'X'
        self.board[2][0] = 'O'
        self.board[2][1] = 'O'

    def start_full(self, do_print = False):
        self.create_board() # empty board
        self.last_player = 'O' # x starts, so current position is a result of 'O'
        self.current_player = 'X'
        self.create_start_almost_full_position() # create start position
        print(self.valid_moves())
        print(self.pick_random_move())
        if do_print: self.show_board()

    def valid_moves(self):
        x = np.where(self.board == '-')
        moves = np.asarray(x).T
        return moves

    def pick_random_move(self):
        moves = self.valid_moves()
        random_index = np.random.choice(moves.shape[0], size=1, replace=False)
        random_move = moves[random_index][0]
        return random_move

    def return_random_rollout(self, board, last_player, current_player, do_print=False):
        while True:
            move = self.pick_random_move()
            board = self.fix_spot(board, move[0], move[1], current_player)
            if self.is_player_win(current_player):
                if do_print: print(f"Player {current_player} wins the game!")
                if current_player == 'X': reward = 1 ; win = 1
                if current_player == 'O': reward = -1 ; win = 0
                break
            if self.is_board_filled():
                if do_print: print("Match Draw!")
                reward = win = 0
                break
            # Swap turn
            current_player = self.swap_player_turn(current_player)
            last_player = self.swap_player_turn(last_player)
            if do_print: self.show_board()
        if do_print: self.show_board() # show winning board
        if do_print: print([win, reward])
        return [win, reward]

    def start(self, do_print = False):
        self.create_board() # empty board
        self.last_player = 'O' # x starts, so current position is a result of 'O'
        self.current_player = 'X'
        self.create_start_position() # create start position
        if do_print: self.show_board()

    def start_random(self, do_print = False):
        self.start(do_print=do_print)
        return self.return_random_rollout(self.board, self.last_player, self. current_player ,do_print=do_print )

    def start_MCTS(self, do_print = False):
        self.start(do_print=do_print)
        root_node = Node(self.board, self.last_player, self.current_player)
        mcts = MCTS(root_node)

        while True:
            if self.current_player == 'X':
                #### MCTS
                #print(mcts.root_nodes[-1].board)
                if self.is_board_almost_filled():
                    pass
                #for _ in range(50):
                mcts.take_step()
                mcts.choose_new_root_node() # if only played once, : [0,2] always chosen (since 0,0) becomes first maxchild, and [0,2] second since UCB of [0,0] no longer infinite
                move = mcts.root_nodes[-1].move
                #print(mcts.root_nodes[-1].board)
                ####
            if self.current_player == 'O':
                move = self.pick_random_move()
            self.board = self.fix_spot(self.board, move[0], move[1], self.current_player)
            if self.current_player == 'O':
                enemy_node = Node(self.board, self.swap_player_turn(self.last_player), self.swap_player_turn(self.current_player), parent = mcts.root_nodes[-1], move=move)
                enemy_node.played = True
                mcts.root_nodes.append(enemy_node)

            if self.is_player_win(self.current_player):
                if do_print: print(f"Player {self.current_player} wins the game!")
                if self.current_player == 'X': reward = 1 ; win = 1
                if self.current_player == 'O': reward = -1 ; win = 0
                break
            if self.is_board_filled():
                if do_print: print("Match Draw!")
                reward = win = 0
                break
            # Swap turn
            self.current_player = self.swap_player_turn(self.current_player)
            self.last_player = self.swap_player_turn(self.last_player)
            if do_print: self.show_board()
        if do_print: self.show_board() # show winning board
        if do_print: print([win, reward])
        return [win, reward]




# starting the game
tic_tac_toe = TicTacToe()
winner = tic_tac_toe.start_random()

In [3]:
tic_tac_toe = TicTacToe()
#winner = tic_tac_toe.start_random(do_print=True)
winner = tic_tac_toe.start_MCTS(do_print=True)

- O - 
X X O 
- - - 



NameError: name 'Node' is not defined

In [None]:
tic_tac_toe = TicTacToe()
print(tic_tac_toe.start_full())

TODO:
-check waarom self.board van 2e root-node al vol is.
-Maak expliciet onderscheid tussen x en o nodes, haal dubbele functies uit TicTacToe en MCTS. Duidelijk onderscheid MCTS tree en spel tictactoe

In [None]:
scores = np.zeros(3) # X, draw, O
for i in tqdm(range(10000)):
    tic_tac_toe = TicTacToe()
    win_rew = tic_tac_toe.start_random()
    if win_rew[1] == 1:        
        scores[0] += 1
    if win_rew[1] == 0:        
        scores[1] += 1
    if win_rew[1] == -1:        
        scores[2] += 1

print("X, draw, O") # [ 13475.  53609. -13475.]
print(scores)
    

In [None]:
class Node(TicTacToe):
    def __init__(self, board, last_player, current_player,  parent=None, move=None):
        self.board = board
        self.last_player = last_player
        self.current_player = current_player
        self.played = False
        self.children = dict()
        self.parent = parent
        self.reward = 0
        self.wins = 0
        self.visits = 0
        self.UCB = np.inf # correct???
        self.move = move # With what move did we come to this board?

    def is_terminal(self):
        if self.is_player_win(self.current_player) or self.is_board_filled():
            return True
        else:
            return False

    def reward_board(self):
        if self.is_terminal():
            if self.is_player_win('X'): # if current player wins
                return [1, 1]
            elif self.is_player_win('O'):
                return [0, -1]
            elif self.is_board_filled():
                return [0, 0]

    def generate_all_possible_children(self):
        all_moves = self.valid_moves()
        
        for move in all_moves:
            own_board = deepcopy(self.board)
            #player_child = self.swap_player_turn(self.player)
            board_child = self.fix_spot(own_board, move[0], move[1], self.current_player)
            # Last player and next player swapped arround
            node_child = Node(board_child, self.swap_player_turn(self.last_player), self.swap_player_turn(self.current_player), parent = self, move=move) ## CORRECT??
            node_child.played = True
            self.children[node_child] = node_child # add child  to dictionary of MCTS

        ##########################
    
    def update_UCB(self):
        if self.parent:
            self.UCB = self.wins/self.visits + np.sqrt(2) * np.sqrt(np.log(self.parent.visits)/self.visits)

    def return_max_UCB(self):
        maxx = -1
        for child in self.children:
            if child.UCB > maxx:
                maxchild = child
                maxx = child.UCB
        #print(self.move)
        return maxchild
        if self.children:
            return max(self.children, key=self.children.UCB) #?

class MCTS:
    def __init__(self, root):
        self.root_nodes = []
        self.nodes = dict()
        self.nodes[root] = root
        self.root_nodes.append( root )

    def choose_new_root_node(self):
        root = self.root_nodes[-1]
        if len(root.children) > 0:
            new_root = root.return_max_UCB() # return child with highest UCB score (initialises at inf)
            self.root_nodes.append(new_root)

    def take_step(self):
        #print("MCTS: Take Step")
        node = self.root_nodes[-1] # get last root node
        path = self.selection(node) # 1
        leaf_node = path[-1]
        child = self.expansion(leaf_node) # 2
        path.append(child)
        if not child.is_terminal():
            win_rew = self.simulation(child) # 3
        else: 
            win_rew = child.reward_board()
        self.backtrack(path, win_rew) # 4

    def selection(self, node):
        #print("MCTS: 1. Selection")
        path = []
        while True:
            path.append(node)
            if len(node.children) == 0: # no children
                return path
            node = node.return_max_UCB() # return child with highest UCB score (initialises at inf)
            
    def expansion(self, node):
        #print("MCTS: 2. Expansion")
        # Add child with max UCB
        node.generate_all_possible_children()
        max_child = node.return_max_UCB()
        self.nodes[max_child] = max_child 
        return max_child

    def simulation(self, node):
        #print("MCTS: 3. Simulation")
        node_copy = deepcopy(node)#Node(node.board, node.swap_player_turn(node.player) ) # otherwise random rollout happens on the nodes board
        win_rew = node_copy.return_random_rollout(node_copy.board, node_copy.swap_player_turn(node_copy.last_player), node_copy.swap_player_turn(node_copy.current_player)) # swap player since other player's turn
        return win_rew

    def backtrack(self, path, win_rew):
        #print("MCTS: 4. Backtrack")
        for node in reversed(path):
            node.wins += win_rew[0]
            node.reward += win_rew[1]
            node.visits += 1
        # first all wins need to be updated
        for node in reversed(path):
            node.update_UCB()