# Hex Graph Tsetlin Machine Experiments

This notebook implements the full data pipeline and experimental suite for Hex winner prediction.
It allows dynamic board sizing and reproduces "Model Capacity" and "Message Passing Depth" experiments.

In [None]:
import os
import subprocess
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import time
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100

# === CONFIGURATION ===
BOARD_DIM = 3  # CHANGE THIS to 3, 4, 11 etc.
N_GAMES = 10000
SEED = 42

RUNS_DIR = "runs"
os.makedirs(RUNS_DIR, exist_ok=True)
CSV_PATH = os.path.join(RUNS_DIR, f"hex_moves_dim{BOARD_DIM}_n{N_GAMES}.csv")


## 1. Data Generation
We call the Hex engine via our wrapper script. The script handles recompiling the C code if the `BOARD_DIM` changes.

In [None]:
# Run Hex Engine
# Pass BOARD_DIM as environment variable to script

cmd = [
    "./scripts/run_hex.sh",
    "--games", str(N_GAMES),
    "--seed", str(SEED),
    "--dump-moves", CSV_PATH
]

env = os.environ.copy()
env["BOARD_DIM"] = str(BOARD_DIM)

print(f"Generating {N_GAMES} games for {BOARD_DIM}x{BOARD_DIM} board...")
start_time = time.time()
subprocess.run(cmd, env=env, check=True)
print(f"Done in {time.time() - start_time:.2f}s")

print(f"Generated: {CSV_PATH}")

## 2. Process Data
Parse the CSV and create training/test splits.

In [None]:
def load_and_process_data(csv_path, offset=0):
    print(f"Loading {csv_path}...")
    df = pd.read_csv(csv_path)
    
    # We want ONE sample per game (final state)
    # We can group by game_id and reconstruct board
    
    # Quick reconstruction: use the raw position index from CSV
    # 0=Empty, 1=P0, 2=P1
    
    n_nodes = BOARD_DIM * BOARD_DIM
    
    x_features = []
    o_features = []
    labels = []
    
    for game_id, group in tqdm(df.groupby("game_id"), desc="Processing Games"):
        winner = group["winner"].iloc[0]
        
        # Initialize board features (n_nodes)
        # We use 2 channels: Is_P0, Is_P1
        p0_map = np.zeros(n_nodes, dtype=np.int8)
        p1_map = np.zeros(n_nodes, dtype=np.int8)
        
        # Replay moves
        # CSV has: game_id,move_idx,player,row,col,position,winner
        # position is raw index. Note: C code might include borders in 'position' logic if not careful
        # BUT our CSV dumping logic in C (based on my implementation) calculates raw 0..N-1 row/col too
        # actually let's double check hex.c logic:
        # pos = hg.moves[i]; row = pos / (DIM+2) - 1; col = pos % (DIM+2) - 1
        # We'll use row, col to be safe.
        
        for _, row in group.iterrows():
            r, c = int(row['row']), int(row['col'])
            if 0 <= r < BOARD_DIM and 0 <= c < BOARD_DIM:
                idx = r * BOARD_DIM + c
                if int(row['player']) == 0:
                    p0_map[idx] = 1
                else:
                    p1_map[idx] = 1
                    
        x_features.append(p0_map)
        o_features.append(p1_map)
        labels.append(winner)
        
    return np.array(x_features), np.array(o_features), np.array(labels)

X_feat, O_feat, Y = load_and_process_data(CSV_PATH)

# Split Train/Test (50/50 for small data, or customized)
from sklearn.model_selection import train_test_split
indices = np.arange(len(Y))
train_idx, test_idx = train_test_split(indices, test_size=0.5, random_state=SEED, stratify=Y)

X_train = (X_feat[train_idx], O_feat[train_idx])
Y_train = Y[train_idx]
X_test = (X_feat[test_idx], O_feat[test_idx])
Y_test = Y[test_idx]

print(f"Train: {len(Y_train)} samples, Test: {len(Y_test)} samples")
print(f"Train Class Dist: 0={np.sum(Y_train==0)}, 1={np.sum(Y_train==1)}")

## 3. Graph Preparation for GTM
Convert the features into the `GraphTsetlinMachine` format.

In [None]:
try:
    from GraphTsetlinMachine.graphs import Graphs
    from GraphTsetlinMachine.tm import MultiClassGraphTsetlinMachine
    HAS_GTM = True
except ImportError:
    print("WARNING: GraphTsetlinMachine not installed. Training cells will fail.")
    HAS_GTM = False

def prepare_gtm_graphs(x_feat, o_feat, board_dim, init_with=None):
    if not HAS_GTM: return None
    
    n_samples = len(x_feat)
    n_nodes = board_dim * board_dim
    symbols = ["Empty", "Player0", "Player1"]
    
    graphs = Graphs(
        n_samples,
        symbols=symbols,
        hypervector_size=1024,
        hypervector_bits=2,
        init_with=init_with
    )
    
    # Pre-calculate edges (Hex grid)
    # Offsets: (0, 1), (0, -1), (-1, 1), (1, -1), (-1, 0), (1, 0)
    offsets = [(0, 1), (0, -1), (-1, 1), (1, -1), (-1, 0), (1, 0)]
    
    adjacency = {}
    for r in range(board_dim):
        for c in range(board_dim):
            u = r * board_dim + c
            neighbors = []
            for dr, dc in offsets:
                nr, nc = r + dr, c + dc
                if 0 <= nr < board_dim and 0 <= nc < board_dim:
                    v = nr * board_dim + nc
                    neighbors.append(v)
            adjacency[u] = neighbors
            
    # 1. Config Nodes
    for i in range(n_samples):
        graphs.set_number_of_graph_nodes(i, n_nodes)
    graphs.prepare_node_configuration()
    
    # 2. Add Nodes and Edges Count
    for i in range(n_samples):
        for u in range(n_nodes):
            graphs.add_graph_node(i, u, len(adjacency[u]))
    graphs.prepare_edge_configuration()
    
    # 3. Add Edges
    for i in range(n_samples):
        for u in range(n_nodes):
            for v in adjacency[u]:
                graphs.add_graph_node_edge(i, u, v, 0) # Edge type 0
                
    # 4. Add Properties (Features)
    for i in range(n_samples):
        p0 = x_feat[i]
        p1 = o_feat[i]
        for u in range(n_nodes):
            if p0[u]:
                graphs.add_graph_node_property(i, u, "Player0")
            elif p1[u]:
                graphs.add_graph_node_property(i, u, "Player1")
            else:
                graphs.add_graph_node_property(i, u, "Empty")
                
    graphs.encode()
    return graphs

if HAS_GTM:
    print("Preparing Training Graphs...")
    graphs_train = prepare_gtm_graphs(X_train[0], X_train[1], BOARD_DIM)
    print("Preparing Test Graphs...")
    # Init with train to share hypervectors (crucial)
    graphs_test = prepare_gtm_graphs(X_test[0], X_test[1], BOARD_DIM, init_with=graphs_train)

## 4. Experiment 2: Model Capacity
Varying the number of clauses.

In [None]:
EXP_EPOCHS = 30
# Clauses to test
clause_configs = [100, 400, 1600]
# Fixed params
HYPER_T = 2000
HYPER_S = 5.0
DEPTH = 5

results_cap = []

if HAS_GTM:
    for clauses in clause_configs:
        print(f"\n--- Running Capacity Exp: {clauses} Clauses ---")
        tm = MultiClassGraphTsetlinMachine(
            number_of_clauses=clauses,
            T=HYPER_T,
            s=HYPER_S,
            depth=DEPTH,
            message_size=512,
            message_bits=2,
            max_included_literals=32,
            grid=(16*13, 1, 1),
            block=(128, 1, 1)
        )
        
        # Training loop
        for i in range(EXP_EPOCHS):
            tm.fit(graphs_train, Y_train, epochs=1, incremental=True)
            # Optional: Print progress
            if (i+1) % 10 == 0:
                print(f"Epoch {i+1}/{EXP_EPOCHS}")
        
        # Eval
        preds = tm.predict(graphs_test)
        acc = accuracy_score(Y_test, preds)
        print(f"Test Accuracy: {acc*100:.2f}%")
        
        results_cap.append({
            "Clauses": clauses,
            "Accuracy": acc,
            "Predictions": preds
        })

## 5. Experiment 3: Message Passing Depth
Varying the depth of the graph convolution.

In [None]:
depth_configs = [1, 3, 5, 8]
# Fixed
FIXED_CLAUSES = 400

results_depth = []

if HAS_GTM:
    for d in depth_configs:
        print(f"\n--- Running Depth Exp: Depth {d} ---")
        tm = MultiClassGraphTsetlinMachine(
            number_of_clauses=FIXED_CLAUSES,
            T=HYPER_T,
            s=HYPER_S,
            depth=d,
            message_size=512,
            message_bits=2,
            max_included_literals=32,
            grid=(16*13, 1, 1),
            block=(128, 1, 1)
        )
        
        for i in range(EXP_EPOCHS):
            tm.fit(graphs_train, Y_train, epochs=1, incremental=True)
            
        preds = tm.predict(graphs_test)
        acc = accuracy_score(Y_test, preds)
        print(f"Test Accuracy: {acc*100:.2f}%")
        
        results_depth.append({
            "Depth": d,
            "Accuracy": acc,
            "Predictions": preds
        })

## 6. Visualization
Plotting results side-by-side.

In [None]:
if HAS_GTM:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot Capacity
    clauses = [r['Clauses'] for r in results_cap]
    accs = [r['Accuracy']*100 for r in results_cap]
    ax1.plot(clauses, accs, marker='o', color='green', linewidth=2)
    ax1.set_xscale('log')
    ax1.set_xlabel('Clauses')
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_title(f'Accuracy vs Capacity (Depth={DEPTH})')
    ax1.grid(True, alpha=0.3)
    
    # Plot Depth
    depths = [r['Depth'] for r in results_depth]
    accs_d = [r['Accuracy']*100 for r in results_depth]
    ax2.plot(depths, accs_d, marker='s', color='purple', linewidth=2)
    ax2.set_xlabel('Depth')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title(f'Accuracy vs Depth (Clauses={FIXED_CLAUSES})')
    ax2.grid(True, alpha=0.3)
    
    plt.show()