# HW8: Monte Carlo Tree Search (MCTS) for 3×3 Tic-Tac-Toe

In this notebook, we explore how to build intelligent agents to play a simple 3×3 game with a win condition of 3 in a row. We'll test multiple agent strategies, including:

- Heuristics
- An MCTS (Monte Carlo Tree Search) agent guided by a neural network

We begin by implementing the core of the MCTS algorithm.

In [None]:
import os
import math
import numpy as np
from random import shuffle
import torch
import torch.optim as optim
from abc import ABC, abstractmethod
import torch.nn as nn
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
import pandas as pd
from typing import Dict, List, Tuple, Optional
import json
import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

## 🔍 What’s in this cell?

The following code defines two key components:
- `Node`: a class representing nodes in the MCTS search tree
- `MCTS`: a class that performs the MCTS algorithm, which repeatedly simulates games and backs up evaluations to guide better decisions

The algorithm uses a neural network model for policy priors and state evaluation, and selects actions based on a UCB (Upper Confidence Bound) score.

We will use this implementation throughout the notebook to build a powerful learning-based agent.

In [None]:
def ucb_score(parent, child):
    """
    The score for an action that would transition between the parent and child.
    """
    prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
    if child.visit_count > 0:
        # The value of the child is from the perspective of the opposing player
        value_score = -child.value()
    else:
        value_score = 0

    return value_score + prior_score


class Node:
    def __init__(self, prior, to_play):
        self.visit_count = 0
        self.to_play = to_play
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.state = None

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def select_action(self, temperature):
        """
        Select action according to the visit count distribution and the temperature.
        """
        visit_counts = np.array([child.visit_count for child in self.children.values()])
        actions = [action for action in self.children.keys()]
        if temperature == 0:
            action = actions[np.argmax(visit_counts)]
        elif temperature == float("inf"):
            action = np.random.choice(actions)
        else:
            # See paper appendix Data Generation
            visit_count_distribution = visit_counts ** (1 / temperature)
            visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
            action = np.random.choice(actions, p=visit_count_distribution)

        return action

    def select_child(self):
        """
        Select the child with the highest UCB score.
        """
        best_score = -np.inf
        best_action = -1
        best_child = None

        for action, child in self.children.items():
            score = ucb_score(self, child)
            if score > best_score:
                best_score = score
                best_action = action
                best_child = child

        return best_action, best_child

    def expand(self, state, to_play, action_probs):
        """
        We expand a node and keep track of the prior policy probability given by neural network
        """
        self.to_play = to_play
        self.state = state
        for a, prob in enumerate(action_probs):
            if prob != 0:
                self.children[a] = Node(prior=prob, to_play=self.to_play * -1)

    def __repr__(self):
        """
        Debugger pretty print node info
        """
        prior = "{0:.2f}".format(self.prior)
        return "{} Prior: {} Count: {} Value: {}".format(self.state.__str__(), prior, self.visit_count, self.value())


class MCTS:

    def __init__(self, game, model, args):
        self.game = game
        self.model = model
        self.args = args

    def run(self, model, state, to_play):

        root = Node(0, to_play)

        # EXPAND root
        action_probs, value = model.predict(state)
        valid_moves = self.game.get_valid_moves(state)
        action_probs = action_probs * valid_moves  # mask invalid moves
        action_probs /= np.sum(action_probs)
        root.expand(state, to_play, action_probs)

        for _ in range(self.args['num_simulations']):
            node = root
            search_path = [node]

            # SELECT
            while node.expanded():
                action, node = node.select_child()
                search_path.append(node)

            parent = search_path[-2]
            state = parent.state
            # Now we're at a leaf node and we would like to expand
            # Players always play from their own perspective
            next_state, _ = self.game.get_next_state(state, player=1, action=action)
            # Get the board from the perspective of the other player
            next_state = self.game.get_canonical_board(next_state, player=-1)

            # The value of the new state from the perspective of the other player
            value = self.game.get_reward_for_player(next_state, player=1)
            if value is None:
                # If the game has not ended:
                # EXPAND
                action_probs, value = model.predict(next_state)
                valid_moves = self.game.get_valid_moves(next_state)
                action_probs = action_probs * valid_moves  # mask invalid moves
                action_probs /= np.sum(action_probs)
                node.expand(next_state, parent.to_play * -1, action_probs)

            self.backpropagate(search_path, value, parent.to_play * -1)

        return root

    def backpropagate(self, search_path, value, to_play):
        """
        At the end of a simulation, we propagate the evaluation all the way up the tree
        to the root.
        """
        for node in reversed(search_path):
            node.value_sum += value if node.to_play == to_play else -value
            node.visit_count += 1

## 🎮 Game Environment: Generalized Tic-Tac-Toe (K-in-a-Row)

To experiment with intelligent agents, we define a flexible game environment: a 2D grid where the goal is to connect **K pieces in a row** (horizontally, vertically, or diagonally).

This implementation supports:
- Custom board size (e.g., 3×3, 4×4)
- Custom win condition (e.g., 3 in a row, 4 in a row)
- All necessary game mechanics: state transition, valid moves, win detection, and canonical board representation.

We'll use this class to simulate our 3×3 game with a win condition of 3.

In [None]:
class TicTacToeK:
    """
    Generalized Connect-K on a 2D grid.
    Default: 2x2 board, win condition = 2 in a row
    """

    def __init__(self, rows=2, cols=2, win=2):
        self.rows = rows
        self.cols = cols
        self.win = win

    def get_init_board(self):
        return np.zeros((self.rows, self.cols), dtype=int)

    def get_board_size(self):
        return (self.rows, self.cols)

    def get_action_size(self):
        return self.rows * self.cols

    def get_next_state(self, board, player, action):
        b = np.copy(board)
        r, c = divmod(action, self.cols)
        b[r, c] = player
        return b, -player

    def has_legal_moves(self, board):
        return np.any(board == 0)

    def get_valid_moves(self, board):
        return [(1 if board[r, c] == 0 else 0) for r in range(self.rows) for c in range(self.cols)]

    def is_win(self, board, player):
        # Check rows
        for r in range(self.rows):
            for c in range(self.cols - self.win + 1):
                if all(board[r, c + i] == player for i in range(self.win)):
                    return True
        # Check columns
        for c in range(self.cols):
            for r in range(self.rows - self.win + 1):
                if all(board[r + i, c] == player for i in range(self.win)):
                    return True
        # Check diagonals
        for r in range(self.rows - self.win + 1):
            for c in range(self.cols - self.win + 1):
                if all(board[r + i, c + i] == player for i in range(self.win)):
                    return True
        # Check anti-diagonals
        for r in range(self.rows - self.win + 1):
            for c in range(self.win - 1, self.cols):
                if all(board[r + i, c - i] == player for i in range(self.win)):
                    return True
        return False

    def get_reward_for_player(self, board, player):
        if self.is_win(board, player):
            return 1
        if self.is_win(board, -player):
            return -1
        if self.has_legal_moves(board):
            return None
        return 0  # draw

    def get_canonical_board(self, board, player):
        return player * board

## Agent Implementations and Exploration Task

In this cell, we define a set of agent classes, each representing a different strategy for playing the game. All agents inherit from the `BasePlayer` interface, which requires two methods:

- `get_action`: returns the selected move given the current game state and player.
- `get_action_probs`: returns a probability distribution over actions (used by MCTS).


---

### ✏️ Your Task (3 points)

Carefully read the implementations of the four agents above. For each agent:

1. Explain **in your own words** how the agent makes decisions.
2. What are the advantages and limitations of each strategy?
3. Which agent do you expect to perform best in a 3×3 Connect-3 game, and why?

---





In [None]:


class BasePlayer(ABC):
    """Abstract base class for all players"""

    @abstractmethod
    def get_action(self, game, state, player):
        """Get action for the given state from this player's perspective"""
        pass

    @abstractmethod
    def get_action_probs(self, game, state, player):
        """Get action probabilities for the given state"""
        pass


class MCTSPlayer(BasePlayer):

    def __init__(self, model, args):
        self.model = model
        self.args = args
        self.mcts = None

    def get_action(self, game, state, player):
        canonical_board = game.get_canonical_board(state, player)
        self.mcts = MCTS(game, self.model, self.args)
        root = self.mcts.run(self.model, canonical_board, to_play=1)
        return root.select_action(temperature=0)

    def get_action_probs(self, game, state, player):
        canonical_board = game.get_canonical_board(state, player)
        self.mcts = MCTS(game, self.model, self.args)
        root = self.mcts.run(self.model, canonical_board, to_play=1)

        action_probs = [0 for _ in range(game.get_action_size())]
        for k, v in root.children.items():
            action_probs[k] = v.visit_count

        if np.sum(action_probs) > 0:
            action_probs = np.array(action_probs) / np.sum(action_probs)
        else:
            action_probs = np.array(action_probs)

        return action_probs, root


class Player1(BasePlayer):

    def get_action(self, game, state, player):
        valid_moves = game.get_valid_moves(game.get_canonical_board(state, player))
        valid_actions = [i for i, valid in enumerate(valid_moves) if valid]
        return np.random.choice(valid_actions)

    def get_action_probs(self, game, state, player):
        valid_moves = game.get_valid_moves(game.get_canonical_board(state, player))
        action_probs = valid_moves / np.sum(valid_moves) if np.sum(valid_moves) > 0 else valid_moves
        return action_probs, None


class Player2(BasePlayer):

    def get_action(self, game, state, player):
        canonical_board = game.get_canonical_board(state, player)
        valid_moves = game.get_valid_moves(canonical_board)
        valid_actions = [i for i, valid in enumerate(valid_moves) if valid]

        best_action = valid_actions[0]
        best_reward = float('-inf')

        for action in valid_actions:
            next_state, _ = game.get_next_state(state, player, action)
            reward = game.get_reward_for_player(next_state, player)
            if reward is not None and reward > best_reward:
                best_reward = reward
                best_action = action

        return best_action

    def get_action_probs(self, game, state, player):
        canonical_board = game.get_canonical_board(state, player)
        valid_moves = game.get_valid_moves(canonical_board)
        valid_actions = [i for i, valid in enumerate(valid_moves) if valid]

        action_probs = np.zeros(game.get_action_size())
        best_action = self.get_action(game, state, player)
        action_probs[best_action] = 1.0

        return action_probs, None

class Player3(BasePlayer):

    def __init__(self, win_size):
        self.win = win_size

    def get_action(self, game, state, player):
        b = game.get_canonical_board(state, player)
        # ensure 2D for uniform windowing
        mat = b if b.ndim==2 else b.reshape(1, -1)
        rows, cols = mat.shape
        valid = game.get_valid_moves(b)
        actions = {i for i,v in enumerate(valid) if v}

        def windows():
            # yield list of (r,c) coords for each length-win window in all directions
            dirs = [(0,1),(1,0),(1,1),(1,-1)]
            for dr,dc in dirs:
                for r in range(rows):
                    for c in range(cols):
                        coords = [(r+i*dr, c+i*dc) for i in range(self.win)]
                        if all(0<=rr<rows and 0<=cc<cols for rr,cc in coords):
                            yield coords

        opp = -1; me = 1

        # 1) WIN_SIZE-2 threats
        threats2 = []
        for coords in windows():
            vals = [mat[r,c] for r,c in coords]
            if vals.count(opp)==self.win-2 and vals.count(0)>=2:
                empties = [(r,c) for r,c in coords if mat[r,c]==0]
                threats2.append(empties)
        if threats2:
            empties = random.choice(threats2)
            r,c = random.choice(empties)
            a = r*cols + c if b.ndim>1 else c
            if a in actions: return a

        # 2) WIN_SIZE-1 threats
        blocks1 = []
        for coords in windows():
            vals = [mat[r,c] for r,c in coords]
            if vals.count(opp)==self.win-1 and vals.count(0)>=1:
                empties = [(r,c) for r,c in coords if mat[r,c]==0]
                blocks1.extend(empties)
        if blocks1:
            r,c = random.choice(blocks1)
            a = r*cols + c if b.ndim>1 else c
            if a in actions: return a

        # 3)
        my_counts = []
        win_windows = list(windows())
        for coords in win_windows:
            vals = [mat[r,c] for r,c in coords]
            my_counts.append(vals.count(me))
        max_me = max(my_counts) if my_counts else 0

        extend_cands = []
        for coords, cnt in zip(win_windows, my_counts):
            if cnt==max_me and cnt>0:
                empties = [(r,c) for r,c in coords if mat[r,c]==0]
                extend_cands.extend(empties)
        if extend_cands:
            r,c = random.choice(extend_cands)
            a = r*cols + c if b.ndim>1 else c
            if a in actions: return a

        # 4) fallback random
        return random.choice(list(actions))

    def get_action_probs(self, game, state, player):
        a = self.get_action(game, state, player)
        probs = np.zeros(game.get_action_size())
        probs[a] = 1.0
        return probs, None






## Training Against a Fixed Opponent

The `FixedOpponentTrainer` class provides functionality to train a neural network policy/value model against a fixed opponent. This setup simulates a player that continually improves while playing against a static agent.

The training procedure consists of the following components:

- **Self-play episodes**: The trainable agent (using MCTS) plays full games against the fixed opponent. During the trainable player's turn, it collects training examples consisting of:
  - The canonical board state.
  - The action probabilities from MCTS.
  - The final game outcome (used as the value target).

- **Reward assignment**: After the game ends, each move is labeled with a reward (win/loss/draw), adjusted according to which player took the action.

- **Neural network updates**: Collected examples are used to update the trainable model's policy and value heads using standard loss functions (cross-entropy for policy, MSE for value).



In [None]:
class FixedOpponentTrainer:
    """Trainer that allows training against a fixed opponent"""

    def __init__(self, game, trainable_model, fixed_player, args, trainable_player=1):
        """
        Args:
            game: Game instance
            trainable_model: Model to be trained
            fixed_player: Instance of BasePlayer for the fixed opponent
            args: Training arguments
            trainable_player: Which player to train (1 or -1)
        """
        self.game = game
        self.trainable_model = trainable_model
        self.fixed_player = fixed_player
        self.args = args
        self.trainable_player = trainable_player
        self.trainable_mcts_player = MCTSPlayer(trainable_model, args)

    def execute_episode(self):
        """Execute one episode of self-play against fixed opponent"""
        train_examples = []
        current_player = 1
        state = self.game.get_init_board()

        while True:
            if current_player == self.trainable_player:
                # Trainable player's turn - collect training data
                canonical_board = self.game.get_canonical_board(state, current_player)
                action_probs, root = self.trainable_mcts_player.get_action_probs(self.game, state, current_player)
                train_examples.append((canonical_board, current_player, action_probs))
                action = root.select_action(temperature=0)
            else:
                # Fixed player's turn - no training data collected
                action = self.fixed_player.get_action(self.game, state, current_player)

            state, current_player = self.game.get_next_state(state, current_player, action)
            reward = self.game.get_reward_for_player(state, current_player)

            if reward is not None:
                # Game ended, assign rewards
                ret = []
                for hist_state, hist_current_player, hist_action_probs in train_examples:
                    # Reward from the perspective of the trainable player
                    if self.trainable_player == 1:
                        final_reward = reward if current_player == 1 else -reward
                    else:
                        final_reward = reward if current_player == -1 else -reward

                    # Adjust reward based on which player made the move
                    player_reward = final_reward * ((-1) ** (hist_current_player != self.trainable_player))
                    ret.append((hist_state, hist_action_probs, player_reward))

                return ret

    def learn(self):
        """Main training loop"""
        for i in range(1, self.args['numIters'] + 1):
            print(f"Iteration {i}/{self.args['numIters']}")

            train_examples = []

            for eps in range(self.args['numEps']):
                iteration_train_examples = self.execute_episode()
                train_examples.extend(iteration_train_examples)

            shuffle(train_examples)
            self.train(train_examples)



    def train(self, examples):
        """Train the neural network"""
        optimizer = optim.Adam(self.trainable_model.parameters(), lr=5e-4)
        pi_losses = []
        v_losses = []

        for epoch in range(self.args['epochs']):
            self.trainable_model.train()
            batch_idx = 0

            while batch_idx < int(len(examples) / self.args['batch_size']):
                sample_ids = np.random.randint(len(examples), size=self.args['batch_size'])
                boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
                boards = torch.FloatTensor(np.array(boards).astype(np.float64))
                target_pis = torch.FloatTensor(np.array(pis))
                target_vs = torch.FloatTensor(np.array(vs).astype(np.float64))

                # Move to device if available
                if torch.cuda.is_available():
                    device = torch.device('cuda')
                    boards = boards.contiguous().to(device)
                    target_pis = target_pis.contiguous().to(device)
                    target_vs = target_vs.contiguous().to(device)

                # Compute output
                out_pi, out_v = self.trainable_model(boards)
                l_pi = self.loss_pi(target_pis, out_pi)
                l_v = self.loss_v(target_vs, out_v)
                total_loss = l_pi + l_v

                pi_losses.append(float(l_pi))
                v_losses.append(float(l_v))

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                batch_idx += 1


    def loss_pi(self, targets, outputs):
        """Policy loss function"""
        loss = -(targets * torch.log(outputs + 1e-8)).sum(dim=1)
        return loss.mean()

    def loss_v(self, targets, outputs):
        """Value loss function"""
        loss = torch.sum((targets - outputs.view(-1)) ** 2) / targets.size()[0]
        return loss

## Neural Network Model for Connect Game

The `ConnectPolicyValueNet` class defines a simple feedforward neural network used to approximate both the **policy** (action probabilities) and **value** (expected game outcome) for a given game board.

- The input is a flattened board state.
- It outputs:
  - A **softmax** probability vector over all legal actions (policy head).
  - A scalar value between **−1** and **1** representing the expected outcome for the current player (value head).

This model is used by the MCTS algorithm to guide simulations and evaluate board positions during training and play.

---

### 🧠 Task (2 points)
Implement your own version of `ConnectPolicyValueNet`:
- Use at least **two hidden layers**.
- Use **ReLU** activations and appropriate final activations (`softmax` for policy, `tanh` for value).
- Your class should define a `forward` method and a `predict(board)` method that returns both policy and value outputs.

```python
# Your code here:
class ConnectPolicyValueNet(nn.Module):
    ...
```

In [None]:
class ConnectPolicyValueNet(nn.Module):
    # ...

## Training a Policy-Value Network on 3×3 Tic-Tac-Toe

We now set up and train our agent to play a generalized 3×3 Tic-Tac-Toe game with a win condition of 3 in a row.

#### Key Components:
- **Game Setup**:
  - The board is 3×3 and the player must align 3 pieces to win.
- **Model**:
  - We instantiate our neural network (`ConnectPolicyValueNet`) with appropriate input/output dimensions.
- **Training Arguments**:
  - `numIters`: number of training iterations.
  - `numEps`: number of self-play episodes per iteration.
  - `num_simulations`: number of MCTS simulations per move.
  - `epochs`, `batch_size`: used to train the model from generated data.

#### Opponent:
The `FixedOpponentTrainer` class orchestrates self-play between the trainable MCTS agent and the fixed opponent, collecting training data and updating the model accordingly.

---

### ✅ Your Task: (5 points)

- Try training your model against different opponent agents (e.g., `RandomPlayer`, `SequenceThreatPlayer`, etc.).
- Experiment with different training configurations by tuning the values in the `args` dictionary.

Below is a list of the key hyperparameters. After trying different values, **explain in your own words** what each parameter controls and how it influences learning.

🔍 **Your goal** is to find a configuration that trains a model which performs well against provided opponents. Reflect on which hyperparameters were most impactful and why.

We expect your model to win 80 percent of the games against our players.

In [None]:
BOARD_ROWS = 3
BOARD_COLS = 3
WIN_SIZE = 3

game = TicTacToeK(rows=BOARD_ROWS, cols=BOARD_COLS, win=WIN_SIZE)
board_size = game.get_board_size()
action_size = game.get_action_size()

model = ConnectPolicyValueNet(board_size, action_size, device)

args = {
    'numIters': 1,
    'numEps': 1,
    'num_simulations': 1,
    'epochs': 1,
    'batch_size': 64
}

# Choose a fixed opponent
fixed_opponent = #...
# Create trainer
trainer = FixedOpponentTrainer(
    game=game,
    trainable_model=model,
    fixed_player=fixed_opponent,
    args=args,
    trainable_player=  # ...
)

# Start training
trainer.learn()


### Class to calculate performance

In [None]:


class GameStats:
    """Class to track game statistics"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.wins = 0
        self.losses = 0
        self.draws = 0
        self.total_games = 0
        self.total_moves = 0
        self.game_lengths = []
        self.move_times = []
        self.rewards = []

    def add_game_result(self, result, game_length, avg_move_time, final_reward):
        """
        Add result of a single game
        result: 1 for win, -1 for loss, 0 for draw
        """
        if result == 1:
            self.wins += 1
        elif result == -1:
            self.losses += 1
        else:
            self.draws += 1

        self.total_games += 1
        self.total_moves += game_length
        self.game_lengths.append(game_length)
        self.move_times.append(avg_move_time)
        self.rewards.append(final_reward)

    def get_win_rate(self):
        return self.wins / self.total_games if self.total_games > 0 else 0

    def get_stats_dict(self):
        return {
            'wins': self.wins,
            'losses': self.losses,
            'draws': self.draws,
            'total_games': self.total_games,
            'win_rate': self.get_win_rate(),
            'avg_game_length': np.mean(self.game_lengths) if self.game_lengths else 0,
            'avg_move_time': np.mean(self.move_times) if self.move_times else 0,
            'avg_reward': np.mean(self.rewards) if self.rewards else 0
        }


class ModelTester:
    """Comprehensive testing system for game models"""

    def __init__(self, game, model, args):
        self.game = game
        self.model = model
        self.args = args
        self.test_results = {}
        self.sample_games = {}  # Store sample games for each opponent

    def play_single_game(self, player1, player2, verbose=False):
        """
        Play a single game between two players
        Returns: (winner, game_length, move_times, final_state, game_trace)
        """
        state = self.game.get_init_board()
        current_player = 1
        move_count = 0
        move_times = []
        game_trace = []  # Store game moves for replay

        if verbose:
            print("Starting new game...")

        while True:
            start_time = time.time()

            if current_player == 1:
                action = player1.get_action(self.game, state, current_player)
            else:
                action = player2.get_action(self.game, state, current_player)

            move_time = time.time() - start_time
            move_times.append(move_time)
            move_count += 1

            # Store move in trace
            game_trace.append({
                'move': move_count,
                'player': current_player,
                'action': action,
                'state_before': state.copy() if hasattr(state, 'copy') else str(state),
                'move_time': move_time
            })

            if verbose:
                print(f"Player {current_player} chose action {action} in {move_time:.3f}s")

            state, current_player = self.game.get_next_state(state, current_player, action)
            reward = self.game.get_reward_for_player(state, current_player)

            if reward is not None:
                winner = current_player if reward > 0 else (-current_player if reward < 0 else 0)
                game_trace.append({
                    'final_state': state.copy() if hasattr(state, 'copy') else str(state),
                    'winner': winner
                })
                if verbose:
                    print(f"Game ended. Winner: {winner if winner != 0 else 'Draw'}")
                return winner, move_count, move_times, state, game_trace

    def test_against_opponent(self, opponent, num_games=100, test_as_both_players=True, verbose=False):
        """
        Test model against a specific opponent
        """
        print(f"Testing against {opponent.__class__.__name__} ({num_games} games)...")

        model_player = MCTSPlayer(self.model, self.args)
        results = {}
        sample_game_trace = None

        if test_as_both_players:
            # Test as player 1
            stats_as_p1 = GameStats()
            for i in range(num_games // 2):
                if verbose or i % 10 == 0:
                    print(f"Game {i+1}/{num_games//2} (as Player 1)")

                winner, length, times, _, trace = self.play_single_game(model_player, opponent, verbose)
                result = 1 if winner == 1 else (-1 if winner == -1 else 0)
                stats_as_p1.add_game_result(result, length, np.mean(times), winner)

                # Randomly select a game for sample replay
                if sample_game_trace is None and random.random() < 0.1:  # 10% chance
                    sample_game_trace = trace

            # Test as player 2
            stats_as_p2 = GameStats()
            for i in range(num_games // 2):
                if verbose or i % 10 == 0:
                    print(f"Game {i+1}/{num_games//2} (as Player 2)")

                winner, length, times, _, trace = self.play_single_game(opponent, model_player, verbose)
                result = 1 if winner == -1 else (-1 if winner == 1 else 0)  # Flip perspective
                stats_as_p2.add_game_result(result, length, np.mean(times), -winner if winner != 0 else 0)

                # Randomly select a game for sample replay if we don't have one yet
                if sample_game_trace is None and random.random() < 0.1:
                    sample_game_trace = trace

            results['as_player_1'] = stats_as_p1.get_stats_dict()
            results['as_player_2'] = stats_as_p2.get_stats_dict()

            # Combined stats
            combined_stats = GameStats()
            combined_stats.wins = stats_as_p1.wins + stats_as_p2.wins
            combined_stats.losses = stats_as_p1.losses + stats_as_p2.losses
            combined_stats.draws = stats_as_p1.draws + stats_as_p2.draws
            combined_stats.total_games = stats_as_p1.total_games + stats_as_p2.total_games
            combined_stats.game_lengths = stats_as_p1.game_lengths + stats_as_p2.game_lengths
            combined_stats.move_times = stats_as_p1.move_times + stats_as_p2.move_times
            combined_stats.rewards = stats_as_p1.rewards + stats_as_p2.rewards

            results['combined'] = combined_stats.get_stats_dict()

        else:
            # Test only as player 1
            stats = GameStats()
            for i in range(num_games):
                if verbose or i % 10 == 0:
                    print(f"Game {i+1}/{num_games}")

                winner, length, times, _, trace = self.play_single_game(model_player, opponent, verbose)
                result = 1 if winner == 1 else (-1 if winner == -1 else 0)
                stats.add_game_result(result, length, np.mean(times), winner)

                # Randomly select a game for sample replay
                if sample_game_trace is None and random.random() < 0.1:
                    sample_game_trace = trace

            results['combined'] = stats.get_stats_dict()

        # Store the sample game trace
        if sample_game_trace:
            self.sample_games[opponent.__class__.__name__] = sample_game_trace

        return results

    def comprehensive_test(self, opponents_config, num_games_per_opponent=100, save_results=True):
        """
        Run comprehensive tests against multiple opponents

        opponents_config: dict with opponent_name -> opponent_instance
        """
        print("Starting comprehensive model testing...")
        print("=" * 50)

        all_results = {}

        for opponent_name, opponent in opponents_config.items():
            results = self.test_against_opponent(
                opponent,
                num_games=num_games_per_opponent,
                test_as_both_players=True
            )
            all_results[opponent_name] = results

            # Print summary
            combined = results['combined']
            print(f"\n{opponent_name} Results:")
            print(f"  Win Rate: {combined['win_rate']:.1%}")
            print(f"  Games: {combined['wins']}W - {combined['losses']}L - {combined['draws']}D")
            print(f"  Avg Game Length: {combined['avg_game_length']:.1f} moves")
            print(f"  Avg Move Time: {combined['avg_move_time']:.3f}s")

            if 'as_player_1' in results:
                p1_wr = results['as_player_1']['win_rate']
                p2_wr = results['as_player_2']['win_rate']
                print(f"  As Player 1: {p1_wr:.1%} | As Player 2: {p2_wr:.1%}")

        self.test_results = all_results

        # Print sample games for all opponents
        self.print_all_sample_games()

        if save_results:
            self.save_results()

        return all_results

    def print_all_sample_games(self):
        """Print sample games for all tested opponents"""
        if not self.sample_games:
            print("\nNo sample games recorded.")
            return

        print("\n" + "=" * 80)
        print("SAMPLE GAMES")
        print("=" * 80)

        for opponent_name, game_trace in self.sample_games.items():
            self.print_game_trace(game_trace, opponent_name)

    def print_game_trace(self, game_trace, opponent_name):
        """Print a detailed trace of a recorded game"""
        print(f"\n{'='*60}")
        print(f"SAMPLE GAME vs {opponent_name}")
        print(f"{'='*60}")

        if not game_trace:
            print("No game trace available")
            return

        # Extract moves and final result
        moves = [entry for entry in game_trace if 'move' in entry]
        final_entry = [entry for entry in game_trace if 'final_state' in entry]

        if not moves:
            print("No moves recorded in trace")
            return

        print(f"Game with {len(moves)} moves:")
        print()

        for move_info in moves:
            player = move_info['player']
            action = move_info['action']
            move_time = move_info['move_time']
            move_num = move_info['move']

            player_name = "MODEL" if player == 1 else opponent_name
            print(f"Move {move_num}: {player_name} (Player {player})")
            print(f"  Action: {action}")
            print(f"  Time: {move_time:.3f}s")
            print(f"  State before move:")
            print(f"  {move_info['state_before']}")
            print()

        if final_entry:
            winner = final_entry[0]['winner']
            print(f"FINAL STATE:")
            print(f"{final_entry[0]['final_state']}")
            print()
            print(f"RESULT:")
            if winner == 1:
                print(f"🎉 MODEL WINS! 🎉")
            elif winner == -1:
                print(f"😞 {opponent_name.upper()} WINS 😞")
            else:
                print(f"🤝 DRAW 🤝")

        print(f"{'='*60}")

    def analyze_playing_style(self, opponent, num_games=50):
        """Analyze model's playing style and patterns"""
        print(f"Analyzing playing style against {opponent.__class__.__name__}...")

        model_player = MCTSPlayer(self.model, self.args)
        action_frequency = Counter()
        game_lengths = []
        opening_moves = []

        for i in range(num_games):
            state = self.game.get_init_board()
            current_player = 1
            move_count = 0
            first_move = None

            while True:
                if current_player == 1:  # Model's turn
                    action = model_player.get_action(self.game, state, current_player)
                    action_frequency[action] += 1

                    if first_move is None:
                        first_move = action
                else:
                    action = opponent.get_action(self.game, state, current_player)

                move_count += 1
                state, current_player = self.game.get_next_state(state, current_player, action)
                reward = self.game.get_reward_for_player(state, current_player)

                if reward is not None:
                    game_lengths.append(move_count)
                    opening_moves.append(first_move)
                    break

        style_analysis = {
            'action_frequency': dict(action_frequency),
            'most_common_actions': action_frequency.most_common(5),
            'avg_game_length': np.mean(game_lengths),
            'opening_move_diversity': len(set(opening_moves)),
            'most_common_opening': Counter(opening_moves).most_common(1)[0] if opening_moves else None
        }

        return style_analysis

    def create_performance_report(self):
        """Create a detailed performance report"""
        if not self.test_results:
            print("No test results available. Run tests first.")
            return

        print("\n" + "=" * 60)
        print("COMPREHENSIVE PERFORMANCE REPORT")
        print("=" * 60)

        # Overall performance summary
        all_win_rates = []
        for opponent, results in self.test_results.items():
            win_rate = results['combined']['win_rate']
            all_win_rates.append(win_rate)

        print(f"\nOVERALL PERFORMANCE:")
        print(f"Average Win Rate: {np.mean(all_win_rates):.1%}")
        print(f"Best Performance: {np.max(all_win_rates):.1%}")
        print(f"Worst Performance: {np.min(all_win_rates):.1%}")
        print(f"Win Rate Standard Deviation: {np.std(all_win_rates):.1%}")

        # Detailed results per opponent
        print(f"\nDETAILED RESULTS:")
        for opponent, results in self.test_results.items():
            combined = results['combined']
            print(f"\nvs {opponent}:")
            print(f"  Win Rate: {combined['win_rate']:.1%}")
            print(f"  Record: {combined['wins']}W-{combined['losses']}L-{combined['draws']}D")
            print(f"  Avg Game Length: {combined['avg_game_length']:.1f} moves")
            print(f"  Avg Move Time: {combined['avg_move_time']:.3f}s")

            if 'as_player_1' in results:
                p1_wr = results['as_player_1']['win_rate']
                p2_wr = results['as_player_2']['win_rate']
                print(f"  As Player 1: {p1_wr:.1%} | As Player 2: {p2_wr:.1%}")

    def plot_results(self, save_plots=True):
        """Create visualization plots of test results"""
        if not self.test_results:
            print("No test results to plot.")
            return

        # Prepare data
        opponents = list(self.test_results.keys())
        win_rates = [self.test_results[opp]['combined']['win_rate'] for opp in opponents]
        game_lengths = [self.test_results[opp]['combined']['avg_game_length'] for opp in opponents]
        move_times = [self.test_results[opp]['combined']['avg_move_time'] for opp in opponents]

        # Prepare win rates by starting position
        p1_win_rates = []
        p2_win_rates = []
        for opp in opponents:
            if 'as_player_1' in self.test_results[opp]:
                p1_win_rates.append(self.test_results[opp]['as_player_1']['win_rate'])
                p2_win_rates.append(self.test_results[opp]['as_player_2']['win_rate'])
            else:
                p1_win_rates.append(win_rates[opponents.index(opp)])
                p2_win_rates.append(0)  # No data for player 2

        # Create subplots
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        fig.suptitle('Model Performance Analysis', fontsize=16)

        # Overall win rates bar chart
        axes[0, 0].bar(opponents, win_rates, color='skyblue')
        axes[0, 0].set_title('Overall Win Rates by Opponent')
        axes[0, 0].set_ylabel('Win Rate')
        axes[0, 0].set_ylim(0, 1)
        axes[0, 0].tick_params(axis='x', rotation=45)
        for i, v in enumerate(win_rates):
            axes[0, 0].text(i, v + 0.01, f'{v:.1%}', ha='center', va='bottom')

        # Win rates by starting position
        x = np.arange(len(opponents))
        width = 0.35
        axes[0, 1].bar(x - width/2, p1_win_rates, width, label='Starting First', color='lightgreen')
        axes[0, 1].bar(x + width/2, p2_win_rates, width, label='Starting Second', color='lightcoral')
        axes[0, 1].set_title('Win Rates by Starting Position')
        axes[0, 1].set_ylabel('Win Rate')
        axes[0, 1].set_ylim(0, 1)
        axes[0, 1].set_xticks(x)
        axes[0, 1].set_xticklabels(opponents, rotation=45)
        axes[0, 1].legend()

        # Add value labels on bars
        for i, (v1, v2) in enumerate(zip(p1_win_rates, p2_win_rates)):
            if v1 > 0:
                axes[0, 1].text(i - width/2, v1 + 0.01, f'{v1:.1%}', ha='center', va='bottom', fontsize=8)
            if v2 > 0:
                axes[0, 1].text(i + width/2, v2 + 0.01, f'{v2:.1%}', ha='center', va='bottom', fontsize=8)

        # Game lengths
        axes[1, 0].bar(opponents, game_lengths)
        axes[1, 0].set_title('Average Game Length by Opponent')
        axes[1, 0].set_ylabel('Moves per Game')
        axes[1, 0].tick_params(axis='x', rotation=45)
        for i, v in enumerate(game_lengths):
            axes[1, 0].text(i, v + 0.5, f'{v:.1f}', ha='center', va='bottom')

        # Win rate vs game length scatter
        axes[1, 1].scatter(game_lengths, win_rates, s=100, alpha=0.7, color='purple')
        for i, opp in enumerate(opponents):
            axes[1, 1].annotate(opp, (game_lengths[i], win_rates[i]),
                              xytext=(5, 5), textcoords='offset points', fontsize=8)
        axes[1, 1].set_xlabel('Average Game Length')
        axes[1, 1].set_ylabel('Win Rate')
        axes[1, 1].set_title('Win Rate vs Game Length')
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()

        if save_plots:
            plt.savefig('model_performance_analysis.png', dpi=300, bbox_inches='tight')
            print("Plots saved as 'model_performance_analysis.png'")

        plt.show()

    def save_results(self, filename='test_results.json'):
        """Save test results to file"""
        results_to_save = {
            'test_results': self.test_results,
            'sample_games': self.sample_games
        }
        with open(filename, 'w') as f:
            json.dump(results_to_save, f, indent=2, default=str)  # default=str to handle numpy types
        print(f"Results saved to {filename}")

    def load_results(self, filename='test_results.json'):
        """Load test results from file"""
        try:
            with open(filename, 'r') as f:
                data = json.load(f)
            self.test_results = data.get('test_results', {})
            self.sample_games = data.get('sample_games', {})
            print(f"Results loaded from {filename}")
        except FileNotFoundError:
            print(f"File {filename} not found.")


# Import the player classes from the previous artifact
# (Assuming they are available in the same environment)

def create_test_suite(game, model, args):
    """Create a standard test suite with common opponents"""

    # Create various opponents
    opponents = {
        'Player1': Player1(),
        'Player2': Player2(),
        'Player3': Player3(WIN_SIZE)
    }

    # You can add more opponents like:
    # 'Pretrained_Model': MCTSPlayer(pretrained_model, args),
    # 'Weaker_MCTS': MCTSPlayer(model, {**args, 'num_simulations': 10}),

    return opponents

def play_against_agent(game, agent_player, human_first=True):
    """
    Play a human vs any BasePlayer agent (MCTS, heuristic, etc.).
    """
    state = game.get_init_board()
    current_player = -1 if human_first else 1

    is_2d = len(state.shape) == 2
    rows, cols = state.shape if is_2d else (1, state.shape[0])

    def display_board(board):
        symbols = {0: ".", 1: "X", -1: "O"}
        print("\nBoard:")
        for r in range(rows):
            row = [symbols[board[r, c] if is_2d else board[c]] for c in range(cols)]
            print(" " + " | ".join(row))
        print()

    def get_move_from_human(valid):
        while True:
            try:
                if is_2d:
                    inp = input("Your move (row col): ").strip().split()
                    if len(inp) != 2:
                        raise ValueError
                    r, c = map(int, inp)
                    a = r * cols + c
                else:
                    a = int(input(f"Your move (0 to {cols - 1}): "))
                if 0 <= a < len(valid) and valid[a] == 1:
                    return a
            except:
                print("Invalid input. Try again.")

    def display_probs(probs):
        print("Agent move probabilities:")
        for r in range(rows):
            row_probs = [f"{probs[r * cols + c]:.2f}" for c in range(cols)]
            print(" " + " | ".join(row_probs))
        print()

    print("Welcome! You're O. Agent is X.")
    display_board(state)

    while True:
        if current_player == -1:
            move = get_move_from_human(game.get_valid_moves(state))
        else:
            print("Agent is thinking...")
            probs, _ = agent_player.get_action_probs(game, state, current_player)
            move = agent_player.get_action(game, state, current_player)
            display_probs(probs)

        state, current_player = game.get_next_state(state, current_player, move)
        display_board(state)

        reward = game.get_reward_for_player(state, current_player)
        if reward is not None:
            print("Game over!")
            if reward == 0:
                print("It's a draw!")
            else:
                print("You win!" if current_player == 1 else "Agent wins!")
            break


### Test Results

In [None]:
# Example usage:

# Setup
game = TicTacToeK(BOARD_ROWS, BOARD_COLS, WIN_SIZE)
args = {'num_simulations': 25, 'batch_size': 64}

# Create tester
tester = ModelTester(game, model, args)

# Create test opponents
test_opponents = create_test_suite(game, model, args)

# Run comprehensive tests
results = tester.comprehensive_test(test_opponents, num_games_per_opponent=100)

# Generate report and visualizations
tester.create_performance_report()
tester.plot_results()

# Analyze playing style
style_analysis = tester.analyze_playing_style(RandomPlayer(), num_games=50)
print("Playing style analysis:", style_analysis)

# Save results for later analysis
tester.save_results('my_model_results.json')

## You can play with any model yourself.

In [None]:
agent = # ...
play_against_agent(game, agent, human_first=True)

### ⭐ Bonus Task (3 extra points)

Implement an agent that **never loses** in the 3×3 Connect-3 game:
- It should at least guarantee a draw against any opponent.
- Tip: You may hardcode logic, use exhaustive search, or simplify MCTS with perfect rollout.

Once implemented, evaluate this agent on **larger boards** (e.g., 4×4 or 5×5) and reflect on its limitations in more complex settings.