# Methods - Optimized Version

In [316]:
from GraphTsetlinMachine.tm import MultiClassGraphTsetlinMachine
from GraphTsetlinMachine.graphs import Graphs
import numpy as np
import subprocess
import time
import os
from tqdm import tqdm

# Configuration
BOARD_DIM = 11
NOTEBOOK_DIR = os.path.dirname(os.path.abspath("Tsetlin.ipynb"))
HEX_DIR = os.path.join(NOTEBOOK_DIR, "TsetlinMachine/hex")

if not os.path.exists(HEX_DIR):
    raise FileNotFoundError(f"ERROR: Cannot find hex.c at {HEX_DIR}")

print("Building hex using make...")

try:
    result = subprocess.run(
        ["make"],
        cwd=HEX_DIR,
        capture_output=True,
        text=True
    )

    print("=== Make Output ===")
    print(result.stdout)
    if result.stderr.strip():
        print("=== Make Errors ===")
        print(result.stderr)

    if result.returncode == 0:
        print("\n✓ Build successful!")
    else:
        print("\n❌ Build failed! See errors above.")

except Exception as e:
    print("Exception when running make:", e)

def c_position_to_node_id(c_position, board_dim=BOARD_DIM):
    padded_dim = board_dim + 2
    i = c_position // padded_dim
    j = c_position % padded_dim
    node_id = (i - 1) * board_dim + (j - 1)

    if node_id < 0 or node_id >= board_dim * board_dim:
        return None
    return node_id

def get_hex_edges(board_dim=BOARD_DIM):
    edges = []
    neighbor_offsets = [(0, 1), (0, -1), (-1, 1), (1, -1), (-1, 0), (1, 0)]

    for i in range(board_dim):
        for j in range(board_dim):
            node_id = i * board_dim + j
            for di, dj in neighbor_offsets:
                ni, nj = i + di, j + dj
                if 0 <= ni < board_dim and 0 <= nj < board_dim:
                    neighbor_id = ni * board_dim + nj
                    edges.append((node_id, neighbor_id))

    return edges

def parse_game_output(output):
    games = []
    current_game = None

    for line in output.split('\n'):
        line = line.strip()

        if line == "GAME_START":
            current_game = {'moves': [], 'winner': -1}
        elif line.startswith("MOVE"):
            if current_game is not None:
                parts = line.split()
                if len(parts) >= 3:
                    position = int(parts[1])
                    player = int(parts[2])
                    current_game['moves'].append((position, player))
        elif line.startswith("WINNER"):
            if current_game is not None:
                parts = line.split()
                if len(parts) >= 2:
                    current_game['winner'] = int(parts[1])
        elif line == "GAME_END":
            if current_game and current_game['winner'] != -1:
                games.append(current_game)
            current_game = None

    return games

def create_training_data_from_game(moves, winner, board_dim=BOARD_DIM):
    """
    Create ONE training sample per game:
      - board_state: final board (0=empty, 1=player0, 2=player1)
      - label: winner of the game (0 or 1)

    We keep node_features as before for compatibility, but the main
    object we care about is the final board_state.
    """
    num_nodes = board_dim * board_dim
    board_state = np.zeros(num_nodes, dtype=np.int32)
    edges = get_hex_edges(board_dim)  # not used in new Graphs, but kept

    # Play through the whole game to get the final board
    for c_position, player in moves:
        node_id = c_position_to_node_id(c_position, board_dim)
        if node_id is None:
            print(f"Skipping invalid move: c_pos={c_position}")
            continue
        board_state[node_id] = player + 1  # 1 = player 0, 2 = player 1

    # Build node_features from the FINAL full board_state
    node_features = np.zeros((num_nodes, 3), dtype=np.int32)
    for nid in range(num_nodes):
        if board_state[nid] == 1:
            node_features[nid, 0] = 1  # player_0 stone
        elif board_state[nid] == 2:
            node_features[nid, 1] = 1  # player_1 stone
        else:
            node_features[nid, 2] = 1  # empty

    label = int(winner)  # 0 or 1

    sample = {
        'board_state': board_state.reshape(board_dim, board_dim),
        'node_features': node_features,   # still there if you need it
        'edges': edges,                   # unused in new Graph building
        'position': -1,                   # not used now
        'player': -1,                     # not used now
        'label': label
    }

    # Return as a list to match old API
    return [sample]


def prepare_training_data(games, board_dim=BOARD_DIM):
    """
    Turn a list of games into a list of FINAL-state → winner samples.
    Exactly one sample per game.
    """
    all_samples = []

    print(f"Processing {len(games)} games into training samples...")

    # Game-level statistics
    player_0_wins = sum(1 for g in games if g['winner'] == 0)
    player_1_wins = sum(1 for g in games if g['winner'] == 1)
    print(f"Game outcomes (per game):")
    print(f"  Player 0 wins: {player_0_wins}")
    print(f"  Player 1 wins: {player_1_wins}")

    for game in tqdm(games, desc="Processing games"):
        samples = create_training_data_from_game(game['moves'], game['winner'], board_dim)
        all_samples.extend(samples)

    if len(all_samples) == 0:
        print("ERROR: No training samples created! Check your logic.")
        return all_samples

    labels = [s['label'] for s in all_samples]
    unique, counts = np.unique(labels, return_counts=True)
    print(f"\nLabel distribution (winner classes, per FINAL board):")
    for label, count in zip(unique, counts):
        print(f"  Winner {label}: {count} games ({count/len(labels)*100:.1f}%)")

    # Quick sanity check of final board states
    print("\nSample final board state check (first 5 samples):")
    for i in range(min(5, len(all_samples))):
        sample = all_samples[i]
        pieces = np.sum(sample['node_features'][:, :2])
        empties = np.sum(sample['node_features'][:, 2])
        print(f"  Sample {i}: {pieces} stones, {empties} empty cells, label(winner)={sample['label']}")

    print(f"{'='*60}\n")

    return all_samples


def generate_game_data(num_games=1000, hex_dir=HEX_DIR):
    hex_executable = os.path.join(hex_dir, "hex")

    if not os.path.exists(hex_executable):
        print(f"Executable not found at {hex_executable}")
        return []

    print(f"Generating {num_games} games...")

    try:
        result = subprocess.run(
            [hex_executable, str(num_games)],
            cwd=hex_dir,
            capture_output=True,
            text=True,
            timeout=120
        )

        if result.returncode != 0:
            print(f"Error running hex executable:")
            print(result.stderr)
            return []

        games = parse_game_output(result.stdout)
        print(f"Successfully parsed {len(games)} games from output")
        return games

    except Exception as e:
        print(f"Error running hex executable: {e}")
        return []


def prepare_gtm_data(training_samples, board_dim=BOARD_DIM,
                     hypervector_size=1024, hypervector_bits=2):
    """
    Build a Graphs object where each board cell is a node, etc...
    """
    from GraphTsetlinMachine.graphs import Graphs

    Y = np.array([s['label'] for s in training_samples], dtype=np.int32)

    num_graphs = len(training_samples)
    num_nodes = board_dim * board_dim

    symbols = ['player_0', 'player_1', 'empty']

    print("Creating multi-node Hex Graphs object...")
    print(f"  Number of graphs: {num_graphs}")
    print(f"  Nodes per graph: {num_nodes} ({board_dim} x {board_dim})")
    # 🔧 avoid printing the whole object, just its length
    print(f"  Number of symbols: {len(symbols)}")
    print(f"  Hypervector size: {hypervector_size}, bits: {hypervector_bits}")

    graphs, Y = prepare_gtm_data(
        training_samples,
        board_dim=BOARD_DIM,
        hypervector_size=hypervector_size,
        hypervector_bits=hypervector_bits
    )

    # 1) Set number of nodes
    print("Step 1: Setting number of nodes per graph...")
    for graph_id in tqdm(range(num_graphs), desc="Setting nodes"):
        graphs.set_number_of_graph_nodes(graph_id, num_nodes)

    # 2) Node configuration
    print("Step 2: Preparing node configuration...")
    graphs.prepare_node_configuration()

    # 3) Add nodes with edge counts
    print("Step 3: Adding nodes with edge counts...")
    hex_edges = get_hex_edges(board_dim)

    edge_counts = np.zeros(num_nodes, dtype=np.uint32)
    for src, _ in hex_edges:
        edge_counts[src] += 1

    for graph_id in tqdm(range(num_graphs), desc="Adding nodes"):
        for node_id in range(num_nodes):
            graphs.add_graph_node(graph_id, node_id, int(edge_counts[node_id]))

    # 4) Edge configuration
    print("Step 4: Preparing edge configuration...")
    graphs.prepare_edge_configuration()

    # 5) Add edges
    print("Step 5: Adding edges...")
    edge_type_name = "hex_neighbor"
    if edge_type_name not in graphs.edge_type_id:
        graphs.edge_type_id[edge_type_name] = len(graphs.edge_type_id)
    edge_type_id = graphs.edge_type_id[edge_type_name]

    edge_data = [(src, dst, edge_type_id) for (src, dst) in hex_edges]

    for graph_id in tqdm(range(num_graphs), desc="Populating edges"):
        node_index = graphs.node_index[graph_id]

        for src_node_id, dest_node_id, etype in edge_data:
            base = graphs.edge_index[node_index + src_node_id]
            offset = graphs.graph_node_edge_counter[node_index + src_node_id]
            edge_idx = base + offset

            graphs.edge[edge_idx][0] = dest_node_id
            graphs.edge[edge_idx][1] = etype
            graphs.graph_node_edge_counter[node_index + src_node_id] += 1

    # 6) Add node properties
    print("Step 6: Adding node properties (stone occupancy)...")
    for graph_id in tqdm(range(num_graphs), desc="Adding properties"):
        node_features = training_samples[graph_id]['node_features']
        for node_id in range(num_nodes):
            if node_features[node_id, 0] == 1:
                graphs.add_graph_node_property(graph_id, node_id, "player_0")
            elif node_features[node_id, 1] == 1:
                graphs.add_graph_node_property(graph_id, node_id, "player_1")
            else:
                graphs.add_graph_node_property(graph_id, node_id, "empty")

    # 7) Encode graphs
    print("Step 7: Encoding graphs...")
    graphs.encode()

    print(f"\n✓ Prepared {num_graphs} multi-node Hex graphs")
    unique_labels, label_counts = np.unique(Y, return_counts=True)
    label_dist = ", ".join([f"Winner {lab}={cnt}" for lab, cnt in zip(unique_labels, label_counts)])
    print(f"  Label distribution: {label_dist}")

    return graphs, Y



def train_model(graphs, Y, epochs=100):
    """
    Train a MultiClassGraphTsetlinMachine to predict the WINNER (0 or 1)
    from a given board state.
    """
    NUMBER_OF_CLAUSES = 3000
    T = 25000
    S = 10.0
    DEPTH = 3
    MESSAGE_SIZE = 128
    MESSAGE_BITS = 2

    print("Initializing Graph Tsetlin Machine...")
    print(f"  Clauses: {NUMBER_OF_CLAUSES}")
    print(f"  T: {T}")
    print(f"  s: {S}")
    print(f"  Depth: {DEPTH}")
    print(f"  Message Size: {MESSAGE_SIZE}")

    tm = MultiClassGraphTsetlinMachine(
        number_of_clauses=NUMBER_OF_CLAUSES,
        T=T,
        s=S,
        number_of_state_bits=8,
        depth=DEPTH,
        message_size=MESSAGE_SIZE,
        message_bits=MESSAGE_BITS,
        max_included_literals=32,
        grid=(16 * 13, 1, 1),
        block=(128, 1, 1)
    )

    # Class balancing: oversample minority class of winner
    class_0_indices = np.where(Y == 0)[0]
    class_1_indices = np.where(Y == 1)[0]

    print(f"\nClass distribution before balancing (winner classes):")
    print(f"  Winner 0: {len(class_0_indices)} states")
    print(f"  Winner 1: {len(class_1_indices)} states")

    if len(class_0_indices) > 0 and len(class_1_indices) > 0:
        # Balance by repeating minority class
        if len(class_0_indices) < len(class_1_indices):
            oversample_ratio = len(class_1_indices) // len(class_0_indices)
            class_0_indices = np.tile(class_0_indices, oversample_ratio)
        else:
            oversample_ratio = len(class_0_indices) // len(class_1_indices)
            class_1_indices = np.tile(class_1_indices, oversample_ratio)

        balanced_indices = np.concatenate([class_0_indices, class_1_indices])
        np.random.shuffle(balanced_indices)
        print(f"After balancing: {len(balanced_indices)} total samples")
    else:
        balanced_indices = np.arange(len(Y))
        print("Warning: only one winner class present; no balancing applied.")

    print(f"\nStarting training for {epochs} epochs...")
    print("="*60)

    start_total = time.time()

    for epoch in range(epochs):
        start_epoch = time.time()

        # NOTE: current GTM implementation uses full graph set; balancing
        # is mainly informational here. To actually subsample, GTM would need
        # support for graph subsets.
        tm.fit(graphs, Y, epochs=1, incremental=True)
        elapsed = time.time() - start_epoch

        predictions = tm.predict(graphs)
        accuracy = 100 * (predictions == Y).mean()

        class_0_mask = (Y == 0)
        class_1_mask = (Y == 1)
        class_0_acc = 100 * (predictions[class_0_mask] == 0).mean() if class_0_mask.any() else 0
        class_1_acc = 100 * (predictions[class_1_mask] == 1).mean() if class_1_mask.any() else 0

        print(f"Epoch {epoch+1}/{epochs} - Acc: {accuracy:.2f}% "
              f"(Winner 0 states: {class_0_acc:.1f}%, Winner 1 states: {class_1_acc:.1f}%) - {elapsed:.2f}s")

    total_time = time.time() - start_total
    print("\n" + "="*60)
    print(f"✓ Training completed in {total_time:.2f}s ({total_time/60:.2f} minutes)")

    print("\nFinal Evaluation...")
    predictions = tm.predict(graphs)

    accuracy = 100 * (predictions == Y).mean()
    print(f"\nOverall Accuracy: {accuracy:.2f}%")

    for class_id in [0, 1]:
        mask = Y == class_id
        if mask.any():
            class_acc = 100 * (predictions[mask] == class_id).mean()
            print(f"Winner {class_id} states: {class_acc:.2f}% "
                  f"({(predictions[mask] == class_id).sum()}/{mask.sum()})")

    return tm, predictions


Building hex using make...
=== Make Output ===
make: 'hex' is up to date.


✓ Build successful!


## Generate Game Data

Run this cell to generate Hex games and create training samples.

In [317]:
# Generate games
NUM_GAMES = 1000  # Adjust as needed

print(f"Generating {NUM_GAMES} Hex games...")
games = generate_game_data(NUM_GAMES)

if not games:
    raise Exception("No games generated! Check hex executable.")

print(f"\n✓ Successfully generated {len(games)} games")

# Process into training samples (FINAL state -> winner)
training_samples = prepare_training_data(games, BOARD_DIM)
print("\n" + "="*60)
print("DIAGNOSTIC: Checking final board states:")
print("="*60)
for i in range(min(10, len(training_samples))):
    sample = training_samples[i]
    non_zero = np.sum(sample['node_features'][:, :2])  # Count player pieces
    empty = np.sum(sample['node_features'][:, 2])      # Count empty cells
    print(f"Sample {i}: {non_zero} stones on board, {empty} empty cells, label(winner)={sample['label']}")

print("="*60 + "\n")
print(f"\n✓ Training data ready: {len(training_samples)} samples")


Generating 1000 Hex games...
Generating 1000 games...
Successfully parsed 1000 games from output

✓ Successfully generated 1000 games
Processing 1000 games into training samples...
Game outcomes (per game):
  Player 0 wins: 505
  Player 1 wins: 495


Processing games: 100%|██████████| 1000/1000 [00:00<00:00, 2336.25it/s]


Label distribution (winner classes, per FINAL board):
  Winner 0: 505 games (50.5%)
  Winner 1: 495 games (49.5%)

Sample final board state check (first 5 samples):
  Sample 0: 116 stones, 5 empty cells, label(winner)=1
  Sample 1: 115 stones, 6 empty cells, label(winner)=0
  Sample 2: 108 stones, 13 empty cells, label(winner)=1
  Sample 3: 115 stones, 6 empty cells, label(winner)=0
  Sample 4: 112 stones, 9 empty cells, label(winner)=1


DIAGNOSTIC: Checking final board states:
Sample 0: 116 stones on board, 5 empty cells, label(winner)=1
Sample 1: 115 stones on board, 6 empty cells, label(winner)=0
Sample 2: 108 stones on board, 13 empty cells, label(winner)=1
Sample 3: 115 stones on board, 6 empty cells, label(winner)=0
Sample 4: 112 stones on board, 9 empty cells, label(winner)=1
Sample 5: 111 stones on board, 10 empty cells, label(winner)=0
Sample 6: 102 stones on board, 19 empty cells, label(winner)=1
Sample 7: 106 stones on board, 15 empty cells, label(winner)=1
Sample 8: 109 s




## Prepare Data for Graph Tsetlin Machine

Convert training samples into the GTM Graphs format.

In [318]:
from GraphTsetlinMachine.graphs import Graphs

# Extract labels (winner 0/1 per FINAL board)
Y = np.array([s['label'] for s in training_samples], dtype=np.uint32)

print("Preparing data in GTM Graphs format (MNIST-style)...")

# --- Define symbols: one for each (player, cell) ---

symbols = []
for i in range(BOARD_DIM):
    for j in range(BOARD_DIM):
        symbols.append(f"P0_{i}_{j}")  # player 0 stone at (i,j)
        symbols.append(f"P1_{i}_{j}")  # player 1 stone at (i,j)

num_nodes = 1  # ONE node per graph, just like 'Image Node' in MNIST example

graphs = Graphs(
    number_of_graphs=len(training_samples),
    symbols=symbols,
    hypervector_size=1024,
    hypervector_bits=2,
    double_hashing=False,
    # one_hot_encoding=False by default unless you want one-hot
)

print("Step 1: Setting number of nodes for each graph to 1...")
for graph_id in range(len(training_samples)):
    graphs.set_number_of_graph_nodes(graph_id, num_nodes)

print("Step 2: Preparing node configuration...")
graphs.prepare_node_configuration()

print("Step 3: Adding the single 'Board' node to each graph...")
for graph_id in range(len(training_samples)):
    number_of_outgoing_edges = 0  # we don't use edges here
    graphs.add_graph_node(graph_id, 'Board', number_of_outgoing_edges)

print("Step 4: Preparing edge configuration (no edges, but required call)...")
graphs.prepare_edge_configuration()

print("Step 5: Adding node properties based on FINAL board state...")

for graph_id in range(len(training_samples)):
    if graph_id % 1000 == 0:
        print(f"  Processing graph {graph_id}/{len(training_samples)}")

    # node_features: shape (board_dim^2, 3)
    node_features = training_samples[graph_id]['node_features']

    for node_id in range(BOARD_DIM * BOARD_DIM):
        i = node_id // BOARD_DIM
        j = node_id % BOARD_DIM

        if node_features[node_id, 0] == 1:
            # Player 0 stone at (i,j)
            graphs.add_graph_node_property(graph_id, 'Board', f"P0_{i}_{j}")
        elif node_features[node_id, 1] == 1:
            # Player 1 stone at (i,j)
            graphs.add_graph_node_property(graph_id, 'Board', f"P1_{i}_{j}")
        # empty cells add no property

print("Step 6: Encoding graphs...")
graphs.encode()

print(f"\n✓ Prepared {len(training_samples)} graphs")
print(f"  Nodes per graph: {num_nodes}")
print(f"  Symbols per graph node: up to {BOARD_DIM*BOARD_DIM*2} (stones only)")
unique_labels, label_counts = np.unique(Y, return_counts=True)
label_dist = ", ".join([f"Winner {label}={count}" for label, count in zip(unique_labels, label_counts)])
print(f"  Label distribution: {label_dist}")


Preparing data in GTM Graphs format (MNIST-style)...
Step 1: Setting number of nodes for each graph to 1...
Step 2: Preparing node configuration...
Step 3: Adding the single 'Board' node to each graph...
Step 4: Preparing edge configuration (no edges, but required call)...
Step 5: Adding node properties based on FINAL board state...
  Processing graph 0/1000
Step 6: Encoding graphs...

✓ Prepared 1000 graphs
  Nodes per graph: 1
  Symbols per graph node: up to 242 (stones only)
  Label distribution: Winner 0=505, Winner 1=495


## Train the Graph Tsetlin Machine

Train the model on the prepared graph data.

In [319]:
tm, predictions = train_model(graphs, Y, epochs=100)

Initializing Graph Tsetlin Machine...
  Clauses: 3000
  T: 25000
  s: 10.0
  Depth: 3
  Message Size: 128
Initialization of sparse structure.

Class distribution before balancing (winner classes):
  Winner 0: 505 states
  Winner 1: 495 states
After balancing: 1000 total samples

Starting training for 100 epochs...
Epoch 1/100 - Acc: 64.20% (Winner 0 states: 37.0%, Winner 1 states: 91.9%) - 1.31s
Epoch 2/100 - Acc: 72.60% (Winner 0 states: 55.4%, Winner 1 states: 90.1%) - 0.64s
Epoch 3/100 - Acc: 79.30% (Winner 0 states: 76.0%, Winner 1 states: 82.6%) - 0.65s
Epoch 4/100 - Acc: 80.10% (Winner 0 states: 82.6%, Winner 1 states: 77.6%) - 0.64s
Epoch 5/100 - Acc: 80.20% (Winner 0 states: 80.4%, Winner 1 states: 80.0%) - 0.65s
Epoch 6/100 - Acc: 83.70% (Winner 0 states: 85.0%, Winner 1 states: 82.4%) - 0.65s
Epoch 7/100 - Acc: 82.80% (Winner 0 states: 84.0%, Winner 1 states: 81.6%) - 0.66s
Epoch 8/100 - Acc: 83.40% (Winner 0 states: 83.2%, Winner 1 states: 83.6%) - 0.66s
Epoch 9/100 - Acc: 8

## Save the Trained Model

In [320]:
import pickle

# Save the trained model
model_path = "TsetlinMachine/hex_tm_model.pkl"
print(f"Saving trained model to {model_path}...")

# Get the model state (weights, clauses, etc.) instead of the whole object
model_state = tm.get_state()

# Save the state along with model configuration
model_save_data = {
    'state': model_state,
    'config': {
        'number_of_clauses': tm.number_of_clauses,
        'T': tm.T,
        's': tm.s,
        'number_of_state_bits': tm.number_of_state_bits,
        'depth': tm.depth,
        'message_size': tm.message_size,
        'message_bits': tm.message_bits,
        'max_included_literals': tm.max_included_literals,
        'grid': (16 * 13, 1, 1),
        'block': (128, 1, 1)
    },
    'symbols': symbols,
    'num_nodes': num_nodes,
    'board_dim': BOARD_DIM
}

with open(model_path, 'wb') as f:
    pickle.dump(model_save_data, f)

print("✓ Model saved successfully")

# Show some example predictions
print("\nExample predictions (first 10):")
for i in range(min(10, len(predictions))):
    pred = predictions[i]
    true = Y[i]
    status = "✓" if pred == true else "✗"
    print(f"{status} Sample {i}: Pred={pred}, True={true}")

Saving trained model to TsetlinMachine/hex_tm_model.pkl...
✓ Model saved successfully

Example predictions (first 10):
✓ Sample 0: Pred=1, True=1
✓ Sample 1: Pred=0, True=0
✓ Sample 2: Pred=1, True=1
✓ Sample 3: Pred=0, True=0
✓ Sample 4: Pred=1, True=1
✓ Sample 5: Pred=0, True=0
✓ Sample 6: Pred=1, True=1
✓ Sample 7: Pred=1, True=1
✓ Sample 8: Pred=0, True=0
✓ Sample 9: Pred=1, True=1


## Load a Trained Model (Optional)

Use this to load a previously trained model.

In [321]:
def load_trained_model(model_path):
    """Load a previously trained Tsetlin Machine"""
    import pickle
    from GraphTsetlinMachine.tm import MultiClassGraphTsetlinMachine

    print(f"Loading model from {model_path}...")

    with open(model_path, 'rb') as f:
        model_data = pickle.load(f)

    # Recreate the model with saved configuration
    config = model_data['config']
    tm = MultiClassGraphTsetlinMachine(
        number_of_clauses=config['number_of_clauses'],
        T=config['T'],
        s=config['s'],
        number_of_state_bits=config['number_of_state_bits'],
        depth=config['depth'],
        message_size=config['message_size'],
        message_bits=config['message_bits'],
        max_included_literals=config['max_included_literals'],
        grid=config['grid'],
        block=config['block']
    )

    # Restore the model state
    tm.set_state(model_data['state'])

    print("✓ Model loaded successfully")

    return tm, model_data

# Example usage:
loaded_tm, loaded_data = load_trained_model("TsetlinMachine/hex_tm_model.pkl")

Loading model from TsetlinMachine/hex_tm_model.pkl...
Initialization of sparse structure.
✓ Model loaded successfully


In [322]:
import contextlib
import io
import numpy as np

def evaluate_trained_model_quiet(
    model_path="TsetlinMachine/hex_tm_model.pkl",
    num_test_games=200,
    hypervector_size=1024,
    hypervector_bits=2
):
    """Quiet evaluation: suppress all internal prints, show only final metrics."""

    # Suppress prints using redirect_stdout
    def silent_call(fn, *args, **kwargs):
        buf = io.StringIO()
        with contextlib.redirect_stdout(buf):
            return fn(*args, **kwargs)

    # 1) Load model
    tm, model_data = silent_call(load_trained_model, model_path)
    board_dim = model_data.get("board_dim", BOARD_DIM)
    print("✓ Model loaded")

    # 2) Generate new test games (suppress print spam)
    test_games = silent_call(generate_game_data, num_test_games)

    if not test_games:
        print("ERROR: No test games generated.")
        return

    # 3) Convert games → samples (quiet)
    test_samples = []
    for g in test_games:
        samples = silent_call(
            create_training_data_from_game,
            g["moves"],
            g["winner"],
            board_dim
        )
        test_samples.extend(samples)

    if len(test_samples) == 0:
        print("ERROR: No test samples created.")
        return

    Y_test = np.array([s["label"] for s in test_samples], dtype=np.int32)

    # 4) Build graphs (quiet)
    graphs_test, Y_check = silent_call(
        prepare_gtm_data,
        test_samples,
        board_dim,
        hypervector_size,
        hypervector_bits
    )

    print("✓ Test data prepared")

    # 5) Predict (quiet)
    preds = silent_call(tm.predict, graphs_test).astype(int)

    # 6) Accuracy
    overall = 100 * (preds == Y_test).mean()

    # Winner-wise accuracy
    mask0 = (Y_test == 0)
    mask1 = (Y_test == 1)

    acc0 = 100 * (preds[mask0] == 0).mean() if mask0.any() else 0
    acc1 = 100 * (preds[mask1] == 1).mean() if mask1.any() else 0

    # Print only relevant things
    print(f"Overall accuracy: {overall:.2f}%")
    print(f"Winner 0 accuracy: {acc0:.2f}%")
    print(f"Winner 1 accuracy: {acc1:.2f}%")

    return {
        "overall": overall,
        "winner0": acc0,
        "winner1": acc1,
        "preds": preds,
        "labels": Y_test
    }


In [323]:
results = evaluate_trained_model_quiet(
    "TsetlinMachine/hex_tm_model.pkl",
    num_test_games=300,
    hypervector_size=1024,
    hypervector_bits=2
)


✓ Model loaded


RecursionError: maximum recursion depth exceeded