In [272]:
import random
import copy

class Connect4:
    def __init__(self, agent_player='O'):
        self.board = [[' ' for _ in range(7)] for _ in range(6)]
        self.current_player = 'X'  # Player X always starts
        self.game_over = False
        self.agent_player = agent_player

    def display_board(self):
        # Display the board
        for row in self.board:
            print('|' + '|'.join(row) + '|')
        print('+-' * 7 + '+')
        
    def is_board_full(self):
        # Check if the board is full
        return all(self.board[0][col] != ' ' for col in range(7))
    
    # Check if top most cell in that column is empty
    def get_legal_moves(self):
        return [col for col in range(7) if self.board[0][col] == ' ']

    def drop_piece(self, column):
        # Check for valid column
        if not 0 <= column < 7 or self.board[0][column] != ' ':
            return False  # Invalid move

        # Find the lowest empty space in the column
        for row in range(5, -1, -1):
            if self.board[row][column] == ' ':
                self.board[row][column] = self.current_player
                return True
        return False

    def check_winner(self):
        # Check horizontal, vertical and diagonal lines for a win
        for row in range(6):
            for col in range(7):
                if self.board[row][col] != ' ':
                    if self.check_line(row, col, 1, 0) or \
                       self.check_line(row, col, 0, 1) or \
                       self.check_line(row, col, 1, 1) or \
                       self.check_line(row, col, 1, -1):
                        return self.board[row][col]
        return None

    def check_line(self, start_row, start_col, d_row, d_col):
        # Check a line of 4 pieces starting from (start_row, start_col) in direction (d_row, d_col)
        for i in range(1, 4):
            r = start_row + d_row * i
            c = start_col + d_col * i
            if not (0 <= r < 6 and 0 <= c < 7) or self.board[r][c] != self.board[start_row][start_col]:
                return False
        return True

    def switch_player(self):
        self.current_player = 'O' if self.current_player == 'X' else 'X'
        
    def strategy_player(self):
        valid_move = False
        while not valid_move:
            valid_move = True
            try:
                column = int(input(f"Player {self.current_player}, choose a column (0-6): "))
            except ValueError:
                print("Invalid input. Please enter a number between 0 and 6.")
                valid_move = False
                continue

            if not self.drop_piece(column):
                print("Invalid move. Try again.")
                valid_move = False
                continue
    
    def strategy_random_agent(self):
        # Random agent's turn
        valid_move = False
        while not valid_move:
            column = random.randint(0, 6)
            valid_move = self.drop_piece(column)
        print(f"Random agent (Player {self.agent_player}) chose column: {column}")
        
    def terminal_state(self):
        winner = self.check_winner()
        if winner:
            self.game_over = True
            return True
        elif self.is_board_full():
            self.game_over = True
            return True
        
        return False
        
        
    def tree_policy(self, node):
        # random selection
        # TODO : Add UCB Selection
        return random.choice(node.children)
    
    def simulation_policy(self, game_state):
        # Pick random move
        rand_move = random.choice(game_state.get_legal_moves())
        
        # Play move
        game_state.drop_piece(rand_move)
    
        # Switch player
        game_state.switch_player()
        
        return game_state
    
    def evaluate(self, game_state):
        winner = self.check_winner()
        if winner == 'O': # agent
            # print('Agent won')
            return True
        
        return False
        
    # Following pseudo-code : https://webdocs.cs.ualberta.ca/~hayward/396/jem/mcts.html#mcts
    def strategy_mcts(self):
        # simulation steps, depth
        root_node = Node(None, None)
        
        computation_budget = 10
        while computation_budget > 0:
            computation_budget -= 1
            game_state = copy.deepcopy(self)
            node = root_node
            
            # selection (select leaf)
            while not node.is_leaf():
                if node.move is not None: 
                    game_state.drop_piece(node.move)
                node = game_state.tree_policy(node) 
            
            # expansion
            node.expand_node(game_state)
            node = game_state.tree_policy(node)
            
            # simulation (rollout)
            while not game_state.terminal_state():
                game_state = game_state.simulation_policy(game_state)
                
            result = game_state.evaluate(game_state)
            
            # backup (backpropagation)
            while node.has_parent():
                node.update(result)
                node = node.parent
                
        return root_node.best_move().move

    def play_game(self):
        while not self.game_over:
            self.display_board()
            
            if self.current_player == 'O': # self.agent_player
                # self.strategy_random_agent()
                move = self.strategy_mcts()
                self.drop_piece(move)
            else:
                # self.strategy_random_agent()
                self.strategy_player()
                    
            winner = self.check_winner()
            if winner:
                self.display_board()
                print(f"Player {winner} wins!")
                self.game_over = True
            elif self.is_board_full():
                self.display_board()
                print("The game is a draw!")
                self.game_over = True
            else:
                self.switch_player()

# Create a game instance and start playing
game = Connect4()
game.play_game()
# game.strategy_mcts()

| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
|X| | | | | | |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
|X|O| | | | | |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
|X| | | | | | |
|X|O| | | | | |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
|X| | | | | | |
|X|O|O| | | | |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
| | | | | | | |
|X| | | | | | |
|X| | | | | | |
|X|O|O| | | | |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
|O| | | | | | |
|X| | | | | | |
|X| | | | | | |
|X|O|O| | | | |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
|O| | | | | | |
|X| | | | | | |
|X|X| | | | | |
|X|O|O| | | | |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
|O| | | | | | |
|X|O| | | | | |
|X|X| | | | | |
|X|O|O| | | | |
+-+-+-+-

IndexError: Cannot choose from an empty sequence

In [271]:
class Node:
    def __init__(self, m, p):
        self.move = m
        self.parent = p
        self.children = []
        self.wins = 0
        self.visits = 0
    
    def expand_node(self, game_state):
        if not game_state.terminal_state():
            legal_moves = game_state.get_legal_moves()
            for i in range(len(legal_moves)):
                node_child = Node(legal_moves[i], self)
                self.children.append(node_child)
    
    def update(self, r):
        self.visits += 1
        if r == True:
            self.wins += 1
            
    def is_leaf(self):
        return len(self.children) == 0
    
    def has_parent(self):
        return self.parent is not None
    
    def best_move(self):
        if not self.children:
            return None
        
        best_child = None
        best_win_ratio = -1
        
        for child in self.children:
            if child.visits > 0:
                win_ratio = child.wins / child.visits
                if win_ratio > best_win_ratio:
                    best_win_ratio = win_ratio
                    best_child = child
        return best_child
        