In [1]:
import subprocess

try:
    # Check for NVIDIA GPU
    result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
    if result.returncode == 0:
        print("✓ NVIDIA GPU detected!")
        print("\nGPU Info:")
        print(result.stdout)
    else:
        print("✗ No NVIDIA GPU found")
except FileNotFoundError:
    print("✗ nvidia-smi not found - no NVIDIA GPU or drivers not installed")

# Check CUDA availability in PyTorch (if you have it)
try:
    import torch
    print(f"\nPyTorch CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
except ImportError:
    print("\nPyTorch not installed (not needed for GTM)")

✓ NVIDIA GPU detected!

GPU Info:
Sun Dec 14 18:08:04 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.172.08             Driver Version: 570.172.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM3-32GB           On  |   00000000:57:00.0 Off |                    0 |
| N/A   29C    P0             49W /  350W |       0MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
              

# Methods - Optimized Version

In [2]:
from GraphTsetlinMachine.tm import MultiClassGraphTsetlinMachine
from GraphTsetlinMachine.graphs import Graphs
import numpy as np
import subprocess
import time
import os
from tqdm import tqdm

# Configuration
NUMBER_OF_CLAUSES = 400
T = 2000
S = 5.0
DEPTH = 8
MESSAGE_SIZE = 512
MESSAGE_BITS = 2
BOARD_DIM = 3
NOTEBOOK_DIR = os.path.dirname(os.path.abspath("Tsetlin.ipynb"))
HEX_DIR = os.path.join(NOTEBOOK_DIR, "TsetlinMachine/hex")

if not os.path.exists(HEX_DIR):
    raise FileNotFoundError(f"ERROR: Cannot find hex.c at {HEX_DIR}")

print("Building hex using make...")

try:
    result = subprocess.run(
        ["make"],
        cwd=HEX_DIR,
        capture_output=True,
        text=True
    )

    print("=== Make Output ===")
    print(result.stdout)
    if result.stderr.strip():
        print("=== Make Errors ===")
        print(result.stderr)

    if result.returncode == 0:
        print("\n✓ Build successful!")
    else:
        print("\n❌ Build failed! See errors above.")

except Exception as e:
    print("Exception when running make:", e)

def c_position_to_node_id(c_position, board_dim=BOARD_DIM):
    padded_dim = board_dim + 2
    i = c_position // padded_dim
    j = c_position % padded_dim
    node_id = (i - 1) * board_dim + (j - 1)

    if node_id < 0 or node_id >= board_dim * board_dim:
        return None
    return node_id

def get_hex_edges(board_dim=BOARD_DIM):
    edges = []
    neighbor_offsets = [(0, 1), (0, -1), (-1, 1), (1, -1), (-1, 0), (1, 0)]

    for i in range(board_dim):
        for j in range(board_dim):
            node_id = i * board_dim + j
            for di, dj in neighbor_offsets:
                ni, nj = i + di, j + dj
                if 0 <= ni < board_dim and 0 <= nj < board_dim:
                    neighbor_id = ni * board_dim + nj
                    edges.append((node_id, neighbor_id))

    return edges

def parse_game_output(output):
    games = []
    current_game = None

    for line in output.split('\n'):
        line = line.strip()

        if line == "GAME_START":
            current_game = {'moves': [], 'winner': -1}
        elif line.startswith("MOVE"):
            if current_game is not None:
                parts = line.split()
                if len(parts) >= 3:
                    position = int(parts[1])
                    player = int(parts[2])
                    current_game['moves'].append((position, player))
        elif line.startswith("WINNER"):
            if current_game is not None:
                parts = line.split()
                if len(parts) >= 2:
                    current_game['winner'] = int(parts[1])
        elif line == "GAME_END":
            if current_game and current_game['winner'] != -1:
                games.append(current_game)
            current_game = None

    return games

def create_training_data_from_game(moves, winner, board_dim=BOARD_DIM):
    """
    Create ONE training sample per game:
      - board_state: final board (0=empty, 1=player0, 2=player1)
      - label: winner of the game (0 or 1)

    We keep node_features as before for compatibility, but the main
    object we care about is the final board_state.
    """
    num_nodes = board_dim * board_dim
    board_state = np.zeros(num_nodes, dtype=np.int32)
    edges = get_hex_edges(board_dim)  # not used in new Graphs, but kept

    # Play through the whole game to get the final board
    for c_position, player in moves:
        node_id = c_position_to_node_id(c_position, board_dim)
        if node_id is None:
            print(f"Skipping invalid move: c_pos={c_position}")
            continue
        board_state[node_id] = player + 1  # 1 = player 0, 2 = player 1

    # Build node_features from the FINAL full board_state
    node_features = np.zeros((num_nodes, 3), dtype=np.int32)
    for nid in range(num_nodes):
        if board_state[nid] == 1:
            node_features[nid, 0] = 1  # player_0 stone
        elif board_state[nid] == 2:
            node_features[nid, 1] = 1  # player_1 stone
        else:
            node_features[nid, 2] = 1  # empty

    label = int(winner)  # 0 or 1

    sample = {
        'board_state': board_state.reshape(board_dim, board_dim),
        'node_features': node_features,   # still there if you need it
        'edges': edges,                   # unused in new Graph building
        'position': -1,                   # not used now
        'player': -1,                     # not used now
        'label': label
    }

    # Return as a list to match old API
    return [sample]


def prepare_training_data(games, board_dim=BOARD_DIM):
    """
    Turn a list of games into a list of FINAL-state → winner samples.
    Exactly one sample per game.
    """
    all_samples = []

    print(f"Processing {len(games)} games into training samples...")

    # Game-level statistics
    player_0_wins = sum(1 for g in games if g['winner'] == 0)
    player_1_wins = sum(1 for g in games if g['winner'] == 1)
    print(f"Game outcomes (per game):")
    print(f"  Player 0 wins: {player_0_wins}")
    print(f"  Player 1 wins: {player_1_wins}")

    for game in tqdm(games, desc="Processing games"):
        samples = create_training_data_from_game(game['moves'], game['winner'], board_dim)
        all_samples.extend(samples)

    if len(all_samples) == 0:
        print("ERROR: No training samples created! Check your logic.")
        return all_samples

    labels = [s['label'] for s in all_samples]
    unique, counts = np.unique(labels, return_counts=True)
    print(f"\nLabel distribution (winner classes, per FINAL board):")
    for label, count in zip(unique, counts):
        print(f"  Winner {label}: {count} games ({count/len(labels)*100:.1f}%)")

    # Quick sanity check of final board states
    print("\nSample final board state check (first 5 samples):")
    for i in range(min(5, len(all_samples))):
        sample = all_samples[i]
        pieces = np.sum(sample['node_features'][:, :2])
        empties = np.sum(sample['node_features'][:, 2])
        print(f"  Sample {i}: {pieces} stones, {empties} empty cells, label(winner)={sample['label']}")

    print(f"{'='*60}\n")

    return all_samples


def generate_game_data(num_games=1000, hex_dir=HEX_DIR):
    hex_executable = os.path.join(hex_dir, "hex")

    if not os.path.exists(hex_executable):
        print(f"Executable not found at {hex_executable}")
        return []

    print(f"Generating {num_games} games...")

    try:
        result = subprocess.run(
            [hex_executable, str(num_games)],
            cwd=hex_dir,
            capture_output=True,
            text=True,
            timeout=120
        )

        if result.returncode != 0:
            print(f"Error running hex executable:")
            print(result.stderr)
            return []

        games = parse_game_output(result.stdout)
        print(f"Successfully parsed {len(games)} games from output")
        return games

    except Exception as e:
        print(f"Error running hex executable: {e}")
        return []


def prepare_gtm_data_multinode(training_samples, board_dim=BOARD_DIM,
                                hypervector_size=1024, hypervector_bits=2,
                                init_with=None):
    """
    Multi-node graph representation: 121 nodes per graph with Hex connectivity.
    Args:
        init_with: Optional Graphs object to copy hypervectors from (for evaluation)
    """
    Y = np.array([s['label'] for s in training_samples], dtype=np.int32)
    num_graphs = len(training_samples)
    num_nodes_per_graph = board_dim * board_dim

    symbols = ["Empty", "Player0", "Player1"]

    print(f"\n{'='*70}")
    print("CREATING MULTI-NODE GRAPH REPRESENTATION")
    print(f"{'='*70}")
    print(f"  Graphs: {num_graphs}, Nodes per graph: {num_nodes_per_graph} ({board_dim}×{board_dim})")

    graphs = Graphs(
        num_graphs,
        symbols=symbols,
        hypervector_size=hypervector_size,
        hypervector_bits=hypervector_bits,
        init_with=init_with
    )

    # Hex neighbor offsets
    neighbor_offsets = [(0, 1), (0, -1), (-1, 1), (1, -1), (-1, 0), (1, 0)]

    # Pre-calculate edges per node (needed for add_graph_node)
    edges_per_node = np.zeros(num_nodes_per_graph, dtype=np.uint32)
    for i in range(board_dim):
        for j in range(board_dim):
            node_id = i * board_dim + j
            edge_count = sum(1 for di, dj in neighbor_offsets
                           if 0 <= i+di < board_dim and 0 <= j+dj < board_dim)
            edges_per_node[node_id] = edge_count

    # Step 1: Set node counts
    print("Step 1: Configuring nodes...")
    for graph_id in range(num_graphs):
        graphs.set_number_of_graph_nodes(graph_id, num_nodes_per_graph)

    graphs.prepare_node_configuration()

    # Step 2: Add nodes with edge counts
    print("Step 2: Adding nodes...")
    for graph_id in range(num_graphs):
        for node_id in range(num_nodes_per_graph):
            graphs.add_graph_node(graph_id, node_id, edges_per_node[node_id])

    graphs.prepare_edge_configuration()

    # Step 3: Add edges
    print("Step 3: Adding edges...")
    for graph_id in range(num_graphs):
        for i in range(board_dim):
            for j in range(board_dim):
                node_id = i * board_dim + j
                for di, dj in neighbor_offsets:
                    ni, nj = i + di, j + dj
                    if 0 <= ni < board_dim and 0 <= nj < board_dim:
                        neighbor_id = ni * board_dim + nj
                        graphs.add_graph_node_edge(graph_id, node_id, neighbor_id, "hex_edge")

    # Step 4: Add node properties
    print("Step 4: Adding properties...")
    for graph_id in range(num_graphs):
        node_features = training_samples[graph_id]['node_features']
        for node_id in range(num_nodes_per_graph):
            if node_features[node_id, 0] == 1:
                graphs.add_graph_node_property(graph_id, node_id, "Player0")
            elif node_features[node_id, 1] == 1:
                graphs.add_graph_node_property(graph_id, node_id, "Player1")
            else:
                graphs.add_graph_node_property(graph_id, node_id, "Empty")

    # Step 5: Encode
    print("Step 5: Encoding...")
    graphs.encode()

    print(f"✓ Multi-node graphs created!\n{'='*70}\n")
    return graphs, Y



def train_model(graphs, Y, epochs=100):
    """Train with hyperparameters from global config"""
    print("Initializing Graph Tsetlin Machine...")
    print(f"  Clauses: {NUMBER_OF_CLAUSES}")
    print(f"  T: {T}")
    print(f"  s: {S}")
    print(f"  Depth: {DEPTH}")
    print(f"  Message Size: {MESSAGE_SIZE}")

    tm = MultiClassGraphTsetlinMachine(
        number_of_clauses=NUMBER_OF_CLAUSES,
        T=T,
        s=S,
        number_of_state_bits=8,
        depth=DEPTH,
        message_size=MESSAGE_SIZE,
        message_bits=MESSAGE_BITS,
        max_included_literals=32,
        grid=(16 * 13, 1, 1),
        block=(128, 1, 1)
    )

    # Class balancing: oversample minority class of winner
    class_0_indices = np.where(Y == 0)[0]
    class_1_indices = np.where(Y == 1)[0]

    print(f"\nClass distribution before balancing (winner classes):")
    print(f"  Winner 0: {len(class_0_indices)} states")
    print(f"  Winner 1: {len(class_1_indices)} states")

    if len(class_0_indices) > 0 and len(class_1_indices) > 0:
        # Balance by repeating minority class
        if len(class_0_indices) < len(class_1_indices):
            oversample_ratio = len(class_1_indices) // len(class_0_indices)
            class_0_indices = np.tile(class_0_indices, oversample_ratio)
        else:
            oversample_ratio = len(class_0_indices) // len(class_1_indices)
            class_1_indices = np.tile(class_1_indices, oversample_ratio)

        balanced_indices = np.concatenate([class_0_indices, class_1_indices])
        np.random.shuffle(balanced_indices)
        print(f"After balancing: {len(balanced_indices)} total samples")
    else:
        balanced_indices = np.arange(len(Y))
        print("Warning: only one winner class present; no balancing applied.")

    print(f"\nStarting training for {epochs} epochs...")
    print("="*60)

    start_total = time.time()

    for epoch in range(epochs):
        start_epoch = time.time()

        # NOTE: current GTM implementation uses full graph set; balancing
        # is mainly informational here. To actually subsample, GTM would need
        # support for graph subsets.
        tm.fit(graphs, Y, epochs=1, incremental=True)
        elapsed = time.time() - start_epoch

        predictions = tm.predict(graphs)
        accuracy = 100 * (predictions == Y).mean()

        class_0_mask = (Y == 0)
        class_1_mask = (Y == 1)
        class_0_acc = 100 * (predictions[class_0_mask] == 0).mean() if class_0_mask.any() else 0
        class_1_acc = 100 * (predictions[class_1_mask] == 1).mean() if class_1_mask.any() else 0

        print(f"Epoch {epoch+1}/{epochs} - Acc: {accuracy:.2f}% "
              f"(Winner 0 states: {class_0_acc:.1f}%, Winner 1 states: {class_1_acc:.1f}%) - {elapsed:.2f}s")

    total_time = time.time() - start_total
    print("\n" + "="*60)
    print(f"✓ Training completed in {total_time:.2f}s ({total_time/60:.2f} minutes)")

    print("\nFinal Evaluation...")
    predictions = tm.predict(graphs)

    accuracy = 100 * (predictions == Y).mean()
    print(f"\nOverall Accuracy: {accuracy:.2f}%")

    for class_id in [0, 1]:
        mask = Y == class_id
        if mask.any():
            class_acc = 100 * (predictions[mask] == class_id).mean()
            print(f"Winner {class_id} states: {class_acc:.2f}% "
                  f"({(predictions[mask] == class_id).sum()}/{mask.sum()})")

    return tm, predictions

def save_model(tm, filepath="TsetlinMachine/hex_tm_model.pkl",
               board_dim=11, additional_info=None):
    """Save trained model with metadata"""
    print(f"Saving model to {filepath}...")

    state_dict = tm.save(fname=filepath)

    print(f"✓ Model saved successfully to {filepath}")
    return state_dict

Building hex using make...
=== Make Output ===
make: 'hex' is up to date.


✓ Build successful!


## Generate Game Data

Run this cell to generate Hex games and create training samples.

In [3]:
# Generate games
NUM_GAMES = 10000

print(f"Generating {NUM_GAMES} Hex games...")
games = generate_game_data(NUM_GAMES)

if not games:
    raise Exception("No games generated! Check hex executable.")

print(f"\n✓ Successfully generated {len(games)} games")

# Process into training samples (FINAL state -> winner)
training_samples = prepare_training_data(games, BOARD_DIM)
print("\n" + "="*60)
print("DIAGNOSTIC: Checking final board states:")
print("="*60)
for i in range(min(10, len(training_samples))):
    sample = training_samples[i]
    non_zero = np.sum(sample['node_features'][:, :2])  # Count player pieces
    empty = np.sum(sample['node_features'][:, 2])      # Count empty cells
    print(f"Sample {i}: {non_zero} stones on board, {empty} empty cells, label(winner)={sample['label']}")

print("="*60 + "\n")
print(f"\n✓ Training data ready: {len(training_samples)} samples")


Generating 10000 Hex games...
Generating 10000 games...
Successfully parsed 10000 games from output

✓ Successfully generated 10000 games
Processing 10000 games into training samples...
Game outcomes (per game):
  Player 0 wins: 6676
  Player 1 wins: 3324


Processing games: 100%|██████████| 10000/10000 [00:00<00:00, 29120.40it/s]


Label distribution (winner classes, per FINAL board):
  Winner 0: 6676 games (66.8%)
  Winner 1: 3324 games (33.2%)

Sample final board state check (first 5 samples):
  Sample 0: 9 stones, 0 empty cells, label(winner)=0
  Sample 1: 7 stones, 2 empty cells, label(winner)=0
  Sample 2: 7 stones, 2 empty cells, label(winner)=0
  Sample 3: 7 stones, 2 empty cells, label(winner)=0
  Sample 4: 9 stones, 0 empty cells, label(winner)=0


DIAGNOSTIC: Checking final board states:
Sample 0: 9 stones on board, 0 empty cells, label(winner)=0
Sample 1: 7 stones on board, 2 empty cells, label(winner)=0
Sample 2: 7 stones on board, 2 empty cells, label(winner)=0
Sample 3: 7 stones on board, 2 empty cells, label(winner)=0
Sample 4: 9 stones on board, 0 empty cells, label(winner)=0
Sample 5: 7 stones on board, 2 empty cells, label(winner)=0
Sample 6: 9 stones on board, 0 empty cells, label(winner)=0
Sample 7: 6 stones on board, 3 empty cells, label(winner)=1
Sample 8: 6 stones on board, 3 empty cells, 




## Prepare Data for Graph Tsetlin Machine

Convert training samples into the GTM Graphs format.

In [4]:
graphs, Y = prepare_gtm_data_multinode(
    training_samples,
    board_dim=BOARD_DIM,
    hypervector_size=1024,
    hypervector_bits=2
)


CREATING MULTI-NODE GRAPH REPRESENTATION
  Graphs: 10000, Nodes per graph: 9 (3×3)
Step 1: Configuring nodes...
Step 2: Adding nodes...
Step 3: Adding edges...
Step 4: Adding properties...
Step 5: Encoding...
✓ Multi-node graphs created!



## Train the Graph Tsetlin Machine

Train the model on the prepared graph data.

In [5]:
#tm, predictions = train_model(graphs, Y, epochs=50)

In [6]:
#save_model(tm=tm, filepath="TsetlinMachine/hex_tm_model.pkl", board_dim=BOARD_DIM, additional_info=None)

## Load a Trained Model (Optional)

Use this to load a previously trained model.

In [7]:
from sklearn.metrics import confusion_matrix, classification_report, precision_score, recall_score, f1_score
from GraphTsetlinMachine.tm import MultiClassGraphTsetlinMachine

def evaluate_model(model_path="TsetlinMachine/hex_tm_model.pkl",
                   num_test_games=100000, verbose=True):
    """Complete model evaluation"""
    if verbose:
        print(f"\n{'='*70}")
        print("MODEL EVALUATION")
        print(f"{'='*70}")

    # Generate test data
    if verbose:
        print(f"\nGenerating {num_test_games} test games...")
    test_games = generate_game_data(num_test_games)

    test_samples = []
    for g in test_games:
        samples = create_training_data_from_game(g["moves"], g["winner"], BOARD_DIM)
        test_samples.extend(samples)

    Y_test = np.array([s["label"] for s in test_samples], dtype=np.int32)

    if verbose:
        print(f"✓ Created {len(test_samples)} test samples")
        print(f"  Winner 0: {np.sum(Y_test==0)}, Winner 1: {np.sum(Y_test==1)}")

    # Prepare graphs
    if verbose:
        print(f"Preparing multi-node graphs for evaluation...")

    graphs_test, _ = prepare_gtm_data_multinode(
        test_samples,
        board_dim=BOARD_DIM,
        hypervector_size=1024,
        hypervector_bits=2,
        init_with=graphs
    )

    if verbose:
        print(f"✓ Graphs prepared\nMaking predictions...")

    # Predict
    predictions = tm.predict(graphs_test).astype(int)

    # Calculate metrics
    overall_acc = 100 * (predictions == Y_test).mean()

    # Print results
    print(f"\n{'='*70}")
    print(f"RESULTS - {len(test_samples)} samples from {num_test_games} games")
    print(f"{'='*70}")
    print(f"\nOverall Accuracy: {overall_acc:.2f}%")

    # Per-class accuracy
    for winner in [0, 1]:
        mask = Y_test == winner
        if mask.any():
            acc = 100 * (predictions[mask] == winner).mean()
            correct = (predictions[mask] == winner).sum()
            total = mask.sum()
            print(f"Winner {winner} Accuracy: {acc:.2f}% ({correct}/{total})")

    # Confusion matrix
    cm = confusion_matrix(Y_test, predictions)
    print(f"\nConfusion Matrix:")
    print(f"              Predicted")
    print(f"              0      1")
    print(f"Actual  0  [{cm[0,0]:4d}  {cm[0,1]:4d}]")
    print(f"        1  [{cm[1,0]:4d}  {cm[1,1]:4d}]")

    # Additional metrics
    p0 = precision_score(Y_test, predictions, pos_label=0, zero_division=0)
    r0 = recall_score(Y_test, predictions, pos_label=0, zero_division=0)
    f1_0 = f1_score(Y_test, predictions, pos_label=0, zero_division=0)

    p1 = precision_score(Y_test, predictions, pos_label=1, zero_division=0)
    r1 = recall_score(Y_test, predictions, pos_label=1, zero_division=0)
    f1_1 = f1_score(Y_test, predictions, pos_label=1, zero_division=0)

    print(f"\nDetailed Metrics:")
    print(f"  Winner 0: Precision={p0:.3f}, Recall={r0:.3f}, F1={f1_0:.3f}")
    print(f"  Winner 1: Precision={p1:.3f}, Recall={r1:.3f}, F1={f1_1:.3f}")

    return {
        'accuracy': overall_acc,
        'predictions': predictions,
        'labels': Y_test,
        'confusion_matrix': cm,
        'metrics': {'p0': p0, 'r0': r0, 'f1_0': f1_0, 'p1': p1, 'r1': r1, 'f1_1': f1_1}
    }

In [8]:
#results = evaluate_model(num_test_games=100000)


# Experiments and Scenarios

In [9]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from sklearn.metrics import precision_score, recall_score, f1_score
import time
from tqdm import tqdm
import json

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100

# Configuration
BOARD_DIM = 3
TRAIN_GAMES = 10000
TEST_GAMES = 10000
EPOCHS = 50

print("="*70)
print("3×3 HEX GTM EXPERIMENTS: MODEL CAPACITY & DEPTH ANALYSIS")
print("="*70)
print(f"\nConfiguration:")
print(f"  Board Size: {BOARD_DIM}×{BOARD_DIM}")
print(f"  Training Games: {TRAIN_GAMES:,}")
print(f"  Test Games: {TEST_GAMES:,}")
print(f"  Epochs: {EPOCHS}")
print("="*70)

print("\n[1/2] Generating training data...")
train_games = generate_game_data(TRAIN_GAMES)
train_samples = prepare_training_data(train_games, BOARD_DIM)
Y_train_label = np.array([s['label'] for s in train_samples], dtype=np.int32)

print(f"✓ Training samples: {len(train_samples)}")
print(f"  Winner 0: {np.sum(Y_train_label==0)} ({100*np.mean(Y_train_label==0):.1f}%)")
print(f"  Winner 1: {np.sum(Y_train_label==1)} ({100*np.mean(Y_train_label==1):.1f}%)")

print("\n[2/2] Generating test data...")
test_games = generate_game_data(TEST_GAMES)
test_samples = prepare_training_data(test_games, BOARD_DIM)
Y_test = np.array([s['label'] for s in test_samples], dtype=np.int32)

print(f"✓ Test samples: {len(test_samples)}")
print(f"  Winner 0: {np.sum(Y_test==0)} ({100*np.mean(Y_test==0):.1f}%)")
print(f"  Winner 1: {np.sum(Y_test==1)} ({100*np.mean(Y_test==1):.1f}%)")

print("\n✓ Data generation complete!")

3×3 HEX GTM EXPERIMENTS: MODEL CAPACITY & DEPTH ANALYSIS

Configuration:
  Board Size: 3×3
  Training Games: 10,000
  Test Games: 10,000
  Epochs: 50

[1/2] Generating training data...
Generating 10000 games...
Successfully parsed 10000 games from output
Processing 10000 games into training samples...
Game outcomes (per game):
  Player 0 wins: 6676
  Player 1 wins: 3324


Processing games: 100%|██████████| 10000/10000 [00:00<00:00, 28048.73it/s]



Label distribution (winner classes, per FINAL board):
  Winner 0: 6676 games (66.8%)
  Winner 1: 3324 games (33.2%)

Sample final board state check (first 5 samples):
  Sample 0: 9 stones, 0 empty cells, label(winner)=0
  Sample 1: 7 stones, 2 empty cells, label(winner)=0
  Sample 2: 7 stones, 2 empty cells, label(winner)=0
  Sample 3: 7 stones, 2 empty cells, label(winner)=0
  Sample 4: 9 stones, 0 empty cells, label(winner)=0

✓ Training samples: 10000
  Winner 0: 6676 (66.8%)
  Winner 1: 3324 (33.2%)

[2/2] Generating test data...
Generating 10000 games...
Successfully parsed 10000 games from output
Processing 10000 games into training samples...
Game outcomes (per game):
  Player 0 wins: 6676
  Player 1 wins: 3324


Processing games: 100%|██████████| 10000/10000 [00:00<00:00, 28453.59it/s]


Label distribution (winner classes, per FINAL board):
  Winner 0: 6676 games (66.8%)
  Winner 1: 3324 games (33.2%)

Sample final board state check (first 5 samples):
  Sample 0: 9 stones, 0 empty cells, label(winner)=0
  Sample 1: 7 stones, 2 empty cells, label(winner)=0
  Sample 2: 7 stones, 2 empty cells, label(winner)=0
  Sample 3: 7 stones, 2 empty cells, label(winner)=0
  Sample 4: 9 stones, 0 empty cells, label(winner)=0

✓ Test samples: 10000
  Winner 0: 6676 (66.8%)
  Winner 1: 3324 (33.2%)

✓ Data generation complete!





In [None]:
import gc

print("\n" + "="*70)
print("EXPERIMENT 2: MODEL CAPACITY (CLAUSE COUNT)")
print("="*70)

# Experiment configurations
exp2_configs = [
    {"name": "Minimal", "clauses": 100, "T": 500},
    {"name": "Small", "clauses": 200, "T": 1000},
    {"name": "Medium", "clauses": 400, "T": 2000},
    {"name": "Large", "clauses": 800, "T": 4000},
    {"name": "XLarge", "clauses": 1600, "T": 8000},
]

# Fixed parameters
EXP2_DEPTH = 8
EXP2_S = 5.0
EXP2_MSG_SIZE = 512

exp2_results = []
exp2_predictions = []
exp2_learning_curves = []

for config in exp2_configs:
    print(f"\n{'='*70}")
    print(f"Running: {config['name']} - {config['clauses']} clauses, T={config['T']}")
    print(f"{'='*70}")

    start_time = time.time()

    # Prepare graphs
    print("  Preparing graphs...")
    graphs_train, Y_train = prepare_gtm_data_multinode(
        train_samples,
        board_dim=BOARD_DIM,
        hypervector_size=1024,
        hypervector_bits=2
    )
    gc.collect()
    torch.cuda.empty_cache()
    # Initialize model
    print(f"  Initializing GTM ({config['clauses']} clauses)...")
    tm = MultiClassGraphTsetlinMachine(
        number_of_clauses=config['clauses'],
        T=config['T'],
        s=EXP2_S,
        depth=EXP2_DEPTH,
        message_size=EXP2_MSG_SIZE,
        message_bits=2
    )

    # Training with epoch tracking
    print(f"  Training for {EPOCHS} epochs...")
    epoch_history = []

    for epoch in tqdm(range(EPOCHS), desc="  Progress"):
        tm.fit(graphs_train, Y_train, epochs=1, incremental=True)

        # Evaluate on training set
        preds_train = tm.predict(graphs_train)
        train_acc = 100 * (preds_train == Y_train).mean()

        # Per-class accuracy
        class_0_mask = Y_train == 0
        class_1_mask = Y_train == 1
        class_0_acc = 100 * (preds_train[class_0_mask] == 0).mean() if class_0_mask.any() else 0
        class_1_acc = 100 * (preds_train[class_1_mask] == 1).mean() if class_1_mask.any() else 0

        epoch_history.append({
            'epoch': epoch + 1,
            'train_acc': train_acc,
            'class_0_acc': class_0_acc,
            'class_1_acc': class_1_acc,
            'gap': abs(class_0_acc - class_1_acc)
        })

    # Prepare test graphs
    print("  Preparing test graphs...")
    graphs_test, _ = prepare_gtm_data_multinode(
        test_samples,
        board_dim=BOARD_DIM,
        init_with=graphs_train  # CRITICAL
    )

    # Evaluate on test set
    print("  Evaluating on test set...")
    test_preds = tm.predict(graphs_test)

    # Calculate metrics
    test_acc = 100 * accuracy_score(Y_test, test_preds)
    test_class_0_acc = 100 * (test_preds[Y_test==0] == 0).mean()
    test_class_1_acc = 100 * (test_preds[Y_test==1] == 1).mean()

    # Classification report
    precision_0 = precision_score(Y_test, test_preds, pos_label=0, zero_division=0)
    recall_0 = recall_score(Y_test, test_preds, pos_label=0, zero_division=0)
    f1_0 = f1_score(Y_test, test_preds, pos_label=0, zero_division=0)

    precision_1 = precision_score(Y_test, test_preds, pos_label=1, zero_division=0)
    recall_1 = recall_score(Y_test, test_preds, pos_label=1, zero_division=0)
    f1_1 = f1_score(Y_test, test_preds, pos_label=1, zero_division=0)

    training_time = time.time() - start_time

    # Store results
    result = {
        'experiment': config['name'],
        'clauses': config['clauses'],
        'T': config['T'],
        'depth': EXP2_DEPTH,
        's': EXP2_S,
        'test_acc': test_acc,
        'test_class_0_acc': test_class_0_acc,
        'test_class_1_acc': test_class_1_acc,
        'class_gap': abs(test_class_0_acc - test_class_1_acc),
        'precision_0': precision_0,
        'recall_0': recall_0,
        'f1_0': f1_0,
        'precision_1': precision_1,
        'recall_1': recall_1,
        'f1_1': f1_1,
        'training_time_sec': training_time,
        'epoch_history': epoch_history
    }

    exp2_results.append(result)
    exp2_predictions.append(test_preds)
    exp2_learning_curves.append(pd.DataFrame(epoch_history))

    # Print summary
    print(f"\n  Results Summary:")
    print(f"    Training Time: {training_time/60:.2f} minutes")
    print(f"    Test Accuracy: {test_acc:.2f}%")
    print(f"    Winner 0: {test_class_0_acc:.2f}% (P={precision_0:.3f}, R={recall_0:.3f}, F1={f1_0:.3f})")
    print(f"    Winner 1: {test_class_1_acc:.2f}% (P={precision_1:.3f}, R={recall_1:.3f}, F1={f1_1:.3f})")
    print(f"    Class Gap: {abs(test_class_0_acc - test_class_1_acc):.2f}%")

print("\n✓ Experiment 2 complete!")


EXPERIMENT 2: MODEL CAPACITY (CLAUSE COUNT)

Running: Minimal - 100 clauses, T=500
  Preparing graphs...

CREATING MULTI-NODE GRAPH REPRESENTATION
  Graphs: 10000, Nodes per graph: 9 (3×3)
Step 1: Configuring nodes...
Step 2: Adding nodes...
Step 3: Adding edges...
Step 4: Adding properties...
Step 5: Encoding...
✓ Multi-node graphs created!

  Initializing GTM (100 clauses)...
Initialization of sparse structure.
  Training for 50 epochs...


  Progress:   0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
print("\n" + "="*70)
print("EXPERIMENT 3: MESSAGE PASSING DEPTH")
print("="*70)

# Experiment configurations
exp3_configs = [
    {"name": "NoMP", "depth": 1},
    {"name": "Shallow", "depth": 3},
    {"name": "Medium", "depth": 5},
    {"name": "Deep", "depth": 8},
    {"name": "VeryDeep", "depth": 12},
]

# Fixed parameters
EXP3_CLAUSES = 400
EXP3_T = 2000
EXP3_S = 5.0
EXP3_MSG_SIZE = 512

exp3_results = []
exp3_predictions = []
exp3_learning_curves = []

for config in exp3_configs:
    print(f"\n{'='*70}")
    print(f"Running: {config['name']} - Depth={config['depth']}")
    print(f"{'='*70}")

    start_time = time.time()

    # Prepare graphs
    print("  Preparing graphs...")
    graphs_train, Y_train = prepare_gtm_data_multinode(
        train_samples,
        board_dim=BOARD_DIM,
        hypervector_size=1024,
        hypervector_bits=2
    )

    # Initialize model
    print(f"  Initializing GTM (depth={config['depth']})...")
    tm = MultiClassGraphTsetlinMachine(
        number_of_clauses=EXP3_CLAUSES,
        T=EXP3_T,
        s=EXP3_S,
        depth=config['depth'],
        message_size=EXP3_MSG_SIZE,
        message_bits=2
    )

    # Training with epoch tracking
    print(f"  Training for {EPOCHS} epochs...")
    epoch_history = []

    for epoch in tqdm(range(EPOCHS), desc="  Progress"):
        tm.fit(graphs_train, Y_train, epochs=1, incremental=True)

        # Evaluate on training set
        preds_train = tm.predict(graphs_train)
        train_acc = 100 * (preds_train == Y_train).mean()

        # Per-class accuracy
        class_0_mask = Y_train == 0
        class_1_mask = Y_train == 1
        class_0_acc = 100 * (preds_train[class_0_mask] == 0).mean() if class_0_mask.any() else 0
        class_1_acc = 100 * (preds_train[class_1_mask] == 1).mean() if class_1_mask.any() else 0

        epoch_history.append({
            'epoch': epoch + 1,
            'train_acc': train_acc,
            'class_0_acc': class_0_acc,
            'class_1_acc': class_1_acc,
            'gap': abs(class_0_acc - class_1_acc)
        })

    # Prepare test graphs
    print("  Preparing test graphs...")
    graphs_test, _ = prepare_gtm_data_multinode(
        test_samples,
        board_dim=BOARD_DIM,
        init_with=graphs_train  # CRITICAL
    )

    # Evaluate on test set
    print("  Evaluating on test set...")
    test_preds = tm.predict(graphs_test)

    # Calculate metrics
    test_acc = 100 * accuracy_score(Y_test, test_preds)
    test_class_0_acc = 100 * (test_preds[Y_test==0] == 0).mean()
    test_class_1_acc = 100 * (test_preds[Y_test==1] == 1).mean()

    # Classification metrics
    precision_0 = precision_score(Y_test, test_preds, pos_label=0, zero_division=0)
    recall_0 = recall_score(Y_test, test_preds, pos_label=0, zero_division=0)
    f1_0 = f1_score(Y_test, test_preds, pos_label=0, zero_division=0)

    precision_1 = precision_score(Y_test, test_preds, pos_label=1, zero_division=0)
    recall_1 = recall_score(Y_test, test_preds, pos_label=1, zero_division=0)
    f1_1 = f1_score(Y_test, test_preds, pos_label=1, zero_division=0)

    training_time = time.time() - start_time

    # Store results
    result = {
        'experiment': config['name'],
        'depth': config['depth'],
        'clauses': EXP3_CLAUSES,
        'T': EXP3_T,
        's': EXP3_S,
        'test_acc': test_acc,
        'test_class_0_acc': test_class_0_acc,
        'test_class_1_acc': test_class_1_acc,
        'class_gap': abs(test_class_0_acc - test_class_1_acc),
        'precision_0': precision_0,
        'recall_0': recall_0,
        'f1_0': f1_0,
        'precision_1': precision_1,
        'recall_1': recall_1,
        'f1_1': f1_1,
        'training_time_sec': training_time,
        'epoch_history': epoch_history
    }

    exp3_results.append(result)
    exp3_predictions.append(test_preds)
    exp3_learning_curves.append(pd.DataFrame(epoch_history))

    # Print summary
    print(f"\n  Results Summary:")
    print(f"    Training Time: {training_time/60:.2f} minutes")
    print(f"    Test Accuracy: {test_acc:.2f}%")
    print(f"    Winner 0: {test_class_0_acc:.2f}% (P={precision_0:.3f}, R={recall_0:.3f}, F1={f1_0:.3f})")
    print(f"    Winner 1: {test_class_1_acc:.2f}% (P={precision_1:.3f}, R={recall_1:.3f}, F1={f1_1:.3f})")
    print(f"    Class Gap: {abs(test_class_0_acc - test_class_1_acc):.2f}%")

print("\n✓ Experiment 3 complete!")

In [None]:
print("\n" + "="*70)
print("FINAL RESULTS SUMMARY")
print("="*70)

print("\n[EXPERIMENT 2: MODEL CAPACITY]")
print("-" * 70)
exp2_df = pd.DataFrame([{
    'Config': r['experiment'],
    'Clauses': r['clauses'],
    'T': r['T'],
    'Test Acc (%)': f"{r['test_acc']:.2f}",
    'Winner 0 (%)': f"{r['test_class_0_acc']:.2f}",
    'Winner 1 (%)': f"{r['test_class_1_acc']:.2f}",
    'Gap (%)': f"{r['class_gap']:.2f}",
    'Time (min)': f"{r['training_time_sec']/60:.1f}"
} for r in exp2_results])
print(exp2_df.to_string(index=False))

print("\n[EXPERIMENT 3: MESSAGE PASSING DEPTH]")
print("-" * 70)
exp3_df = pd.DataFrame([{
    'Config': r['experiment'],
    'Depth': r['depth'],
    'Test Acc (%)': f"{r['test_acc']:.2f}",
    'Winner 0 (%)': f"{r['test_class_0_acc']:.2f}",
    'Winner 1 (%)': f"{r['test_class_1_acc']:.2f}",
    'Gap (%)': f"{r['class_gap']:.2f}",
    'Time (min)': f"{r['training_time_sec']/60:.1f}"
} for r in exp3_results])
print(exp3_df.to_string(index=False))

# ============================================================================
# CELL 6: DETAILED METRICS
# ============================================================================

print("\n" + "="*70)
print("DETAILED METRICS")
print("="*70)

print("\n[EXPERIMENT 2: MODEL CAPACITY - DETAILED METRICS]")
print("-" * 70)
for result in exp2_results:
    print(f"\n{result['experiment']} ({result['clauses']} clauses, T={result['T']}):")
    print(f"  Overall Accuracy: {result['test_acc']:.2f}%")
    print(f"  Winner 0: Precision={result['precision_0']:.3f}, Recall={result['recall_0']:.3f}, F1={result['f1_0']:.3f}")
    print(f"  Winner 1: Precision={result['precision_1']:.3f}, Recall={result['recall_1']:.3f}, F1={result['f1_1']:.3f}")

print("\n[EXPERIMENT 3: MESSAGE PASSING DEPTH - DETAILED METRICS]")
print("-" * 70)
for result in exp3_results:
    print(f"\n{result['experiment']} (Depth={result['depth']}):")
    print(f"  Overall Accuracy: {result['test_acc']:.2f}%")
    print(f"  Winner 0: Precision={result['precision_0']:.3f}, Recall={result['recall_0']:.3f}, F1={result['f1_0']:.3f}")
    print(f"  Winner 1: Precision={result['precision_1']:.3f}, Recall={result['recall_1']:.3f}, F1={result['f1_1']:.3f}")



In [None]:
print("\nGenerating learning curves...")

# Experiment 2
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
axes = axes.flatten()

for idx, (result, curve) in enumerate(zip(exp2_results, exp2_learning_curves)):
    if idx >= 6:
        break

    axes[idx].plot(curve['epoch'], curve['train_acc'],
                  label='Overall', linewidth=2.5, color='black')
    axes[idx].plot(curve['epoch'], curve['class_0_acc'],
                  label='Winner 0', linestyle='--', linewidth=2, color='blue')
    axes[idx].plot(curve['epoch'], curve['class_1_acc'],
                  label='Winner 1', linestyle='--', linewidth=2, color='red')

    axes[idx].set_title(f"{result['experiment']}\n{result['clauses']} clauses, T={result['T']}",
                       fontsize=11, fontweight='bold')
    axes[idx].set_xlabel('Epoch', fontsize=10)
    axes[idx].set_ylabel('Training Accuracy (%)', fontsize=10)
    axes[idx].legend(fontsize=9)
    axes[idx].grid(True, alpha=0.3)
    axes[idx].set_ylim([0, 105])

if len(exp2_results) < 6:
    axes[5].axis('off')

plt.suptitle('Experiment 2: Learning Curves (Model Capacity)',
            fontsize=14, fontweight='bold', y=1.00)
plt.tight_layout()
plt.savefig('exp2_learning_curves.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved exp2_learning_curves.png")
plt.show()

# Experiment 3
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
axes = axes.flatten()

for idx, (result, curve) in enumerate(zip(exp3_results, exp3_learning_curves)):
    if idx >= 6:
        break

    axes[idx].plot(curve['epoch'], curve['train_acc'],
                  label='Overall', linewidth=2.5, color='black')
    axes[idx].plot(curve['epoch'], curve['class_0_acc'],
                  label='Winner 0', linestyle='--', linewidth=2, color='blue')
    axes[idx].plot(curve['epoch'], curve['class_1_acc'],
                  label='Winner 1', linestyle='--', linewidth=2, color='red')

    axes[idx].set_title(f"{result['experiment']}\nDepth={result['depth']}",
                       fontsize=11, fontweight='bold')
    axes[idx].set_xlabel('Epoch', fontsize=10)
    axes[idx].set_ylabel('Training Accuracy (%)', fontsize=10)
    axes[idx].legend(fontsize=9)
    axes[idx].grid(True, alpha=0.3)
    axes[idx].set_ylim([0, 105])

if len(exp3_results) < 6:
    axes[5].axis('off')

plt.suptitle('Experiment 3: Learning Curves (Message Passing Depth)',
            fontsize=14, fontweight='bold', y=1.00)
plt.tight_layout()
plt.savefig('exp3_learning_curves.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved exp3_learning_curves.png")
plt.show()

# ============================================================================
# CELL 8: VISUALIZATIONS - CONFUSION MATRICES
# ============================================================================

print("\nGenerating confusion matrices...")

# Experiment 2
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, (result, preds) in enumerate(zip(exp2_results, exp2_predictions)):
    if idx >= 6:
        break

    cm = confusion_matrix(Y_test, preds)

    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
               ax=axes[idx], cbar=False, annot_kws={'size': 14})

    axes[idx].set_title(f"{result['experiment']}\n{result['clauses']} clauses\nAcc: {result['test_acc']:.1f}%",
                       fontsize=11, fontweight='bold')
    axes[idx].set_xlabel('Predicted Winner', fontsize=10)
    axes[idx].set_ylabel('Actual Winner', fontsize=10)

if len(exp2_results) < 6:
    axes[5].axis('off')

plt.suptitle('Experiment 2: Confusion Matrices (Model Capacity)',
            fontsize=14, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig('exp2_confusion_matrices.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved exp2_confusion_matrices.png")
plt.show()

# Experiment 3
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, (result, preds) in enumerate(zip(exp3_results, exp3_predictions)):
    if idx >= 6:
        break

    cm = confusion_matrix(Y_test, preds)

    sns.heatmap(cm, annot=True, fmt='d', cmap='Greens',
               ax=axes[idx], cbar=False, annot_kws={'size': 14})

    axes[idx].set_title(f"{result['experiment']}\nDepth={result['depth']}\nAcc: {result['test_acc']:.1f}%",
                       fontsize=11, fontweight='bold')
    axes[idx].set_xlabel('Predicted Winner', fontsize=10)
    axes[idx].set_ylabel('Actual Winner', fontsize=10)

if len(exp3_results) < 6:
    axes[5].axis('off')

plt.suptitle('Experiment 3: Confusion Matrices (Message Passing Depth)',
            fontsize=14, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig('exp3_confusion_matrices.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved exp3_confusion_matrices.png")
plt.show()

# ============================================================================
# CELL 9: VISUALIZATIONS - COMPARISON PLOTS
# ============================================================================

print("\nGenerating comparison plots...")

# Experiment 2: Accuracy vs Clauses
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

clauses = [r['clauses'] for r in exp2_results]
test_accs = [r['test_acc'] for r in exp2_results]
class_0_accs = [r['test_class_0_acc'] for r in exp2_results]
class_1_accs = [r['test_class_1_acc'] for r in exp2_results]

ax1.plot(clauses, test_accs, marker='o', linewidth=2.5, markersize=10,
        color='darkgreen', label='Overall')
ax1.set_xlabel('Number of Clauses', fontsize=12)
ax1.set_ylabel('Test Accuracy (%)', fontsize=12)
ax1.set_title('Overall Accuracy vs Model Capacity', fontsize=13, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.set_xscale('log')

ax2.plot(clauses, class_0_accs, marker='o', linewidth=2.5, markersize=10,
        color='blue', label='Winner 0')
ax2.plot(clauses, class_1_accs, marker='s', linewidth=2.5, markersize=10,
        color='red', label='Winner 1')
ax2.fill_between(clauses, class_0_accs, class_1_accs, alpha=0.2, color='gray')
ax2.set_xlabel('Number of Clauses', fontsize=12)
ax2.set_ylabel('Test Accuracy (%)', fontsize=12)
ax2.set_title('Per-Class Accuracy vs Model Capacity', fontsize=13, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.set_xscale('log')

plt.tight_layout()
plt.savefig('exp2_capacity_comparison.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved exp2_capacity_comparison.png")
plt.show()

# Experiment 3: Accuracy vs Depth
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

depths = [r['depth'] for r in exp3_results]
test_accs = [r['test_acc'] for r in exp3_results]
class_0_accs = [r['test_class_0_acc'] for r in exp3_results]
class_1_accs = [r['test_class_1_acc'] for r in exp3_results]

ax1.plot(depths, test_accs, marker='o', linewidth=2.5, markersize=10,
        color='purple', label='Overall')
ax1.set_xlabel('Message Passing Depth', fontsize=12)
ax1.set_ylabel('Test Accuracy (%)', fontsize=12)
ax1.set_title('Overall Accuracy vs Depth', fontsize=13, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.axvline(x=3, color='gray', linestyle='--', alpha=0.5, label='Min Required (3)')
ax1.legend(fontsize=10)

ax2.plot(depths, class_0_accs, marker='o', linewidth=2.5, markersize=10,
        color='blue', label='Winner 0')
ax2.plot(depths, class_1_accs, marker='s', linewidth=2.5, markersize=10,
        color='red', label='Winner 1')
ax2.fill_between(depths, class_0_accs, class_1_accs, alpha=0.2, color='gray')
ax2.set_xlabel('Message Passing Depth', fontsize=12)
ax2.set_ylabel('Test Accuracy (%)', fontsize=12)
ax2.set_title('Per-Class Accuracy vs Depth', fontsize=13, fontweight='bold')
ax2.axvline(x=3, color='gray', linestyle='--', alpha=0.5, label='Min Required (3)')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('exp3_depth_comparison.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved exp3_depth_comparison.png")
plt.show()

# ============================================================================
# CELL 10: VISUALIZATIONS - CLASS GAP
# ============================================================================

print("\nGenerating class gap comparison...")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Experiment 2
gaps_exp2 = [r['class_gap'] for r in exp2_results]
names_exp2 = [r['experiment'] for r in exp2_results]

ax1.bar(range(len(gaps_exp2)), gaps_exp2, color='coral', alpha=0.7,
       edgecolor='darkred', linewidth=2)
ax1.set_xlabel('Configuration', fontsize=12)
ax1.set_ylabel('Accuracy Gap (%)', fontsize=12)
ax1.set_title('Exp 2: Winner 0 vs Winner 1 Gap\n(Lower is Better)',
             fontsize=13, fontweight='bold')
ax1.set_xticks(range(len(names_exp2)))
ax1.set_xticklabels(names_exp2, rotation=45, ha='right')
ax1.grid(True, axis='y', alpha=0.3)

# Experiment 3
gaps_exp3 = [r['class_gap'] for r in exp3_results]
names_exp3 = [r['experiment'] for r in exp3_results]

ax2.bar(range(len(gaps_exp3)), gaps_exp3, color='lightblue', alpha=0.7,
       edgecolor='darkblue', linewidth=2)
ax2.set_xlabel('Configuration', fontsize=12)
ax2.set_ylabel('Accuracy Gap (%)', fontsize=12)
ax2.set_title('Exp 3: Winner 0 vs Winner 1 Gap\n(Lower is Better)',
             fontsize=13, fontweight='bold')
ax2.set_xticks(range(len(names_exp3)))
ax2.set_xticklabels(names_exp3, rotation=45, ha='right')
ax2.grid(True, axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('class_gap_comparison.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved class_gap_comparison.png")
plt.show()

# ============================================================================
# CELL 11: SAVE RESULTS
# ============================================================================

print("\nSaving results...")

# Save CSV
exp2_df.to_csv('exp2_results.csv', index=False)
exp3_df.to_csv('exp3_results.csv', index=False)
print("  ✓ Saved exp2_results.csv")
print("  ✓ Saved exp3_results.csv")

# Save JSON
all_results = {
    'experiment_2_capacity': exp2_results,
    'experiment_3_depth': exp3_results,
    'configuration': {
        'board_dim': BOARD_DIM,
        'train_games': TRAIN_GAMES,
        'test_games': TEST_GAMES,
        'epochs': EPOCHS
    }
}

with open('experiments_results.json', 'w') as f:
    json.dump(all_results, f, indent=2, default=str)
print("  ✓ Saved experiments_results.json")

# ============================================================================
# CELL 12: BEST CONFIGURATIONS SUMMARY
# ============================================================================

# Find best configurations
best_exp2 = max(exp2_results, key=lambda x: x['test_acc'])
best_exp3 = max(exp3_results, key=lambda x: x['test_acc'])

print("\n" + "="*70)
print("BEST CONFIGURATIONS")
print("="*70)

print("\n[EXPERIMENT 2: MODEL CAPACITY]")
print(f"  Best Config: {best_exp2['experiment']}")
print(f"  Clauses: {best_exp2['clauses']}, T: {best_exp2['T']}")
print(f"  Test Accuracy: {best_exp2['test_acc']:.2f}%")
print(f"  Winner 0: {best_exp2['test_class_0_acc']:.2f}%")
print(f"  Winner 1: {best_exp2['test_class_1_acc']:.2f}%")
print(f"  Class Gap: {best_exp2['class_gap']:.2f}%")

print("\n[EXPERIMENT 3: MESSAGE PASSING DEPTH]")
print(f"  Best Config: {best_exp3['experiment']}")
print(f"  Depth: {best_exp3['depth']}")
print(f"  Test Accuracy: {best_exp3['test_acc']:.2f}%")
print(f"  Winner 0: {best_exp3['test_class_0_acc']:.2f}%")
print(f"  Winner 1: {best_exp3['test_class_1_acc']:.2f}%")
print(f"  Class Gap: {best_exp3['class_gap']:.2f}%")

print("\n" + "="*70)
print("EXPERIMENTS COMPLETE!")
print("="*70)
print("\nGenerated Files:")
print("  Learning Curves:")
print("    - exp2_learning_curves.png")
print("    - exp3_learning_curves.png")
print("  Confusion Matrices:")
print("    - exp2_confusion_matrices.png")
print("    - exp3_confusion_matrices.png")
print("  Comparison Plots:")
print("    - exp2_capacity_comparison.png")
print("    - exp3_depth_comparison.png")
print("    - class_gap_comparison.png")
print("  Data Files:")
print("    - exp2_results.csv")
print("    - exp3_results.csv")
print("    - experiments_results.json")
print("\n✓ All experiments completed successfully!")
print("✓ Ready for report writing!")