In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
from tqdm import tqdm
import logging
import copy
import math

# Set up logging to avoid printing too much during training
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class UltimateTicTacToeBoard:
    """
    Represents the game state of Ultimate Tic-Tac-Toe.
    This class handles all game rules and board manipulations.
    """
    PLAYER_MAP = {1: "X", -1: "O"}
    SECTION_MAP = {
        0: (0, 0), 1: (0, 3), 2: (0, 6),
        3: (3, 0), 4: (3, 3), 5: (3, 6),
        6: (6, 0), 7: (6, 3), 8: (6, 6)
    }
    
    def __init__(self):
        self.reset()
        
    def reset(self):
        """Reset all game boards to their initial state."""
        # The main 9x9 board
        self.grid = np.zeros((9, 9), dtype=np.int8)
        # The 3x3 board representing which sub-boards are won
        self.section_grid = np.zeros((3, 3), dtype=np.int8)
        # The active section for the next move (0-8), or -1 if any section is valid
        self.active_section = -1
        self.current_player = 1
        self.last_move = None
        self.game_over = False
        self.winner = 0  # 1 for Player X, -1 for Player O, 0 for ongoing/draw

    def print_board(self):
        """Print the game board with symbols."""
        symbols = {0: ' ', 1: 'X', -1: 'O'}
        # Top line for section grid
        print("  | 0 1 2 | 3 4 5 | 6 7 8 ")
        print("--+-------+-------+-------")
        for i in range(9):
            row = self.grid[i]
            if i % 3 == 0 and i != 0:
                print("--+-------+-------+-------")
            
            row_str = " | ".join([
                " ".join([symbols[x] for x in row[j:j+3]]) for j in range(0, 9, 3)
            ])
            print(f"{i} | {row_str}")

    def get_valid_moves(self):
        """
        Returns a list of valid move indices (0-80).
        A move is valid if the cell is empty AND it is in the active section.
        If the active section is full or won, any non-full, non-won section is valid.
        """
        valid_moves = []
        if self.active_section == -1 or not self.is_section_playable(self.active_section):
            # If the next section is won or drawn, the player can move anywhere
            # on the main board where the section is not yet won or drawn.
            for section in range(9):
                if self.is_section_playable(section):
                    row_start, col_start = self.SECTION_MAP[section]
                    for r in range(3):
                        for c in range(3):
                            global_row, global_col = row_start + r, col_start + c
                            if self.grid[global_row, global_col] == 0:
                                valid_moves.append(global_row * 9 + global_col)
        else:
            # Player is forced to play in a specific section
            row_start, col_start = self.SECTION_MAP[self.active_section]
            for r in range(3):
                for c in range(3):
                    global_row, global_col = row_start + r, col_start + c
                    if self.grid[global_row, global_col] == 0:
                        valid_moves.append(global_row * 9 + global_col)
        
        return valid_moves

    def is_section_playable(self, section_number):
        """Check if a section has not been won or drawn."""
        macro_row = section_number // 3
        macro_col = section_number % 3
        return self.section_grid[macro_row, macro_col] == 0

    def make_move(self, index):
        """
        Places a player's mark on the board.
        Returns a tuple of (reward, done) for the RL environment.
        """
        if index not in self.get_valid_moves():
            return -10, False  # Penalize for illegal moves

        row, col = index // 9, index % 9
        self.grid[row, col] = self.current_player
        self.last_move = index

        reward = 0
        
        # Check the section where the move was made
        section_number = (row // 3) * 3 + (col // 3)
        macro_row, macro_col = section_number // 3, section_number % 3
        
        # Check if the section is won or drawn
        if self.is_section_won_or_drawn(section_number):
            if self.section_grid[macro_row, macro_col] == 0:
                section_winner = self._check_section_winner(section_number)
                if section_winner != 0:
                    self.section_grid[macro_row, macro_col] = section_winner
                    reward += 10 * self.current_player
                else: # It's a draw
                    self.section_grid[macro_row, macro_col] = 2 # Use 2 to denote a draw
                    reward += 5

        # Check for a final game winner
        self.winner = self._check_game_winner()
        if self.winner != 0:
            self.game_over = True
            reward = 100 * self.winner
            
        # Check for a draw on the main board
        if not self.game_over and len(self.get_valid_moves()) == 0:
            self.game_over = True
            self.winner = 2 # Draw
            reward = 1
        
        # Update the active section for the next player
        next_section = (row % 3) * 3 + (col % 3)
        if self.is_section_playable(next_section):
            self.active_section = next_section
        else:
            self.active_section = -1  # Any section is valid
        
        self.current_player *= -1
        
        return reward, self.game_over

    def _check_section_winner(self, section_number):
        """Check for a winner in a specific 3x3 section."""
        row_start, col_start = self.SECTION_MAP[section_number]
        section = self.grid[row_start:row_start + 3, col_start:col_start + 3]
        
        sums = [
            np.sum(section, axis=1), # Rows
            np.sum(section, axis=0), # Columns
            np.diag(section),        # Main diagonal
            np.diag(np.fliplr(section)) # Anti-diagonal
        ]
        
        for s in sums:
            if np.any(s == 3):
                return 1
            if np.any(s == -3):
                return -1
        return 0

    def is_section_won_or_drawn(self, section_number):
        """Checks if a section is either won, or is a draw."""
        # A section is won if a player has 3 in a row
        if self._check_section_winner(section_number) != 0:
            return True
        # A section is a draw if all cells are filled and no one has won
        row_start, col_start = self.SECTION_MAP[section_number]
        section = self.grid[row_start:row_start + 3, col_start:col_start + 3]
        if np.all(section != 0):
            return True
        return False

    def _check_game_winner(self):
        """Check for a winner on the 3x3 section grid."""
        sums = [
            np.sum(self.section_grid, axis=1),
            np.sum(self.section_grid, axis=0),
            np.diag(self.section_grid),
            np.diag(np.fliplr(self.section_grid))
        ]
        for s in sums:
            if np.any(s == 3):
                return 1
            if np.any(s == -3):
                return -1
        return 0


class UltimateTicTacToeEnv:
    """An OpenAI Gym-style environment for Ultimate Tic-Tac-Toe."""
    
    def __init__(self):
        self.game = UltimateTicTacToeBoard()
        self.action_space = 81
        self.observation_space = (9, 9)
        self.winner = 0

    def reset(self):
        """Resets the environment and returns the initial state."""
        self.game.reset()
        obs = self._get_obs()
        return obs

    def step(self, action):
        """
        Takes a single step in the environment.
        Returns observation, reward, done, and info.
        """
        # The reward is player-specific, so we need to track the player before the move
        player_before_move = self.game.current_player
        
        # An invalid move immediately ends the game with a large negative reward
        if action not in self.game.get_valid_moves():
            return self._get_obs(), -100, True, {'winner': -player_before_move, 'valid': False}

        reward, done = self.game.make_move(action)
        
        obs = self._get_obs()
        info = {
            'winner': self.game.winner,
            'valid': True
        }
        
        # The reward should be from the perspective of the player who just moved
        # But the environment handles the reward based on the game state
        if done:
            if self.game.winner == 1:
                final_reward = 100
            elif self.game.winner == -1:
                final_reward = -100
            else: # Draw
                final_reward = 10
            return obs, final_reward, done, info
            
        return obs, reward, done, info

    def render(self):
        """Prints the current state of the board."""
        self.game.print_board()

    def _get_obs(self):
        """
        Returns a tuple representing the current game state for the agent.
        (main_board, section_board, valid_moves_mask)
        """
        valid_moves = self.game.get_valid_moves()
        valid_mask = np.zeros(self.action_space, dtype=bool)
        valid_mask[valid_moves] = True
        return (
            np.copy(self.game.grid),
            np.copy(self.game.section_grid),
            valid_mask.reshape(9, 9)
        )

# Helper functions for state processing
def convert_state_to_tensor(state, player):
    """
    Converts a state tuple into a PyTorch tensor with the correct
    channel-first format for the network.
    
    Args:
        state (tuple): (main_board, section_board, valid_mask)
        player (int): 1 or -1, the current player's turn.
    
    Returns:
        torch.Tensor: A tensor of shape (1, 3, 9, 9).
    """
    board, macroboard, valid_mask = state
    
    # Flip the perspective so the current player is always '1'
    # This simplifies the learning for the neural network
    p_board = board * player
    p_macroboard = macroboard * player
    
    # Create a 9x9 macroboard representation
    expanded_macro = np.repeat(p_macroboard, 3, axis=0).repeat(3, axis=1)

    stacked_state = np.stack(
        [p_board, expanded_macro, valid_mask.astype(np.float32)],
        axis=0
    )
    
    tensor_state = torch.from_numpy(stacked_state).float().unsqueeze(0)
    return tensor_state


class ReplayBuffer:
    """A simple replay buffer for storing and sampling experiences."""
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        """Adds a new experience to the buffer."""
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        """Randomly samples a batch of experiences."""
        if len(self.buffer) < batch_size:
            return None
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)


class DQNetwork(nn.Module):
    """Deep Q-Network for Ultimate Tic-Tac-Toe."""
    def __init__(self, input_shape=(3, 9, 9), n_actions=81):
        super(DQNetwork, self).__init__()
        c, h, w = input_shape
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(c, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * h * w, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

# --- New MCTS Classes ---
class MCTSNode:
    """A node in the Monte Carlo Tree Search tree."""
    def __init__(self, parent, action=None, state=None):
        self.parent = parent
        self.action = action
        self.state = state
        self.children = []
        self.visits = 0
        self.total_reward = 0
        self.is_terminal = False
        self.untried_moves = None

    def ucb1_score(self, c_param):
        """Calculate the UCB1 score for selection."""
        if self.visits == 0:
            return float('inf') # Prioritize untried nodes
        return (self.total_reward / self.visits) + c_param * math.sqrt(math.log(self.parent.visits) / self.visits)

class MCTS:
    """Monte Carlo Tree Search implementation."""
    def __init__(self, agent_network, env, c_param=2.0):
        self.agent_network = agent_network
        self.env = env
        self.c_param = c_param
        self.device = agent_network.device

    def search(self, root_state, num_simulations):
        """Performs MCTS simulations starting from the root state."""
        root = MCTSNode(parent=None, action=None, state=root_state)
        root.untried_moves = self.env.game.get_valid_moves()

        for _ in range(num_simulations):
            node = self.select(root)
            if not node.is_terminal:
                expanded_node = self.expand(node)
                reward = self.simulate(expanded_node)
                self.backpropagate(expanded_node, reward)

        # After simulations, choose the best child based on visit counts
        best_child = max(root.children, key=lambda c: c.visits)
        return best_child.action

    def select(self, node):
        """Selects the best child node using the UCB1 algorithm."""
        while node.children:
            best_score = -float('inf')
            best_child = None
            for child in node.children:
                score = child.ucb1_score(self.c_param)
                if score > best_score:
                    best_score = score
                    best_child = child
            if best_child and best_child.untried_moves:
                return best_child
            if best_child:
                node = best_child
            else: # All children are terminal or fully expanded, pick one with max visits
                return max(node.children, key=lambda c: c.visits)
        return node

    def expand(self, node):
        """Expands the tree by adding a new child node."""
        move = random.choice(node.untried_moves)
        node.untried_moves.remove(move)

        # Create a new environment state for the child node
        child_env = copy.deepcopy(self.env)
        _, _, done, _ = child_env.step(move)
        
        child_state = child_env._get_obs()
        child_node = MCTSNode(parent=node, action=move, state=child_state)
        child_node.is_terminal = done
        if not done:
            child_node.untried_moves = child_env.game.get_valid_moves()
        node.children.append(child_node)
        return child_node

    def simulate(self, node):
        """Performs a rollout/simulation from a given node."""
        if node.is_terminal:
            return 1 if self.env.game.winner == self.env.game.current_player * -1 else -1

        # Use the agent's network to evaluate the current state
        with torch.no_grad():
            state_tensor = convert_state_to_tensor(node.state, self.env.game.current_player)
            q_values = self.agent_network(state_tensor.to(self.device)).cpu().squeeze()
            valid_moves_mask = np.ones(self.env.action_space, dtype=bool)
            valid_moves = self.env.game.get_valid_moves()
            invalid_moves = np.delete(np.arange(self.env.action_space), valid_moves)
            q_values[invalid_moves] = -float('inf')
            
            # The value of the state is the maximum Q-value
            value = q_values.max().item()
        
        return value

    def backpropagate(self, node, reward):
        """Propagates the reward back up the tree."""
        current_node = node
        while current_node is not None:
            current_node.visits += 1
            current_node.total_reward += reward
            # The reward is inverted for the opponent's turn
            reward *= -1
            current_node = current_node.parent

class DQNAgent:
    """
    Deep Q-Learning Agent for Ultimate Tic-Tac-Toe.
    This agent learns a policy by interacting with the game environment.
    """
    def __init__(self, env):
        self.env = env
        self.n_actions = self.env.action_space
        self.gamma = 0.99
        self.epsilon = 1.0 # This is now for a bit of exploration in MCTS, not for action selection
        self.epsilon_min = 0.05
        self.epsilon_decay = 0.995
        self.batch_size = 64
        self.learning_rate = 0.0001
        self.target_update_frequency = 1000
        
        self.q_network = DQNetwork()
        self.target_network = DQNetwork()
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)
        self.loss_fn = nn.MSELoss()
        
        self.replay_buffer = ReplayBuffer(capacity=100000)
        self.steps_done = 0
        self.mcts = MCTS(self.q_network, env)

    def get_action(self, state, player, valid_actions, num_simulations=100):
        """Selects an action using MCTS."""
        self.mcts.env = self.env # Ensure MCTS uses the current game state
        
        # Use MCTS to find the best action
        action = self.mcts.search(state, num_simulations)

        return action

    def learn(self):
        """Performs a single learning step from a batch of experiences."""
        if len(self.replay_buffer) < self.batch_size:
            return # Not enough experiences to train yet
        
        experiences = self.replay_buffer.sample(self.batch_size)
        
        # Separate components of the experience batch
        states, actions, rewards, next_states, dones = zip(*experiences)
        
        # Convert to tensors
        states_t = torch.cat([convert_state_to_tensor(s, 1) for s in states]).to(self.q_network.device)
        actions_t = torch.tensor(actions, dtype=torch.long).unsqueeze(-1).to(self.q_network.device)
        rewards_t = torch.tensor(rewards, dtype=torch.float32).to(self.q_network.device)
        next_states_t = torch.cat([convert_state_to_tensor(s, -1) for s in next_states]).to(self.q_network.device)
        dones_t = torch.tensor(dones, dtype=torch.float32).to(self.q_network.device)
        
        # Get Q-values for the states and actions taken
        q_values = self.q_network(states_t).gather(1, actions_t).squeeze()

        # Compute target Q-values using the target network
        with torch.no_grad():
            next_q_values = self.target_network(next_states_t).max(1)[0]
            target_q_values = rewards_t + self.gamma * next_q_values * (1 - dones_t)
            
        # Calculate loss and update the online network
        loss = self.loss_fn(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.steps_done += 1
        
        # Update target network
        if self.steps_done % self.target_update_frequency == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

    def update_epsilon(self):
        """Decays epsilon for a more greedy policy over time."""
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

# Main training loop
def train(agent, env, num_episodes=50000):
    logger.info("Starting training...")
    for episode in tqdm(range(num_episodes)):
        state = env.reset()
        done = False
        total_reward = 0
        
        while not done:
            # Player 1's turn
            valid_moves_p1 = env.game.get_valid_moves()
            action_p1 = agent.get_action(state, env.game.current_player, valid_moves_p1, num_simulations=200)
            next_state_p1, reward_p1, done, info_p1 = env.step(action_p1)
            
            # Store experience from P1's perspective
            agent.replay_buffer.push(state, action_p1, reward_p1, next_state_p1, done)
            agent.learn()
            
            state = next_state_p1
            total_reward += reward_p1
            if done:
                break
                
            # Player -1's turn (as an opponent to train against)
            valid_moves_p_1 = env.game.get_valid_moves()
            action_p_1 = agent.get_action(state, env.game.current_player, valid_moves_p_1, num_simulations=200)
            next_state_p_1, reward_p_1, done, info_p_1 = env.step(action_p_1)
            
            # Store experience from P-1's perspective
            # The reward is inverted because it's a zero-sum game
            agent.replay_buffer.push(state, action_p_1, -reward_p_1, next_state_p_1, done)
            agent.learn()
            
            state = next_state_p_1
            total_reward += reward_p_1
            if done:
                break
        
        agent.update_epsilon()
        if (episode + 1) % 1000 == 0:
            logger.info(f"Episode {episode + 1}: Total Reward = {total_reward}, Epsilon = {agent.epsilon:.3f}")

    logger.info("Training complete.")
    torch.save(agent.q_network.state_dict(), "ultimate_ttt_agent.pt")
    logger.info("Model saved as ultimate_ttt_agent.pt")


In [None]:

if __name__ == "__main__":
    # Example of how to use the code
    env = UltimateTicTacToeEnv()
    agent = DQNAgent(env)

    # Train the agent for a number of episodes
    # Note: Training can take a long time to converge
    # For a quick run, you can set num_simulations to a small number
    # (e.g., 50) and num_episodes to a smaller number.
    train(agent, env, num_episodes=10000)

    # --- After training, you can test the agent ---
    logger.info("\n--- Agent vs. Human Player Test ---")
    agent.q_network.load_state_dict(torch.load("ultimate_ttt_agent.pt"))
    
    # We still use MCTS for the final policy, but with more simulations
    # to get a better move.
    num_test_simulations = 800
    
    env.reset()
    env.render()
    
    while not env.game.game_over:
        if env.game.current_player == 1: # Human player (you)
            try:
                move = int(input("Enter your move (0-80): "))
                if move not in env.game.get_valid_moves():
                    logger.info("Invalid move. Try again.")
                    continue
                _, _, _, _ = env.step(move)
            except ValueError:
                logger.info("Invalid input. Please enter a number.")
                continue
        else: # Agent's turn
            logger.info("Agent is thinking with MCTS...")
            state = env._get_obs()
            valid_moves = env.game.get_valid_moves()
            action = agent.get_action(state, env.game.current_player, valid_moves, num_simulations=num_test_simulations)
            logger.info(f"Agent plays at: {action}")
            _, _, _, _ = env.step(action)

        env.render()
        
    if env.game.winner == 1:
        logger.info("Congratulations! You won!")
    elif env.game.winner == -1:
        logger.info("The agent won. Better luck next time!")
    else:
        logger.info("It's a draw!")

INFO:__main__:Starting training...
  0%|▏                                                                           | 21/10000 [03:13<26:01:26,  9.39s/it]