In [319]:
import random
import copy

class Connect4:
    def __init__(self, agent_player='O'):
        self.board = [[' ' for _ in range(7)] for _ in range(6)]
        self.current_player = 'X'  # Player X always starts
        self.game_over = False
        self.agent_player = agent_player

    def display_board(self):
        # Display the board
        for row in self.board:
            print('|' + '|'.join(row) + '|')
        print('+-' * 7 + '+')
        
    def is_board_full(self):
        # Check if the board is full
        return all(self.board[0][col] != ' ' for col in range(7))
    
    # Check if top most cell in that column is empty
    def get_legal_moves(self):
        return [col for col in range(7) if self.board[0][col] == ' ']

    def drop_piece(self, column):
        # Check for valid column
        if not 0 <= column < 7 or self.board[0][column] != ' ':
            return False  # Invalid move

        # Find the lowest empty space in the column
        for row in range(5, -1, -1):
            if self.board[row][column] == ' ':
                self.board[row][column] = self.current_player
                return True
        return False

    def check_winner(self):
        # Check horizontal, vertical and diagonal lines for a win
        for row in range(6):
            for col in range(7):
                if self.board[row][col] != ' ':
                    if self.check_line(row, col, 1, 0) or \
                       self.check_line(row, col, 0, 1) or \
                       self.check_line(row, col, 1, 1) or \
                       self.check_line(row, col, 1, -1):
                        return self.board[row][col]
        return None

    def check_line(self, start_row, start_col, d_row, d_col):
        # Check a line of 4 pieces starting from (start_row, start_col) in direction (d_row, d_col)
        for i in range(1, 4):
            r = start_row + d_row * i
            c = start_col + d_col * i
            if not (0 <= r < 6 and 0 <= c < 7) or self.board[r][c] != self.board[start_row][start_col]:
                return False
        return True

    def switch_player(self):
        self.current_player = 'O' if self.current_player == 'X' else 'X'


    def terminal_state(self):
        winner = self.check_winner()
        if winner:
            self.game_over = True
            return True
        elif self.is_board_full():
            self.game_over = True
            return True
        
        return False
        
    
        return False
    
    def clone(self):
        return copy.deepcopy(self)


# Inspired by: 
# https://webdocs.cs.ualberta.ca/~hayward/396/jem/mcts.html#mcts
class MCTS_Strategy():
    
    def __init__(self, computation_budget=10, player='0'):
      self.computation_budget = computation_budget
      self.player = player
    
    def policy(self, game):
        # simulation steps, depth
        root_node = Node(None, None)
        
        self.computation_budget = 10
        while self.computation_budget > 0:
            self.computation_budget -= 1
            game_state = game.clone()
            node = root_node
            
            # selection (select leaf)
            while not node.is_leaf():
                if node.move is not None: 
                    game_state.drop_piece(node.move)
                node = self.tree_policy(node) 
            
            # expansion
            node.expand_node(game_state)
            
            if not game_state.terminal_state():
                node = self.tree_policy(node)
            
            # simulation (rollout)
            while not game_state.terminal_state():
                game_state = self.simulation_policy(game_state)
                
            result = self.evaluate(game_state)
            
            # backup (backpropagation)
            while node.has_parent():
                node.update(result)
                node = node.parent
                
        return root_node.best_move().move

    def tree_policy(self, node):
        # random selection
        # TODO : Add UCB Selection
        return random.choice(node.children)
        
    def simulation_policy(self, game):
        # Pick random move
        rand_move = random.choice(game.get_legal_moves())
        
        # Play move
        game.drop_piece(rand_move)

        # Switch player
        game.switch_player()
        
        return game

    def evaluate(self, game):
        winner = game.check_winner()
        if winner == self.player:
            return True
        
def play_game():
    game = Connect4()
    
    while not game.game_over:
        game.display_board()
        
        if game.current_player == 'X': # random player
            strategy_random_agent(game)
        else:
            # game.strategy_random_agent()
            mcts = MCTS_Strategy(10, 'O')
            move = mcts.policy(game)
            game.drop_piece(move)
                  
        winner = game.check_winner()
        if winner:
            game.display_board()
            print(f"Player {winner} wins!")
            game.game_over = True
        elif game.is_board_full():
            game.display_board()
            print("The game is a draw!")
            game.game_over = True
        else:
            game.switch_player()
            
def strategy_player(game):
    valid_move = False
    while not valid_move:
        valid_move = True
        try:
            column = int(input(f"Player {game.current_player}, choose a column (0-6): "))
        except ValueError:
            print("Invalid input. Please enter a number between 0 and 6.")
            valid_move = False
            continue

        if not game.drop_piece(column):
            print("Invalid move. Try again.")
            valid_move = False
            continue
        
def strategy_random_agent(game):
    # Random agent's turn
    valid_move = False
    while not valid_move:
        column = random.randint(0, 6)
        valid_move = game.drop_piece(column)
    print(f"Random agent (Player {game.agent_player}) chose column: {column}")


# Create a game instance and start playing

play_game()
# game.strategy_mcts()


# Gym api
# Obs, Reward, Terminated

| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
+-+-+-+-+-+-+-+
Random agent (Player O) chose column: 1
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| |X| | | | | |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
|O|X| | | | | |
+-+-+-+-+-+-+-+
Random agent (Player O) chose column: 5
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
|O|X| | | |X| |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
|O|X| | |O|X| |
+-+-+-+-+-+-+-+
Random agent (Player O) chose column: 1
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| |X| | | | | |
|O|X| | |O|X| |
+-+-+-+-+-+-+-+
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
|O|X| | | | | |
|O|X| | |O|X| |
+-+-+-+-+-+-+-+
Random agent (Player O) chose column: 5
| | | | | | | |
| | | | | | | |
| | | | | | | |
| | | | 

In [317]:
class Node:
    def __init__(self, m, p):
        self.move = m
        self.parent = p
        self.children = []
        self.wins = 0
        self.visits = 0
    
    def expand_node(self, game_state):
        if not game_state.terminal_state():
            legal_moves = game_state.get_legal_moves()
            for i in range(len(legal_moves)):
                node_child = Node(legal_moves[i], self)
                self.children.append(node_child)
    
    def update(self, r):
        self.visits += 1
        if r == True:
            self.wins += 1
            
    def is_leaf(self):
        return len(self.children) == 0
    
    def has_parent(self):
        return self.parent is not None
    
    def best_move(self):
        if not self.children:
            return None
        
        best_child = None
        best_win_ratio = -1
        
        for child in self.children:
            if child.visits > 0:
                win_ratio = child.wins / child.visits
                if win_ratio > best_win_ratio:
                    best_win_ratio = win_ratio
                    best_child = child
        return best_child
        

In [320]:
import random
import copy
import numpy as np
 
 
class Connect4:
    def __init__(self):
        self.state = np.zeros((6, 7))
        self.game_over = False
        self._random_opponent_move(self.state)
 
    def step(self, action):
        valid_move = self._drop_piece(self.state, action, 1)
 
        if not valid_move:
            raise Exception("Invalid move")
 
        self._random_opponent_move(self.state)
        reward = self._check_winner(self.state)
 
        if reward != 0 or self._is_board_full(self.state):
            self.game_over = True
 
        return reward, self.game_over
 
    @classmethod
    def get_legal_moves(cls, state):
        # Check if top most cell in that column is empty
        return [col for col in range(7) if state[0][col] == 0]
 
    def random_rollout(self):
        # Random rollout policy
        rollout_state = copy.deepcopy(self.state)
        rollout_game_over = False
        while not rollout_game_over:
            rollout_valid_move = False
            while not rollout_valid_move:
                rollout_column = random.randint(0, 6)
                rollout_valid_move, rollout_state = self._drop_piece(
                    rollout_state, rollout_column, 1
                )
            rollout_winner = self._check_winner(rollout_state)
            if rollout_winner or self._is_board_full(rollout_state):
                rollout_game_over = True
        return rollout_winner
 
    @classmethod
    def _random_opponent_move(cls, state):
        # Random agent's turn
        valid_move = False
        while not valid_move:
            column = random.randint(0, 6)
            valid_move = cls._drop_piece(state, column, -1)
 
    @classmethod
    def _drop_piece(cls, state, column, current_player):
        # Check for valid column
        if not 0 <= column < 7 or state[0][column] != 0:
            return False  # Invalid move
 
        # Find the lowest empty space in the column
        for row in range(5, -1, -1):
            if state[row][column] == 0:
                state[row][column] = current_player
                return True
        return False
 
    # @classmethod
    # def _is_board_full(cls, state):
    #     # Check if the board is full
    #     return all(state[0][col] != 0 for col in range(7))
 
    @classmethod
    def _check_winner(cls, state):
        # Check horizontal, vertical and diagonal lines for a win
        for row in range(6):
            for col in range(7):
                if state[row][col] != 0:
                    if (
                        cls._check_line(state, row, col, 1, 0)
                        or cls._check_line(state, row, col, 0, 1)
                        or cls._check_line(state, row, col, 1, 1)
                        or cls._check_line(state, row, col, 1, -1)
                    ):
                        return state[row][col]
        return 0
 
    @classmethod
    def _check_line(cls, state, start_row, start_col, d_row, d_col):
        # Check a line of 4 pieces starting from (start_row, start_col) in direction (d_row, d_col)
        for i in range(1, 4):
            r = start_row + d_row * i
            c = start_col + d_col * i
            if (
                not (0 <= r < 6 and 0 <= c < 7)
                or state[r][c] != state[start_row][start_col]
            ):
                return False
        return True
 