In [None]:
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
from time import time
from collections import Counter
import gc
from GraphTsetlinMachine.graphs import Graphs
from GraphTsetlinMachine.tm import MultiClassGraphTsetlinMachine
from sklearn.metrics import classification_report, confusion_matrix

seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)

class Args:
    def __init__(self, **kwargs):
        self.epochs = 10
        self.number_of_clauses = 20000
        self.T = 16000
        self.s = 5.0
        self.depth = 10
        self.hypervector_size = 4096
        self.hypervector_bits = 4
        self.message_size = 4096
        self.message_bits = 4
        self.double_hashing = True
        self.max_included_literals = 128
        self.batch_size = 500
        self.patience = 5
        for key, value in kwargs.items():
            setattr(self, key, value)

args = Args()

start_time = time()
try:
    data = pd.read_csv('datasett/5moves_13x13.csv')
    data = data.sample(30000, random_state=seed_value).reset_index(drop=True)
except FileNotFoundError:
    print("Error: Dataset not found.")
    exit(-1)
end_time = time()
print(f"Loading data took {end_time - start_time:.2f} seconds")
print(f"Dataset size after sampling: {data.shape}")

board_size = 13
cell_columns = [f'cell{row}_{col}' for row in range(board_size) for col in range(board_size)]

required_columns = ['winner', 'starting_player'] + cell_columns
missing_columns = [col for col in required_columns if col not in data.columns]
if missing_columns:
    print(f"Error: Missing columns: {missing_columns}")
    exit(-1)

X_df = data[cell_columns]
y = data['winner'].values.astype(int)
starting_player = data['starting_player'].values.astype(int)

if X_df.isnull().values.any():
    print("Warning: Missing values detected. Filling with 0.")
    X_df = X_df.fillna(0)

unique_labels = np.unique(y)
if not set(unique_labels).issubset({0,1}):
    label_mapping = {label: idx for idx, label in enumerate(unique_labels)}
    y = np.array([label_mapping[label] for label in y])
    print("Labels mapped to:", label_mapping)

def augment_data(X_df, y, sp):
    X_aug = X_df.copy()
    X_aug = X_aug.iloc[:, ::-1]
    y_aug = y.copy()
    sp_aug = sp.copy()
    X_aug = X_aug.replace({1: -1, -1: 1})
    return pd.concat([X_df, X_aug], ignore_index=True), np.concatenate([y, y_aug]), np.concatenate([sp, sp_aug])

X_df, y, starting_player = augment_data(X_df, y, starting_player)
print(f"Dataset size after augmentation: {X_df.shape}")
print("Class distribution after augmentation:", Counter(y))

train_size = int(len(X_df)*0.8)
X_train_df = X_df.iloc[:train_size].reset_index(drop=True)
y_train = y[:train_size]
sp_train = starting_player[:train_size]
X_test_df = X_df.iloc[train_size:].reset_index(drop=True)
y_test = y[train_size:]
sp_test = starting_player[train_size:]

print(f"X_train shape: {X_train_df.shape}")
print(f"X_test shape: {X_test_df.shape}")

train_label_counts = Counter(y_train)
min_class_size = min(train_label_counts.values())
print("Training set class distribution before balancing:", train_label_counts)
class_indices = {cls: np.where(y_train == cls)[0] for cls in np.unique(y_train)}
selected_indices = np.concatenate([
    np.random.choice(indices, min_class_size, replace=False) for indices in class_indices.values()
])
np.random.shuffle(selected_indices)

X_train_df = X_train_df.iloc[selected_indices].reset_index(drop=True)
y_train = y_train[selected_indices]
sp_train = sp_train[selected_indices]

print("Balanced training set class distribution:", Counter(y_train))

value_to_symbol = {1:'X', -1:'O', 0:'Empty'}

symbol_names = [
    'X', 'O', 'Empty',
    'StartingPlayer0', 'StartingPlayer1',
    'Center', 'Edge', 'Corner',
    'Bridge',
    'IsCriticalBlock'
]

for r in range(board_size):
    symbol_names.append(f'Row_r{r}')
for c in range(board_size):
    symbol_names.append(f'Col_c{c}')

symbol_names.extend(['DistFromCenter_Near', 'DistFromCenter_Mid', 'DistFromCenter_Far',
                     'NeighborX_Low', 'NeighborX_Medium', 'NeighborX_High',
                     'NeighborO_Low', 'NeighborO_Medium', 'NeighborO_High',
                     'TotalX_Low', 'TotalX_Medium', 'TotalX_High',
                     'TotalO_Low', 'TotalO_Medium', 'TotalO_High'])

def prepare_graph_data(X_df_batch, sp_series_batch, y_batch):
    num_graphs = X_df_batch.shape[0]
    num_board_nodes = board_size**2
    total_nodes_per_graph = num_board_nodes

    graphs = Graphs(
        number_of_graphs=num_graphs,
        symbols=symbol_names,
        hypervector_size=args.hypervector_size,
        hypervector_bits=args.hypervector_bits,
        double_hashing=args.double_hashing
    )

    nodes = [(r, c) for r in range(board_size) for c in range(board_size)]
    node_id_map = {(r,c): idx for idx,(r,c) in enumerate(nodes)}

    directions = [
        (-1,0), (-1,1), (0,1),
        (1,0), (1,-1), (0,-1)
    ]

    edges = [[] for _ in range(total_nodes_per_graph)]
    n_edges_list = [0 for _ in range(total_nodes_per_graph)]

    for r,c in nodes:
        node_id = node_id_map[(r,c)]
        for dr, dc in directions:
            nr, nc = r+dr,c+dc
            if 0<=nr<board_size and 0<=nc<board_size:
                neighbor_id = node_id_map[(nr,nc)]
                edges[node_id].append(neighbor_id)
                n_edges_list[node_id]+=1

    for graph_id in range(num_graphs):
        graphs.set_number_of_graph_nodes(graph_id, total_nodes_per_graph)
    graphs.prepare_node_configuration()

    for graph_id in range(num_graphs):
        for k in range(total_nodes_per_graph):
            graphs.add_graph_node(graph_id, k, n_edges_list[k])
    graphs.prepare_edge_configuration()

    center_r, center_c = board_size//2, board_size//2

    for graph_id in range(num_graphs):
        row_data = X_df_batch.iloc[graph_id]
        board_state = row_data.values.astype(int)
        board_state_symbols = [value_to_symbol.get(val,'Empty') for val in board_state]
        sp = sp_series_batch[graph_id]
        winner = y_batch[graph_id]

        total_X = board_state_symbols.count('X')
        total_O = board_state_symbols.count('O')

        if total_X <= 56:
            total_X_property = 'TotalX_Low'
        elif total_X <= 112:
            total_X_property = 'TotalX_Medium'
        else:
            total_X_property = 'TotalX_High'

        if total_O <= 56:
            total_O_property = 'TotalO_Low'
        elif total_O <= 112:
            total_O_property = 'TotalO_Medium'
        else:
            total_O_property = 'TotalO_High'

        board_state_dict = {(r,c):board_state_symbols[idx] for idx,(r,c) in enumerate(nodes)}

        for idx,(r,c) in enumerate(nodes):
            sym = board_state_symbols[idx]
            graphs.add_graph_node_property(graph_id, idx, sym)
            graphs.add_graph_node_property(graph_id, idx, f'StartingPlayer{sp}')

            graphs.add_graph_node_property(graph_id, idx, total_X_property)
            graphs.add_graph_node_property(graph_id, idx, total_O_property)

            if r==center_r and c==center_c:
                graphs.add_graph_node_property(graph_id, idx, 'Center')
            elif (r in [0, board_size-1]) and (c in [0, board_size-1]):
                graphs.add_graph_node_property(graph_id, idx, 'Corner')
            else:
                graphs.add_graph_node_property(graph_id, idx, 'Edge')

            graphs.add_graph_node_property(graph_id, idx, f'Row_r{r}')
            graphs.add_graph_node_property(graph_id, idx, f'Col_c{c}')

            dist_from_center = abs(r-center_r)+abs(c-center_c)
            if dist_from_center <=4:
                dist_property='DistFromCenter_Near'
            elif dist_from_center <=8:
                dist_property='DistFromCenter_Mid'
            else:
                dist_property='DistFromCenter_Far'
            graphs.add_graph_node_property(graph_id, idx, dist_property)

            neighbor_symbols=[]
            for dr,dc in directions:
                nr,nc=r+dr,c+dc
                if 0<=nr<board_size and 0<=nc<board_size:
                    neighbor_sym = board_state_dict.get((nr,nc),'Empty')
                    neighbor_symbols.append(neighbor_sym)
            num_neighbor_X = neighbor_symbols.count('X')
            num_neighbor_O = neighbor_symbols.count('O')

            def bin_count(cnt):
                if cnt<=2:
                    return 'Low'
                elif cnt<=4:
                    return 'Medium'
                else:
                    return 'High'

            graphs.add_graph_node_property(graph_id, idx, f'NeighborX_{bin_count(num_neighbor_X)}')
            graphs.add_graph_node_property(graph_id, idx, f'NeighborO_{bin_count(num_neighbor_O)}')

            if sym in ['X','O']:
                for dr1,dc1 in directions:
                    nr1,nc1=r+dr1,c+dc1
                    nr2,nc2=r+2*dr1,c+2*dc1
                    if 0<=nr2<board_size and 0<=nc2<board_size:
                        sym1=board_state_dict.get((nr1,nc1),'Empty')
                        sym2=board_state_dict.get((nr2,nc2),'Empty')
                        if sym1=='Empty' and sym2==sym:
                            graphs.add_graph_node_property(graph_id, idx,'Bridge')
                            break

            if winner==1 and sym=='O':
                graphs.add_graph_node_property(graph_id, idx, 'IsCriticalBlock')
            elif winner==0 and sym=='X':
                graphs.add_graph_node_property(graph_id, idx, 'IsCriticalBlock')

        for node_id in range(num_board_nodes):
            for neighbor_id in edges[node_id]:
                graphs.add_graph_node_edge(graph_id, node_id, neighbor_id, edge_type_name=0)

    graphs.encode()
    return graphs

print("Initializing Tsetlin Machine...")
tm = MultiClassGraphTsetlinMachine(
    args.number_of_clauses,
    args.T,
    args.s,
    len(np.unique(y_train)),
    depth=args.depth,
    max_included_literals=args.max_included_literals,
    message_size=args.message_size,
    message_bits=args.message_bits
)

y_train = np.array(y_train,dtype=np.int32)
y_test = np.array(y_test,dtype=np.int32)
sp_train = np.array(sp_train)
sp_test = np.array(sp_test)

best_test_accuracy=0
epochs_no_improve=0

start_training = time()
num_batches = int(np.ceil(len(X_train_df)/args.batch_size))
for epoch in range(args.epochs):
    print(f"\nEpoch {epoch+1}/{args.epochs}")
    indices = np.arange(len(X_train_df))
    np.random.shuffle(indices)
    X_train_df = X_train_df.iloc[indices].reset_index(drop=True)
    y_train = y_train[indices]
    sp_train = sp_train[indices]

    for batch_idx in range(num_batches):
        start_idx = batch_idx*args.batch_size
        end_idx = min((batch_idx+1)*args.batch_size,len(X_train_df))
        X_batch_df = X_train_df.iloc[start_idx:end_idx].reset_index(drop=True)
        y_batch = y_train[start_idx:end_idx]
        sp_batch = sp_train[start_idx:end_idx]

        graphs_batch = prepare_graph_data(X_batch_df, sp_batch, y_batch)
        tm.fit(graphs_batch,y_batch,epochs=1,incremental=True)
        del graphs_batch
        gc.collect()

    eval_indices = np.random.choice(len(X_train_df), size=2000, replace=False)
    X_eval_df = X_train_df.iloc[eval_indices].reset_index(drop=True)
    y_eval = y_train[eval_indices]
    sp_eval = sp_train[eval_indices]

    graphs_eval = prepare_graph_data(X_eval_df, sp_eval, y_eval)
    train_predictions = tm.predict(graphs_eval)
    train_accuracy = np.mean(y_eval==train_predictions)
    del graphs_eval
    gc.collect()

    num_test_batches = int(np.ceil(len(X_test_df)/args.batch_size))
    test_predictions=[]
    for batch_idx in range(num_test_batches):
        start_idx = batch_idx*args.batch_size
        end_idx = min((batch_idx+1)*args.batch_size,len(X_test_df))
        X_batch_df = X_test_df.iloc[start_idx:end_idx].reset_index(drop=True)
        y_batch = y_test[start_idx:end_idx]
        sp_batch = sp_test[start_idx:end_idx]

        graphs_batch = prepare_graph_data(X_batch_df, sp_batch, y_batch)
        preds = tm.predict(graphs_batch)
        test_predictions.extend(preds)
        del graphs_batch
        gc.collect()

    test_predictions = np.array(test_predictions)
    test_accuracy = np.mean(y_test==test_predictions)

    print(f"Train Accuracy: {train_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}")

    if test_accuracy>best_test_accuracy:
        best_test_accuracy=test_accuracy
        epochs_no_improve=0
    else:
        epochs_no_improve+=1
        if epochs_no_improve>=args.patience:
            print("Early stopping triggered.")
            break

stop_training = time()
print(f"\nTraining Time: {stop_training - start_training:.2f} seconds")

print("\nClassification Report:")
print(classification_report(y_test,test_predictions,digits=4))
print("Confusion Matrix:")
print(confusion_matrix(y_test,test_predictions))

unique_train_preds, counts_train_preds = np.unique(train_predictions, return_counts=True)
print("\nUnique predictions on training set:", unique_train_preds)
print("Training set predictions distribution:", dict(zip(unique_train_preds, counts_train_preds)))

unique_test_preds, counts_test_preds = np.unique(test_predictions, return_counts=True)
print("Unique predictions on test set:", unique_test_preds)
print("Test set predictions distribution:", dict(zip(unique_test_preds, counts_test_preds)))
