# Hex Game - GTM Analysis & Experiments

This notebook performs a comprehensive analysis of the Graph Tsetlin Machine (GTM) on the Hex game.
It implements:
1.  **Robust Data Generation**: Reproducible, seeded games (saved to CSV).
2.  **End-Game Analysis**: Evaluation at Final, End-2, and End-5 moves.
3.  **Parameter Search**: Experiments for Model Capacity (Clauses) and Message Passing (Depth).
4.  **Scaling**: Flexible configuration for 3x3, 11x11, etc.

---

In [None]:
import os
import pickle
import subprocess
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score
import gc

# Try importing GTM
try:
    from GraphTsetlinMachine.graphs import Graphs
    from GraphTsetlinMachine.tm import MultiClassGraphTsetlinMachine
    HAS_GTM = True
    print("‚úì GraphTsetlinMachine detected.")
except ImportError:
    print("WARNING: GraphTsetlinMachine not installed. Training will fail.")
    HAS_GTM = False

# Set Plotting Style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 120

In [None]:
# ==========================================
#            GLOBAL CONFIGURATION
# ==========================================

BOARD_DIM = 11        # 7x7 board
N_GAMES = 8000      # CRITICAL: Need MORE data for 7x7!
SEED = 20

# OPTIMIZED HYPERPARAMETERS FOR 7x7
# These are specifically tuned for the connectivity problem
if BOARD_DIM == 7:
    # START CONSERVATIVE - we can scale up if needed
    CLAUSES = 2000        # Moderate capacity
    T = 50              # Not too permissive (prevents collapse)
    S = 0.8              # Balanced feature selection
    DEPTH = 3            # Moderate message passing
    EPOCHS = 100          # More epochs to allow convergence
    HYPERVECTOR_SIZE = 256
    MESSAGE_SIZE = 256
    GRID_MULT = 3


elif BOARD_DIM == 3:
    CLAUSES = 400
    T = 400
    S = 2.5
    DEPTH = 3
    EPOCHS = 100
    HYPERVECTOR_SIZE = 256
    MESSAGE_SIZE = 256
    GRID_MULT = 1

elif BOARD_DIM == 11:
    CLAUSES = 2000
    T = 60
    S = 0.9
    DEPTH = 3
    EPOCHS = 40
    HYPERVECTOR_SIZE = 1024
    MESSAGE_SIZE = 1024
    GRID_MULT = 4
    USE_ROTATION_AUGMENTATION = True

else:
    # Fallback
    CLAUSES = 200
    T = 400
    S = 2.5
    DEPTH = 3
    EPOCHS = 25
    HYPERVECTOR_SIZE = 256
    MESSAGE_SIZE = 256
    GRID_MULT = 1

print(f"\n{'='*60}")
print(f"  CONFIGURATION: {BOARD_DIM}x{BOARD_DIM} Hex Board")
print(f"{'='*60}")
print(f"  Games:         {N_GAMES:,}")
print(f"  Seed:          {SEED}")
print(f"  Clauses:       {CLAUSES:,}")
print(f"  T:             {T}")
print(f"  S:             {S}")
print(f"  Depth:         {DEPTH}")
print(f"  Epochs:        {EPOCHS}")
print(f"  Hypervec Size: {HYPERVECTOR_SIZE}")
print(f"  Message Size:  {MESSAGE_SIZE}")
print(f"{'='*60}\n")

# Paths
RUNS_DIR = "runs"
os.makedirs(RUNS_DIR, exist_ok=True)
CSV_PATH = os.path.join(RUNS_DIR, f"hex_moves_dim{BOARD_DIM}_n{N_GAMES}.csv")

## 1. Data Generation & Processing
We use the C engine to generate games and save them to CSV. This ensures we have a permanent, reproducible dataset.

In [None]:
def generate_data():
    """Generates game data using ./scripts/run_hex.sh"""
    if os.path.exists(CSV_PATH):
        print(f"Found existing data: {CSV_PATH}")
        # You can uncomment the next line to force regeneration
        # os.remove(CSV_PATH)
        return
        
    print(f"Generating {N_GAMES} games for {BOARD_DIM}x{BOARD_DIM} (Seed {SEED})...")
    cmd = [
        "./scripts/run_hex.sh",
        "--games", str(N_GAMES),
        "--seed", str(SEED),
        "--dump-moves", CSV_PATH
    ]
    
    env = os.environ.copy()
    env["BOARD_DIM"] = str(BOARD_DIM)
    
    start = time.time()
    subprocess.run(cmd, env=env, check=True)
    print(f"Done in {time.time() - start:.2f}s")

def load_and_process_data(offset=0):
    """
    Loads CSV and reconstructs board states.
    offset: Number of moves before the end to capture (0=End, 2=End-2, etc.)
    """
    print(f"Processing data with Offset={offset}...")
    df = pd.read_csv(CSV_PATH)
    
    n_nodes = BOARD_DIM * BOARD_DIM
    x_feat = []
    o_feat = []
    labels = []
    
    # Group by game
    for game_id, group in tqdm(df.groupby("game_id"), desc="Replaying Games"):
        # Skip if game is too short
        if len(group) <= offset:
            continue
            
        # Slice moves
        moves = group.iloc[:-offset] if offset > 0 else group
        winner = group["winner"].iloc[0]
        
        p0 = np.zeros(n_nodes, dtype=np.int8)
        p1 = np.zeros(n_nodes, dtype=np.int8)
        
        for _, row in moves.iterrows():
            r, c, p = int(row['row']), int(row['col']), int(row['player'])
            idx = r * BOARD_DIM + c
            if 0 <= idx < n_nodes:
                if p == 0: p0[idx] = 1
                else: p1[idx] = 1
                    
        x_feat.append(p0)
        o_feat.append(p1)
        labels.append(winner)
        
    return np.array(x_feat), np.array(o_feat), np.array(labels)

def prepare_graphs(x, o, init_with=None):
    """Converts Feature Maps -> GTM Graphs using explicit E/X/O symbols"""
    if not HAS_GTM: return None

    n_samples = len(x)
    n_nodes = BOARD_DIM * BOARD_DIM

    # Convert to string representation
    board_strings = []
    for i in range(n_samples):
        board_str = []
        for node in range(n_nodes):
            if x[i][node] == 1:
                board_str.append("X")
            elif o[i][node] == 1:
                board_str.append("O")
            else:
                board_str.append("E")
        board_strings.append(board_str)

    graphs = Graphs(
        n_samples,
        symbols=["E", "X", "O"],  # Explicit order
        hypervector_size=HYPERVECTOR_SIZE,
        hypervector_bits=2,
        init_with=init_with
    )

    # Hex Edges - 6-neighbor topology
    offsets = [(0, 1), (0, -1), (-1, 1), (1, -1), (-1, 0), (1, 0)]
    adjacency = {}
    for r in range(BOARD_DIM):
        for c in range(BOARD_DIM):
            u = r * BOARD_DIM + c
            adjacency[u] = []
            for dr, dc in offsets:
                nr, nc = r + dr, c + dc
                if 0 <= nr < BOARD_DIM and 0 <= nc < BOARD_DIM:
                    adjacency[u].append(nr * BOARD_DIM + nc)

    # Node Config
    for i in range(n_samples):
        graphs.set_number_of_graph_nodes(i, n_nodes)
    graphs.prepare_node_configuration()

    # Edge Config
    for i in range(n_samples):
        for u in range(n_nodes):
            graphs.add_graph_node(i, u, len(adjacency[u]))
    graphs.prepare_edge_configuration()

    # Add Edges (bidirectional)
    for i in range(n_samples):
        for u in range(n_nodes):
            for v in adjacency[u]:
                graphs.add_graph_node_edge(i, u, v, 0)

    # Add Properties using string symbols
    for i in range(n_samples):
        for u in range(n_nodes):
            graphs.add_graph_node_property(i, u, board_strings[i][u])

    graphs.encode()
    return graphs

# ============================================================
#         ROTATION AUGMENTATION (Optional Enhancement)
# ============================================================

USE_ROTATION_AUGMENTATION = False  # Set to True to enable

def rotate_board_90(board_1d, board_dim):
    """Rotate board 90 degrees clockwise"""
    board_2d = board_1d.reshape(board_dim, board_dim)
    rotated = np.rot90(board_2d, k=-1)  # k=-1 = clockwise
    return rotated.flatten()

def augment_with_rotation(X_raw, O_raw, Y):
    """
    Double dataset by adding 90¬∞ rotated versions.
    When rotated, TOP-BOTTOM becomes LEFT-RIGHT, so labels swap.
    """
    print("\n" + "="*60)
    print("  APPLYING ROTATION AUGMENTATION")
    print("="*60)

    X_augmented = []
    O_augmented = []
    Y_augmented = []

    for i in range(len(Y)):
        # Original board
        X_augmented.append(X_raw[i])
        O_augmented.append(O_raw[i])
        Y_augmented.append(Y[i])

        # Rotated board (90¬∞ clockwise)
        X_rot = rotate_board_90(X_raw[i], BOARD_DIM)
        O_rot = rotate_board_90(O_raw[i], BOARD_DIM)

        # CRITICAL: Swap labels after rotation
        # What was vertical (W0) is now horizontal (W1)
        Y_rot = 1 - Y[i]

        X_augmented.append(X_rot)
        O_augmented.append(O_rot)
        Y_augmented.append(Y_rot)

    X_augmented = np.array(X_augmented)
    O_augmented = np.array(O_augmented)
    Y_augmented = np.array(Y_augmented)

    print(f"Original samples:   {len(Y):,}")
    print(f"Augmented samples:  {len(Y_augmented):,} (2x)")
    print(f"Winner 0: {np.sum(Y_augmented==0):,} ({100*np.mean(Y_augmented==0):.1f}%)")
    print(f"Winner 1: {np.sum(Y_augmented==1):,} ({100*np.mean(Y_augmented==1):.1f}%)")
    print("="*60 + "\n")

    return X_augmented, O_augmented, Y_augmented

if USE_ROTATION_AUGMENTATION:
    print("\n‚ö†Ô∏è ROTATION AUGMENTATION ENABLED")
    print("   Training data will be doubled with 90¬∞ rotations")
else:
    print("\nüìä Using standard data (no augmentation)")

### Generate Data Now

In [None]:
generate_data()
print("Data Ready!")

In [None]:
# # ============================================================
# #         SYSTEMATIC PARAMETER SEARCH - Find Best Config
# # ============================================================
# 
# print("\n" + "="*70)
# print("  üî¨ SYSTEMATIC PARAMETER SEARCH")
# print("="*70)
# print(f"Board: {BOARD_DIM}√ó{BOARD_DIM}")
# print(f"Strategy: Test configs with short training to find most promising")
# print(f"Goal: Maximize Winner 0 accuracy while keeping overall >50%")
# print("="*70 + "\n")
# 
# # Load data once
# print("Loading dataset...")
# X_raw_base, O_raw_base, Y_base = load_and_process_data(offset=0)
# 
# # Test with smaller subset for speed
# N_SEARCH_SAMPLES = min(3000, len(Y_base))
# X_search = X_raw_base[:N_SEARCH_SAMPLES]
# O_search = O_raw_base[:N_SEARCH_SAMPLES]
# Y_search = Y_base[:N_SEARCH_SAMPLES]
# 
# print(f"Using {N_SEARCH_SAMPLES:,} samples for parameter search")
# print(f"Winner 0: {np.sum(Y_search==0)} ({100*np.mean(Y_search==0):.1f}%)")
# print(f"Winner 1: {np.sum(Y_search==1)} ({100*np.mean(Y_search==1):.1f}%)")
# 
# baseline = max(np.mean(Y_search==0), np.mean(Y_search==1))
# print(f"Baseline: {baseline*100:.1f}%\n")
# 
# # Define parameter grid to search
# search_configs = [
#     # Format: {name, clauses, T, s, depth, use_rotation}
# 
#     # Baseline configurations (no rotation)
#     {"name": "Baseline_Conservative", "clauses": 1600, "T": 200, "s": 2.0, "depth": 5, "rotation": False},
#     {"name": "Baseline_Permissive", "clauses": 1600, "T": 100, "s": 1.0, "depth": 3, "rotation": False},
#     {"name": "Baseline_LowT", "clauses": 2000, "T": 80, "s": 1.2, "depth": 4, "rotation": False},
# 
#     # High capacity variations
#     {"name": "HighCap_Conservative", "clauses": 2500, "T": 150, "s": 1.5, "depth": 5, "rotation": False},
#     {"name": "HighCap_Permissive", "clauses": 3000, "T": 100, "s": 1.0, "depth": 4, "rotation": False},
# 
#     # Deep message passing
#     {"name": "Deep_Moderate", "clauses": 1800, "T": 120, "s": 1.3, "depth": 8, "rotation": False},
#     {"name": "Deep_Permissive", "clauses": 2000, "T": 80, "s": 1.0, "depth": 10, "rotation": False},
# 
#     # Shallow but huge
#     {"name": "Shallow_Huge", "clauses": 3500, "T": 100, "s": 1.0, "depth": 2, "rotation": False},
# 
#     # Ultra-low threshold
#     {"name": "UltraLowT_1", "clauses": 2000, "T": 60, "s": 0.9, "depth": 3, "rotation": False},
#     {"name": "UltraLowT_2", "clauses": 2500, "T": 40, "s": 0.8, "depth": 3, "rotation": False},
# 
#     # WITH ROTATION - Best performing configs from above
#     {"name": "Rotation_Permissive", "clauses": 1600, "T": 100, "s": 1.0, "depth": 3, "rotation": True},
#     {"name": "Rotation_HighCap", "clauses": 2500, "T": 100, "s": 1.0, "depth": 4, "rotation": True},
#     {"name": "Rotation_LowT", "clauses": 2000, "T": 60, "s": 0.9, "depth": 3, "rotation": True},
#     {"name": "Rotation_Deep", "clauses": 2000, "T": 80, "s": 1.0, "depth": 8, "rotation": True},
# ]
# 
# SEARCH_EPOCHS = 20  # Short training for quick evaluation
# EVAL_EVERY = 5
# 
# print(f"Testing {len(search_configs)} configurations ({SEARCH_EPOCHS} epochs each)")
# print(f"Estimated time: ~{len(search_configs) * 3:.0f} minutes\n")
# 
# search_results = []
# 
# for i, cfg in enumerate(search_configs, 1):
#     print(f"\n{'='*70}")
#     print(f"  [{i}/{len(search_configs)}] {cfg['name']}")
#     print(f"{'='*70}")
#     print(f"  Clauses: {cfg['clauses']}, T: {cfg['T']}, s: {cfg['s']}, Depth: {cfg['depth']}")
#     print(f"  Rotation: {'YES' if cfg['rotation'] else 'NO'}")
# 
#     start_time = time.time()
# 
#     try:
#         # Prepare data (with or without rotation)
#         if cfg['rotation']:
#             print("  Applying rotation augmentation...")
#             X_data, O_data, Y_data = augment_with_rotation(X_search, O_search, Y_search)
#         else:
#             X_data, O_data, Y_data = X_search.copy(), O_search.copy(), Y_search.copy()
# 
#         # Build graphs
#         print("  Building graphs...", end="", flush=True)
#         g_search = prepare_graphs(X_data, O_data)
#         print(" Done!")
# 
#         # Create model
#         tm = MultiClassGraphTsetlinMachine(
#             number_of_clauses=cfg['clauses'],
#             T=cfg['T'],
#             s=cfg['s'],
#             depth=cfg['depth'],
#             message_size=512,
#             message_bits=2,
#             max_included_literals=96,
#             grid=(16*13*3, 1, 1),
#             block=(128, 1, 1)
#         )
# 
#         # Training with tracking
#         print(f"  Training {SEARCH_EPOCHS} epochs...", end="", flush=True)
# 
#         epoch_results = []
#         best_w0 = 0
#         best_overall = 0
# 
#         for ep in range(SEARCH_EPOCHS):
#             tm.fit(g_search, Y_data, epochs=1, incremental=True)
# 
#             if (ep + 1) % EVAL_EVERY == 0:
#                 preds = tm.predict(g_search)
#                 acc = accuracy_score(Y_data, preds)
#                 acc0 = accuracy_score(Y_data[Y_data==0], preds[Y_data==0]) if np.sum(Y_data==0) > 0 else 0
#                 acc1 = accuracy_score(Y_data[Y_data==1], preds[Y_data==1]) if np.sum(Y_data==1) > 0 else 0
# 
#                 epoch_results.append({
#                     'epoch': ep + 1,
#                     'overall': acc,
#                     'w0': acc0,
#                     'w1': acc1
#                 })
# 
#                 if acc0 > best_w0:
#                     best_w0 = acc0
#                 if acc > best_overall:
#                     best_overall = acc
# 
#                 print(".", end="", flush=True)
# 
#         print(" Done!")
# 
#         # Final evaluation
#         final_preds = tm.predict(g_search)
#         final_acc = accuracy_score(Y_data, final_preds)
#         final_acc0 = accuracy_score(Y_data[Y_data==0], final_preds[Y_data==0]) if np.sum(Y_data==0) > 0 else 0
#         final_acc1 = accuracy_score(Y_data[Y_data==1], final_preds[Y_data==1]) if np.sum(Y_data==1) > 0 else 0
# 
#         pred_counts = np.bincount(final_preds, minlength=2)
# 
#         # Calculate "improvement score" - prioritizes W0 improvement
#         w0_improvement = final_acc0 - (0.5 if cfg['rotation'] else baseline)
#         overall_improvement = final_acc - (0.5 if cfg['rotation'] else baseline)
# 
#         # Weighted score: W0 is 3x more important since it's the problem
#         improvement_score = (w0_improvement * 3.0) + overall_improvement
# 
#         training_time = time.time() - start_time
# 
#         # Store results
#         result = {
#             'name': cfg['name'],
#             'clauses': cfg['clauses'],
#             'T': cfg['T'],
#             's': cfg['s'],
#             'depth': cfg['depth'],
#             'rotation': cfg['rotation'],
#             'final_overall': final_acc * 100,
#             'final_w0': final_acc0 * 100,
#             'final_w1': final_acc1 * 100,
#             'best_w0': best_w0 * 100,
#             'best_overall': best_overall * 100,
#             'gap': abs(final_acc0 - final_acc1) * 100,
#             'improvement_score': improvement_score * 100,
#             'pred_w0_count': int(pred_counts[0]),
#             'pred_w1_count': int(pred_counts[1]),
#             'time_sec': training_time,
#             'epoch_history': epoch_results
#         }
# 
#         search_results.append(result)
# 
#         # Print summary
#         print(f"\n  Results:")
#         print(f"    Overall: {final_acc*100:5.1f}% | W0: {final_acc0*100:5.1f}% | W1: {final_acc1*100:5.1f}%")
#         print(f"    Best W0: {best_w0*100:5.1f}% | Gap: {abs(final_acc0-final_acc1)*100:5.1f}%")
#         print(f"    Improvement Score: {improvement_score*100:+.1f}")
#         print(f"    Predictions: W0={pred_counts[0]}, W1={pred_counts[1]}")
#         print(f"    Time: {training_time:.1f}s")
# 
#         # Memory cleanup
#         del tm, g_search
#         gc.collect()
# 
#     except Exception as e:
#         print(f"\n  ‚úó FAILED: {str(e)[:100]}")
#         search_results.append({
#             'name': cfg['name'],
#             'error': str(e),
#             'failed': True
#         })
#         continue
# 
# # ============================================================
# #  ANALYSIS & RECOMMENDATIONS
# # ============================================================
# 
# print("\n" + "="*70)
# print("  üìä PARAMETER SEARCH COMPLETE")
# print("="*70)
# 
# successful_results = [r for r in search_results if 'error' not in r]
# 
# if successful_results:
#     # Sort by different metrics
#     by_score = sorted(successful_results, key=lambda x: x['improvement_score'], reverse=True)
#     by_w0 = sorted(successful_results, key=lambda x: x['final_w0'], reverse=True)
#     by_overall = sorted(successful_results, key=lambda x: x['final_overall'], reverse=True)
#     by_balanced = sorted(successful_results, key=lambda x: x['gap'])
# 
#     print("\n" + "="*70)
#     print("  üèÜ TOP 5 CONFIGURATIONS BY IMPROVEMENT SCORE")
#     print("="*70)
#     print(f"{'Rank':<5} {'Config':<25} {'Overall':>8} {'W0':>8} {'W1':>8} {'Gap':>8} {'Score':>8}")
#     print("-"*70)
# 
#     for rank, r in enumerate(by_score[:5], 1):
#         marker = "üåü" if r['rotation'] else "  "
#         print(f"{rank:<5} {marker}{r['name']:<23} {r['final_overall']:7.1f}% {r['final_w0']:7.1f}% "
#               f"{r['final_w1']:7.1f}% {r['gap']:7.1f}% {r['improvement_score']:7.1f}")
# 
#     print("\n" + "="*70)
#     print("  üéØ TOP 5 BY WINNER 0 ACCURACY (The Hard One)")
#     print("="*70)
#     print(f"{'Rank':<5} {'Config':<25} {'W0':>8} {'Overall':>8} {'Gap':>8}")
#     print("-"*70)
# 
#     for rank, r in enumerate(by_w0[:5], 1):
#         marker = "üåü" if r['rotation'] else "  "
#         print(f"{rank:<5} {marker}{r['name']:<23} {r['final_w0']:7.1f}% {r['final_overall']:7.1f}% {r['gap']:7.1f}%")
# 
#     print("\n" + "="*70)
#     print("  ‚öñÔ∏è TOP 5 MOST BALANCED (Lowest Gap)")
#     print("="*70)
#     print(f"{'Rank':<5} {'Config':<25} {'Gap':>8} {'W0':>8} {'W1':>8}")
#     print("-"*70)
# 
#     for rank, r in enumerate(by_balanced[:5], 1):
#         marker = "üåü" if r['rotation'] else "  "
#         print(f"{rank:<5} {marker}{r['name']:<23} {r['gap']:7.1f}% {r['final_w0']:7.1f}% {r['final_w1']:7.1f}%")
# 
#     # Comparison: With vs Without Rotation
#     with_rot = [r for r in successful_results if r['rotation']]
#     without_rot = [r for r in successful_results if not r['rotation']]
# 
#     if with_rot and without_rot:
#         print("\n" + "="*70)
#         print("  üîÑ ROTATION AUGMENTATION ANALYSIS")
#         print("="*70)
# 
#         avg_w0_with = np.mean([r['final_w0'] for r in with_rot])
#         avg_w0_without = np.mean([r['final_w0'] for r in without_rot])
#         avg_overall_with = np.mean([r['final_overall'] for r in with_rot])
#         avg_overall_without = np.mean([r['final_overall'] for r in without_rot])
# 
#         print(f"\nAverage Winner 0 Accuracy:")
#         print(f"  Without rotation: {avg_w0_without:.1f}%")
#         print(f"  With rotation:    {avg_w0_with:.1f}%")
#         print(f"  Difference:       {avg_w0_with - avg_w0_without:+.1f}%")
# 
#         print(f"\nAverage Overall Accuracy:")
#         print(f"  Without rotation: {avg_overall_without:.1f}%")
#         print(f"  With rotation:    {avg_overall_with:.1f}%")
#         print(f"  Difference:       {avg_overall_with - avg_overall_without:+.1f}%")
# 
#         if avg_w0_with > avg_w0_without + 5:
#             print(f"\n‚úÖ Rotation provides meaningful W0 improvement (+{avg_w0_with - avg_w0_without:.1f}%)")
#         elif avg_w0_with > avg_w0_without:
#             print(f"\n‚ö†Ô∏è Rotation provides marginal improvement (+{avg_w0_with - avg_w0_without:.1f}%)")
#         else:
#             print(f"\n‚ùå Rotation does not help ({avg_w0_with - avg_w0_without:.1f}%)")
# 
#     # Save results
#     results_df = pd.DataFrame([{
#         'name': r['name'],
#         'clauses': r['clauses'],
#         'T': r['T'],
#         's': r['s'],
#         'depth': r['depth'],
#         'rotation': r['rotation'],
#         'overall': r['final_overall'],
#         'w0': r['final_w0'],
#         'w1': r['final_w1'],
#         'gap': r['gap'],
#         'score': r['improvement_score']
#     } for r in successful_results])
# 
#     results_file = os.path.join(RUNS_DIR, f'parameter_search_{BOARD_DIM}x{BOARD_DIM}.csv')
#     results_df.to_csv(results_file, index=False)
#     print(f"\nüìÅ Results saved to: {results_file}")
# 
#     # RECOMMENDATION
#     print("\n" + "="*70)
#     print("  üí° RECOMMENDATION FOR FULL TRAINING")
#     print("="*70)
# 
#     best = by_score[0]
#     print(f"\nBest Configuration: {best['name']}")
#     print(f"  Clauses:  {best['clauses']}")
#     print(f"  T:        {best['T']}")
#     print(f"  s:        {best['s']}")
#     print(f"  Depth:    {best['depth']}")
#     print(f"  Rotation: {'YES' if best['rotation'] else 'NO'}")
#     print(f"\nExpected Performance (full training):")
#     print(f"  Overall:  {best['final_overall']:.1f}% ‚Üí ~{best['final_overall']+2:.1f}%")
#     print(f"  Winner 0: {best['final_w0']:.1f}% ‚Üí ~{best['final_w0']+3:.1f}%")
#     print(f"  Winner 1: {best['final_w1']:.1f}% ‚Üí ~{best['final_w1']-2:.1f}%")
# 
#     if best['final_w0'] < 30:
#         print(f"\n‚ö†Ô∏è WARNING: Even best config has low W0 accuracy ({best['final_w0']:.1f}%)")
#         print(f"   This confirms the fundamental limitation for {BOARD_DIM}√ó{BOARD_DIM} boards.")
# 
#     print("\n" + "="*70)
# 
# else:
#     print("\n‚ùå No successful configurations")
# 
# print("\n‚úì Parameter search complete!")

## 2. Main Experiment: End, End-2, End-5 Analysis
We train and test the model on three different game states:
1.  **End**: The final position (Easy).
2.  **End-2**: Two moves before the end (Harder).
3.  **End-5**: Five moves before the end (Hardest).

This uses the parameters configured at the top.

In [None]:
offsets = [0, 2, 5]
results_main = []

if HAS_GTM:
    for off in offsets:
        print(f"\n{'='*60}")
        print(f"  Running OFFSET = {off} ({BOARD_DIM}x{BOARD_DIM})")
        print(f"{'='*60}")

        # 1. Load
        X_raw, O_raw, Y = load_and_process_data(offset=off)

        if USE_ROTATION_AUGMENTATION:
            X_raw, O_raw, Y = augment_with_rotation(X_raw, O_raw, Y)

        # 2. Split (80/20)
        split = int(0.8 * len(Y))
        X_train = (X_raw[:split], O_raw[:split])
        Y_train = Y[:split]
        X_test = (X_raw[split:], O_raw[split:])
        Y_test = Y[split:]

        print(f"Train: {len(Y_train):,}, Test: {len(Y_test):,}")
        print(f"Winner 0: {np.sum(Y_train==0):,} ({100*np.mean(Y_train==0):.1f}%)")
        print(f"Winner 1: {np.sum(Y_train==1):,} ({100*np.mean(Y_train==1):.1f}%)")

        # 3. Graphs
        print("\nBuilding Graphs...")
        t0 = time.time()
        g_train = prepare_graphs(X_train[0], X_train[1])
        g_test = prepare_graphs(X_test[0], X_test[1], init_with=g_train)
        print(f"Done in {time.time()-t0:.1f}s")

        # 4. Train
        print(f"\nTraining GTM...")
        tm = MultiClassGraphTsetlinMachine(
            number_of_clauses=CLAUSES,
            T=T,
            s=S,
            depth=DEPTH,
            message_size=MESSAGE_SIZE,  # Use config value
            message_bits=2,
            max_included_literals=min(64, int(32 * GRID_MULT)),  # Scale with board
            grid=(16*13*GRID_MULT, 1, 1),  # Scale grid with complexity
            block=(128, 1, 1)
        )

        history = []
        eval_interval = max(1, EPOCHS // 10)  # Evaluate 10 times total

        t_start = time.time()
        for ep in tqdm(range(EPOCHS), desc=f"Offset {off}"):
            tm.fit(g_train, Y_train, epochs=1, incremental=True)

            # Periodic evaluation (not every epoch - saves time)
            if (ep + 1) % eval_interval == 0 or ep == 0 or ep == EPOCHS-1:
                p_train = tm.predict(g_train)
                acc = accuracy_score(Y_train, p_train)
                history.append((ep+1, acc * 100))
                tqdm.write(f"  Epoch {ep+1:3d}: Train Acc = {acc*100:.1f}%")

        training_time = time.time() - t_start

        # 5. Final Evaluation
        print(f"\nEvaluating...")
        preds = tm.predict(g_test)
        final_acc = accuracy_score(Y_test, preds)
        cm = confusion_matrix(Y_test, preds)

        # Per-class accuracy
        acc0 = accuracy_score(Y_test[Y_test==0], preds[Y_test==0]) if np.sum(Y_test==0) > 0 else 0
        acc1 = accuracy_score(Y_test[Y_test==1], preds[Y_test==1]) if np.sum(Y_test==1) > 0 else 0

        # Metrics
        from sklearn.metrics import recall_score, f1_score

        p0 = precision_score(Y_test, preds, pos_label=0, zero_division=0)
        p1 = precision_score(Y_test, preds, pos_label=1, zero_division=0)
        r0 = recall_score(Y_test, preds, pos_label=0, zero_division=0)
        r1 = recall_score(Y_test, preds, pos_label=1, zero_division=0)
        f1_0 = f1_score(Y_test, preds, pos_label=0, zero_division=0)
        f1_1 = f1_score(Y_test, preds, pos_label=1, zero_division=0)

        print(f"Precision:  W0={p0*100:.2f}% | W1={p1*100:.2f}%")
        print(f"Recall:     W0={r0*100:.2f}% | W1={r1*100:.2f}%")
        print(f"F1 Score:   W0={f1_0*100:.2f}% | W1={f1_1*100:.2f}%")

        print(f"\n{'='*60}")
        print(f"  RESULTS - Offset {off}")
        print(f"{'='*60}")
        print(f"Overall Accuracy:  {final_acc*100:.2f}%")
        print(f"Winner 0 Accuracy: {acc0*100:.2f}%")
        print(f"Winner 1 Accuracy: {acc1*100:.2f}%")
        print(f"Class Gap:         {abs(acc0-acc1)*100:.2f}%")
        print(f"Training Time:     {training_time/60:.1f} min")
        print(f"{'='*60}\n")

        results_main.append({
            "offset": off,
            "acc": final_acc * 100,
            "acc0": acc0 * 100,
            "acc1": acc1 * 100,
            "gap": abs(acc0-acc1) * 100,
            "recall0": r0 * 100,      # ADD THIS
            "recall1": r1 * 100,      # ADD THIS
            "f1_0": f1_0 * 100,       # ADD THIS
            "f1_1": f1_1 * 100,       # ADD THIS
            "history": history,
            "cm": cm,
            "preds": preds,
            "y_test": Y_test,
            "training_time": training_time
        })

        # Memory cleanup
        del tm, g_train, g_test
        gc.collect()

    # Save results summary
    summary_df = pd.DataFrame([{
        'offset': r['offset'],
        'accuracy': r['acc'],
        'acc_winner0': r['acc0'],
        'acc_winner1': r['acc1'],
        'gap': r['gap'],
        'time_min': r['training_time']/60
    } for r in results_main])

    summary_path = os.path.join(RUNS_DIR, f"main_results_{BOARD_DIM}x{BOARD_DIM}.csv")
    summary_df.to_csv(summary_path, index=False)
    print(f"‚úì Results saved to {summary_path}")

Replaying Games:  71%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 5659/8000 [00:16<00:06, 351.52it/s]

In [None]:
# =========================================================
#       PARAMETER EXPERIMENTS (Scalable)
# =========================================================

# Define experiment ranges based on board size
if BOARD_DIM <= 5:
    # Small boards: test more granular
    CLAUSE_TESTS = [int(CLAUSES * f) for f in [0.5, 1.0, 2.0, 4.0]]
    DEPTH_TESTS = [max(1, DEPTH-2), DEPTH, DEPTH+2, DEPTH+5]
    REDUCED_EPOCHS = max(15, EPOCHS // 2)
elif BOARD_DIM <= 11:
    # Medium boards: test around optimal
    CLAUSE_TESTS = [int(CLAUSES * f) for f in [0.67, 1.0, 1.5, 2.0]]
    DEPTH_TESTS = [max(1, DEPTH-3), DEPTH, DEPTH+3, DEPTH+6]
    REDUCED_EPOCHS = max(20, EPOCHS // 2)
else:
    # Large boards: test fewer configs (time-expensive)
    CLAUSE_TESTS = [int(CLAUSES * f) for f in [0.75, 1.0, 1.5]]
    DEPTH_TESTS = [DEPTH-3, DEPTH, DEPTH+5]
    REDUCED_EPOCHS = max(25, EPOCHS // 2)

print("\n" + "="*60)
print(f"  PARAMETER EXPERIMENTS - {BOARD_DIM}x{BOARD_DIM}")
print("="*60)
print(f"  Clause tests: {CLAUSE_TESTS}")
print(f"  Depth tests:  {DEPTH_TESTS}")
print(f"  Epochs:       {REDUCED_EPOCHS} (reduced for speed)")
print("="*60 + "\n")

# Use offset=0 (final position) for all experiments
if results_main:
    X_raw, O_raw, Y = load_and_process_data(offset=0)
    split = int(0.8 * len(Y))
    X_train = (X_raw[:split], O_raw[:split])
    Y_train = Y[:split]
    X_test = (X_raw[split:], O_raw[split:])
    Y_test = Y[split:]

    print("Building graphs for experiments...")
    g_train = prepare_graphs(X_train[0], X_train[1])
    g_test = prepare_graphs(X_test[0], X_test[1], init_with=g_train)
    print("Done!\n")

    # === EXPERIMENT 1: Model Capacity ===
    print("\n" + "-"*60)
    print("EXPERIMENT 1: Model Capacity (Varying Clauses)")
    print("-"*60 + "\n")

    results_capacity = []

    for c in CLAUSE_TESTS:
        print(f"Testing {c} clauses...")
        t0 = time.time()

        tm = MultiClassGraphTsetlinMachine(
            number_of_clauses=c,
            T=T, s=S, depth=DEPTH,  # Keep depth constant
            message_size=MESSAGE_SIZE,
            message_bits=2,
            max_included_literals=min(64, int(32 * GRID_MULT)),
            grid=(16*13*GRID_MULT, 1, 1),
            block=(128, 1, 1)
        )

        for _ in tqdm(range(REDUCED_EPOCHS), desc=f"  {c} clauses", leave=False):
            tm.fit(g_train, Y_train, epochs=1, incremental=True)

        preds = tm.predict(g_test)
        acc = accuracy_score(Y_test, preds)
        acc0 = accuracy_score(Y_test[Y_test==0], preds[Y_test==0]) if np.sum(Y_test==0) > 0 else 0
        acc1 = accuracy_score(Y_test[Y_test==1], preds[Y_test==1]) if np.sum(Y_test==1) > 0 else 0

        elapsed = time.time() - t0

        results_capacity.append({
            "clauses": c,
            "overall": acc*100,
            "winner0": acc0*100,
            "winner1": acc1*100,
            "gap": abs(acc0-acc1)*100,
            "time_min": elapsed/60
        })

        print(f"  ‚Üí Overall: {acc*100:.1f}% | W0: {acc0*100:.1f}% | W1: {acc1*100:.1f}% | Time: {elapsed/60:.1f}min\n")

        del tm
        gc.collect()

    # === EXPERIMENT 2: Message Passing Depth ===
    print("\n" + "-"*60)
    print("EXPERIMENT 2: Message Passing Depth")
    print("-"*60 + "\n")

    results_depth = []

    for d in DEPTH_TESTS:
        if d < 1:
            continue  # Skip invalid depths

        print(f"Testing depth {d}...")
        t0 = time.time()

        tm = MultiClassGraphTsetlinMachine(
            number_of_clauses=CLAUSES,  # Keep clauses constant
            T=T, s=S, depth=d,
            message_size=MESSAGE_SIZE,
            message_bits=2,
            max_included_literals=min(64, int(32 * GRID_MULT)),
            grid=(16*13*GRID_MULT, 1, 1),
            block=(128, 1, 1)
        )

        for _ in tqdm(range(REDUCED_EPOCHS), desc=f"  Depth {d}", leave=False):
            tm.fit(g_train, Y_train, epochs=1, incremental=True)

        preds = tm.predict(g_test)
        acc = accuracy_score(Y_test, preds)
        acc0 = accuracy_score(Y_test[Y_test==0], preds[Y_test==0]) if np.sum(Y_test==0) > 0 else 0
        acc1 = accuracy_score(Y_test[Y_test==1], preds[Y_test==1]) if np.sum(Y_test==1) > 0 else 0

        elapsed = time.time() - t0

        results_depth.append({
            "depth": d,
            "overall": acc*100,
            "winner0": acc0*100,
            "winner1": acc1*100,
            "gap": abs(acc0-acc1)*100,
            "time_min": elapsed/60
        })

        print(f"  ‚Üí Overall: {acc*100:.1f}% | W0: {acc0*100:.1f}% | W1: {acc1*100:.1f}% | Time: {elapsed/60:.1f}min\n")

        del tm
        gc.collect()

    # Save results
    df_cap = pd.DataFrame(results_capacity)
    df_depth = pd.DataFrame(results_depth)

    cap_path = os.path.join(RUNS_DIR, f"capacity_{BOARD_DIM}x{BOARD_DIM}.csv")
    depth_path = os.path.join(RUNS_DIR, f"depth_{BOARD_DIM}x{BOARD_DIM}.csv")

    df_cap.to_csv(cap_path, index=False)
    df_depth.to_csv(depth_path, index=False)

    print(f"\n{'='*60}")
    print(f"‚úì Results saved!")
    print(f"  {cap_path}")
    print(f"  {depth_path}")
    print(f"{'='*60}\n")

## 3. Parameter Search (Capacity & Depth)
Use these experiments to find the best configuration for scaling to 11x11.
This runs on the **Final** board state (Offset 0) by default to save time.

In [None]:
# === COMPREHENSIVE VISUALIZATION ===
if 'results_capacity' in dir() and 'results_depth' in dir():

    fig = plt.figure(figsize=(16, 12))
    gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)

    df_cap = pd.DataFrame(results_capacity)
    df_depth = pd.DataFrame(results_depth)

    # --- ROW 1: Main Experiment Results ---
    if results_main:
        ax1 = fig.add_subplot(gs[0, :])
        for res in results_main:
            epochs, accs = zip(*res['history'])
            ax1.plot(epochs, accs, 'o-', label=f"Offset {res['offset']} (Final: {res['acc']:.1f}%)", linewidth=2)
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Training Accuracy (%)")
        ax1.set_title(f"Learning Curves - {BOARD_DIM}x{BOARD_DIM} Board", fontsize=14, fontweight='bold')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

    # --- ROW 2: Capacity Experiment ---
    ax2 = fig.add_subplot(gs[1, 0])
    ax2.plot(df_cap['clauses'], df_cap['overall'], 'o-', color='black', linewidth=2, markersize=8, label='Overall')
    ax2.plot(df_cap['clauses'], df_cap['winner0'], 's--', color='blue', linewidth=1.5, markersize=6, label='Winner 0')
    ax2.plot(df_cap['clauses'], df_cap['winner1'], '^--', color='red', linewidth=1.5, markersize=6, label='Winner 1')
    ax2.set_xlabel("Number of Clauses")
    ax2.set_ylabel("Accuracy (%)")
    ax2.set_title("Model Capacity vs Accuracy", fontsize=12, fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    ax3 = fig.add_subplot(gs[1, 1])
    ax3.bar(df_cap['clauses'].astype(str), df_cap['gap'], color='coral', alpha=0.7)
    ax3.set_xlabel("Number of Clauses")
    ax3.set_ylabel("Accuracy Gap (%)")
    ax3.set_title("Class Imbalance (Winner 0 - Winner 1)", fontsize=12, fontweight='bold')
    ax3.grid(True, alpha=0.3, axis='y')

    # --- ROW 3: Depth Experiment ---
    ax4 = fig.add_subplot(gs[2, 0])
    ax4.plot(df_depth['depth'], df_depth['overall'], 'o-', color='black', linewidth=2, markersize=8, label='Overall')
    ax4.plot(df_depth['depth'], df_depth['winner0'], 's--', color='blue', linewidth=1.5, markersize=6, label='Winner 0')
    ax4.plot(df_depth['depth'], df_depth['winner1'], '^--', color='red', linewidth=1.5, markersize=6, label='Winner 1')
    ax4.set_xlabel("Message Passing Depth")
    ax4.set_ylabel("Accuracy (%)")
    ax4.set_title("Depth vs Accuracy", fontsize=12, fontweight='bold')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    ax5 = fig.add_subplot(gs[2, 1])
    ax5.bar(df_depth['depth'].astype(str), df_depth['gap'], color='skyblue', alpha=0.7)
    ax5.set_xlabel("Message Passing Depth")
    ax5.set_ylabel("Accuracy Gap (%)")
    ax5.set_title("Class Imbalance (Winner 0 - Winner 1)", fontsize=12, fontweight='bold')
    ax5.grid(True, alpha=0.3, axis='y')

    plt.suptitle(f"GTM Analysis - {BOARD_DIM}x{BOARD_DIM} Hex Board", fontsize=16, fontweight='bold', y=0.995)

    plot_path = os.path.join(RUNS_DIR, f"full_analysis_{BOARD_DIM}x{BOARD_DIM}.png")
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"\n{'='*70}")
    print(f"  FINAL SUMMARY - {BOARD_DIM}x{BOARD_DIM} Board")
    print(f"{'='*70}\n")

    print("CAPACITY EXPERIMENT:")
    print(df_cap.to_string(index=False))

    print(f"\n{'='*70}")
    print("DEPTH EXPERIMENT:")
    print(df_depth.to_string(index=False))

    print(f"\n{'='*70}")
    print(f"üìä Plot saved: {plot_path}")
    print(f"{'='*70}\n")

In [None]:
# ============================================================
#         GENERATE LATEX TABLE FOR REPORT
# ============================================================

if results_main:
    print("\n" + "="*70)
    print("  LATEX TABLE - Copy this into your report")
    print("="*70 + "\n")

    print("\\begin{table}[htbp]")
    print("\\centering")
    print("\\begin{tabular}{|c|c|c|c|c|c|}")
    print("\\hline")
    print("\\textbf{Offset} & \\textbf{Samples} & \\textbf{Accuracy} & \\textbf{Recall W0} & \\textbf{Recall W1} & \\textbf{F1 (W1)} \\\\")
    print("\\hline")

    for res in results_main:
        offset = res['offset']
        n_samples = len(res['y_test'])
        acc = res['acc']
        recall0 = res['recall0']
        recall1 = res['recall1']
        f1_w1 = res['f1_1']

        if offset == 0:
            label = "0 (End)"
        elif offset == 2:
            label = "2 (End-2)"
        elif offset == 5:
            label = "5 (End-5)"
        else:
            label = str(offset)

        print(f"{label:9s} & {n_samples:4d} & {acc:5.2f}\\% & {recall0:5.2f}\\% & {recall1:5.2f}\\% & {f1_w1:5.2f}\\% \\\\")

    print("\\hline")
    print("\\end{tabular}")
    print("\\caption{Test results for the required offset evaluation on $3\\times3$.}")
    print("\\label{tab:3x3-offset-results}")
    print("\\end{table}")

    print("\n" + "="*70 + "\n")

In [None]:
# === SAVE PLOTS WITH CORRECT NAMES ===
if results_main:
    # Learning curves
    plt.figure(figsize=(10, 6))
    for res in results_main:
        epochs, accs = zip(*res['history']) if isinstance(res['history'][0], tuple) else (range(1, len(res['history'])+1), res['history'])
        label_map = {0: "End (Offset 0)", 2: "End-2 (Offset 2)", 5: "End-5 (Offset 5)"}
        plt.plot(epochs, accs, 'o-', label=label_map.get(res['offset'], f"Offset {res['offset']}"), linewidth=2)
    plt.xlabel("Epoch")
    plt.ylabel("Training Accuracy (%)")
    plt.title("Learning Curves: End vs End-2 vs End-5")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(RUNS_DIR, "11x11_learningcurve.png"), dpi=300, bbox_inches='tight')
    plt.show()

    # Confusion matrices
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    for i, res in enumerate(results_main):
        sns.heatmap(res['cm'], annot=True, fmt='d', cmap='Blues', ax=axes[i], annot_kws={'size': 14})
        label_map = {0: "End", 2: "End-2", 5: "End-5"}

        # Fix: Extract label separately to avoid nested f-strings
        label = label_map.get(res['offset'], f"Offset {res['offset']}")
        axes[i].set_title(f"{label} (Acc: {res['acc']:.1f}%)")

        axes[i].set_xlabel("Predicted")
        axes[i].set_ylabel("Actual")
    plt.suptitle("Confusion Matrices", fontsize=16)
    plt.tight_layout()
    plt.savefig(os.path.join(RUNS_DIR, "11x11_confusionmatrix.png"), dpi=300, bbox_inches='tight')
    plt.show()

# Capacity and Depth plots
if 'results_capacity' in dir() and 'results_depth' in dir():
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    df_cap = pd.DataFrame(results_capacity)
    ax1.plot(df_cap['clauses'], df_cap['overall'], 'o-', color='black', linewidth=2, markersize=8, label='Overall')
    ax1.plot(df_cap['clauses'], df_cap['winner0'], 's--', color='blue', linewidth=1.5, markersize=6, label='Winner 0')
    ax1.plot(df_cap['clauses'], df_cap['winner1'], '^--', color='red', linewidth=1.5, markersize=6, label='Winner 1')
    ax1.set_xlabel("Number of Clauses")
    ax1.set_ylabel("Accuracy (%)")
    ax1.set_title("Model Capacity")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    df_depth = pd.DataFrame(results_depth)
    ax2.plot(df_depth['depth'], df_depth['overall'], 'o-', color='black', linewidth=2, markersize=8, label='Overall')
    ax2.plot(df_depth['depth'], df_depth['winner0'], 's--', color='blue', linewidth=1.5, markersize=6, label='Winner 0')
    ax2.plot(df_depth['depth'], df_depth['winner1'], '^--', color='red', linewidth=1.5, markersize=6, label='Winner 1')
    ax2.set_xlabel("Message Passing Depth")
    ax2.set_ylabel("Accuracy (%)")
    ax2.set_title("Message Passing Depth")
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(RUNS_DIR, "11x11_model_message.png"), dpi=300, bbox_inches='tight')
    plt.show()

    # Class gap plots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    ax1.bar(df_cap['clauses'].astype(str), df_cap['gap'], color='coral', alpha=0.7)
    ax1.set_xlabel("Number of Clauses")
    ax1.set_ylabel("Class Gap (%)")
    ax1.set_title("Capacity vs Class Gap")
    ax1.grid(True, alpha=0.3, axis='y')

    ax2.bar(df_depth['depth'].astype(str), df_depth['gap'], color='skyblue', alpha=0.7)
    ax2.set_xlabel("Message Passing Depth")
    ax2.set_ylabel("Class Gap (%)")
    ax2.set_title("Depth vs Class Gap")
    ax2.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig(os.path.join(RUNS_DIR, "11x11_classgap_capacity_depth.png"), dpi=300, bbox_inches='tight')
    plt.show()

print(f"\n‚úì All plots saved to {RUNS_DIR}/")