In [1]:
!pip install stable-baselines3


Collecting stable-baselines3
  Downloading stable_baselines3-2.5.0-py3-none-any.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3.0,>=2.3->stable-baselines3)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (

In [9]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import math
from stable_baselines3 import DQN


In [4]:
class Connect4Game:
    """Connect4 game environment."""
    def __init__(self):
        self.rows = 6
        self.cols = 7
        self.board = np.zeros((self.rows, self.cols), dtype=int)
        self._current_player = 1  # Internal attribute to track the current player
        self.history = []  # Track move history for delete_last_action

    def get_state(self):
        """Returns the current state of the game as a hashable object."""
        return tuple(map(tuple, self.board))

    def number_of_players(self):
        """Returns the number of players."""
        return 2

    def current_player(self):
        """Returns the current player."""
        return self._current_player

    def possible_actions(self):
        """Returns a list of valid actions (columns that are not full)."""
        return [col for col in range(self.cols) if self.board[0, col] == 0]

    def take_action(self, action):
        """
        Applies an action (dropping a piece in a column) to the game state.
        Returns True if the action was valid, False otherwise.
        """
        if action not in self.possible_actions():
            return False
        for row in reversed(range(self.rows)):
            if self.board[row, action] == 0:
                self.board[row, action] = self._current_player
                self.history.append((row, action))  # Track the move
                break
        self._current_player = 2 if self._current_player == 1 else 1
        return True

    def delete_last_action(self):
        """Reverts the last action (used for MCTS rollouts)."""
        if self.history:
            row, col = self.history.pop()
            self.board[row, col] = 0
            self._current_player = 2 if self._current_player == 1 else 1

    def has_outcome(self):
        """Checks if the game has ended (win or draw)."""
        return self.check_winner() != 0 or np.all(self.board != 0)

    def check_winner(self):
        """Checks if there is a winner and returns the winning player (1 or 2)."""
        b = self.board
        rows, cols = self.rows, self.cols

        # Check horizontal, vertical, and diagonal wins
        for r in range(rows):
            for c in range(cols - 3):
                if b[r, c] != 0 and b[r, c] == b[r, c+1] == b[r, c+2] == b[r, c+3]:
                    return b[r, c]

        for r in range(rows - 3):
            for c in range(cols):
                if b[r, c] != 0 and b[r, c] == b[r+1, c] == b[r+2, c] == b[r+3, c]:
                    return b[r, c]

        for r in range(rows - 3):
            for c in range(cols - 3):
                if b[r, c] != 0 and b[r, c] == b[r+1, c+1] == b[r+2, c+2] == b[r+3, c+3]:
                    return b[r, c]

        for r in range(3, rows):
            for c in range(cols - 3):
                if b[r, c] != 0 and b[r, c] == b[r-1, c+1] == b[r-2, c+2] == b[r-3, c+3]:
                    return b[r, c]

        return 0

    def winner(self):
        """Returns the winner of the game (1 or 2) or 0 if no winner yet."""
        return self.check_winner()

    def render(self):
        """Prints the current state of the board."""
        print(self.board)
        print()

In [5]:
class Node:
    """Node class for MCTS."""
    def __init__(self, game, parent=None, action_index=None):
        self.game = game  # Game state at this node
        self.parent = parent  # Parent node
        self.action_index = action_index  # Action that led to this node
        self.children = []  # Child nodes
        self.T = 0  # Total value of the node
        self.N = 0  # Visit count

    def ucb(self, c=1.4):
        """
        Upper Confidence Bound (UCB) formula to balance exploration and exploitation.
        """
        if self.N == 0:
            return float('inf')  # Prioritize unvisited nodes
        return (self.T / self.N) + c * math.sqrt(math.log(self.parent.N) / self.N)

    def expand(self):
        """
        Expands the node by creating child nodes for all possible actions.
        """
        for action in self.game.possible_actions():
            # Create a copy of the game and apply the action
            new_game = Connect4Game()
            new_game.board = np.copy(self.game.board)
            new_game._current_player = self.game._current_player
            new_game.take_action(action)
            # Create a child node
            child = Node(new_game, parent=self, action_index=action)
            self.children.append(child)

    def is_leaf(self):
        """
        Checks if the node is a leaf (no children).
        """
        return len(self.children) == 0

    def best_child(self):
        """
        Returns the child with the highest UCB value.
        """
        return max(self.children, key=lambda child: child.ucb())

    def rollout(self):
        """
        Simulates a random game from the current state and returns the result.
        """
        game = Connect4Game()
        game.board = np.copy(self.game.board)
        game._current_player = self.game._current_player

        while not game.has_outcome():
            action = np.random.choice(game.possible_actions())
            game.take_action(action)

        if game.winner() == 1:  # Assuming player 1 is the MCTS agent
            return 1
        elif game.winner() == 2:  # Assuming player 2 is the opponent
            return -1
        else:
            return 0  # Draw


In [6]:
class MCTS:
    """Monte Carlo Tree Search implementation."""
    def __init__(self, game):
        self.root = Node(game)  # Root node of the tree

    def search(self, iterations):
        """
        Performs MCTS for a given number of iterations.
        """
        for _ in range(iterations):
            node = self.select(self.root)  # Selection
            value = self.simulate(node)  # Simulation
            self.backpropagate(node, value)  # Backpropagation

    def select(self, node):
        """
        Selects a leaf node using the UCB formula.
        """
        while not node.is_leaf():
            node = node.best_child()
        if node.N > 0:  # Expand if the node has been visited before
            node.expand()
            if node.children:
                node = node.best_child()
        return node

    def simulate(self, node):
        """
        Simulates a random game from the node and returns the result.
        """
        return node.rollout()

    def backpropagate(self, node, value):
        """
        Backpropagates the result of the simulation up the tree.
        """
        while node is not None:
            node.N += 1
            node.T += value
            node = node.parent

    def best_action(self):
        """
        Returns the best action based on the most visited child.
        """
        return max(self.root.children, key=lambda child: child.N).action_index

In [11]:
def play_mcts_vs_dqn(mcts, dqn_model, num_games=10):
    """
    Plays a series of games between MCTS and DQN, alternating who starts first.
    """
    mcts_wins = 0
    dqn_wins = 0
    draws = 0

    for game_num in range(num_games):
        game = Connect4Game()

        # Alternate starting player
        if game_num % 2 == 0:
            print(f"Game {game_num + 1}: MCTS starts first!")
            current_player = 1  # MCTS starts first
        else:
            print(f"Game {game_num + 1}: DQN starts first!")
            current_player = 2  # DQN starts first

        while not game.has_outcome():
            if game.current_player() == 1:
                # MCTS's turn
                mcts.root = Node(game)  # Reset MCTS tree
                mcts.search(iterations=1000)  # Run MCTS
                action = mcts.best_action()
            else:
                # DQN's turn
                obs = np.array(game.board).flatten()
                action, _ = dqn_model.predict(obs, deterministic=True)
            game.take_action(action)
            game.render()

        winner = game.winner()
        if winner == 1:
            mcts_wins += 1
            print(f"Game {game_num + 1}: MCTS wins!")
        elif winner == 2:
            dqn_wins += 1
            print(f"Game {game_num + 1}: DQN wins!")
        else:
            draws += 1
            print(f"Game {game_num + 1}: It's a draw!")

    print(f"\nTournament Results:")
    print(f"MCTS Wins: {mcts_wins}")
    print(f"DQN Wins: {dqn_wins}")
    print(f"Draws: {draws}")

In [12]:
# Initialize the Connect 4 game
game = Connect4Game()

# Initialize MCTS
mcts = MCTS(game)

# Load the trained DQN model
dqn_model = DQN.load("connect4_dqn_model")

# Run the tournament
play_mcts_vs_dqn(mcts, dqn_model, num_games=10)

Game 1: MCTS starts first!
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 1 0 0 0]]

[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 2 0 0 0]
 [0 0 0 1 0 0 0]]

[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 2 0 0 0]
 [0 0 0 1 1 0 0]]

[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 2 2 0 0]
 [0 0 0 1 1 0 0]]

[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 2 2 0 0]
 [0 0 0 1 1 1 0]]

[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 2 2 2 0]
 [0 0 0 1 1 1 0]]

[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 2 2 2 0]
 [0 0 0 1 1 1 1]]

Game 1: MCTS wins!
Game 2: DQN starts first!
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1]]

[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 2 0 0 0