In [1]:
import pickle
import logging
import optuna
from optuna.exceptions import TrialPruned
from tqdm.auto import tqdm
from GraphTsetlinMachine.graphs import Graphs
from GraphTsetlinMachine.tm import MultiClassGraphTsetlinMachine
from sklearn.metrics import f1_score, precision_score, recall_score
import math

In [2]:
optuna.logging.set_verbosity(optuna.logging.WARNING)

In [3]:
hypervector_bits = 2
hypervector_size = 64

samples = [1000,10000,100000]
board_sizes = [3,4,5,6,7,8,9,10,11,12,13,14,15]
moves_before = [0, 2, 5]

In [4]:
def stop_when_100_accuracy(study, trial):
    if trial.value >= 100: 
        study.stop()

In [5]:
def objective(trial):
    max_clauses = n_samples*10
    max_epochs = math.sqrt(n_samples)*max(1,mbf)
    
    number_of_clauses = trial.suggest_int('number_of_clauses', board_size**2, max_clauses)
    T = trial.suggest_float('T_factor', 0.5, number_of_clauses * 1.2)
    s = trial.suggest_float('s', 0.1, board_size)
    depth = trial.suggest_int('depth', 2, board_size+1)
    epochs = trial.suggest_int('epochs', 15, max_epochs)
    message_size = 32
    message_bits = 2

    tm = MultiClassGraphTsetlinMachine(
        number_of_clauses,
        T,
        s,
        depth=depth,
        message_size=message_size,
        message_bits=message_bits,
        number_of_state_bits=8,
        boost_true_positive_feedback=1,
    )

    print(f"Start trial with c={number_of_clauses}, T={T}, s={s}, d={depth}, e={epochs}")

    best_test_acc = 0
    best_f1 = 0
    best_prec = 0
    best_rec = 0
    
    progress_bar = tqdm(range(epochs), desc=f"{dataset}", leave=True)
    for epoch in progress_bar:
        tm.fit(graphs_train, Y_train, epochs=1, incremental=True)

        result_test = 100 * (tm.predict(graphs_test) == Y_test).mean()

        f1_score_test = f1_score(Y_test, tm.predict(graphs_test), average='weighted', zero_division=0)
        precision_test = precision_score(Y_test, tm.predict(graphs_test), average='weighted', zero_division=0)
        recall_test = recall_score(Y_test, tm.predict(graphs_test), average='weighted', zero_division=0)

        if result_test > best_test_acc:
            best_test_acc = result_test
            best_f1 = f1_score_test
            best_prec = precision_test
            best_rec = recall_test

        trial.set_user_attr("f1", f1_score_test)
        trial.set_user_attr("precision", precision_test)
        trial.set_user_attr("recall", recall_test)


        progress_bar.set_postfix({
            'Acc':f'{result_test:.2f}%',
            'BestAcc': f'{best_test_acc:.2f}%',
            'F1': f'{best_f1:.2f}',
            'Prec': f'{best_prec:.2f}',
            'Rec': f'{best_rec:.2f}'
        })

        # Early stopping conditions
        if result_test >= 100 and f1_score_test >= 1:
            return result_test
            
        if epoch > 5 and best_test_acc < 90:
            if trial.should_prune():
                raise TrialPruned()

    return result_test

In [None]:
# Running multiple studies for different configurations of the dataset
for n_samples in tqdm(samples, desc="Samples"):
    for board_size in tqdm(board_sizes, desc="Board Sizes", leave=False):
        for mbf in tqdm(moves_before, desc="Moves Before", leave=False):
            dataset = f"{board_size}x{board_size}_{mbf}"
            with open(f"graphs/{dataset}_{n_samples}.pkl", 'rb') as f:
                graphs_train, graphs_test, X_train, Y_train, X_test, Y_test = pickle.load(f)

            study = optuna.create_study(
                #directions=["maximize", "minimize"],  # Maximize accuracy, minimize number of clauses
                direction="maximize",
                study_name=f"Study_{dataset}",
                storage="sqlite:///results/optuna/ja_tsehex.db",
                load_if_exists=True
                #pruner=optuna.pruners.MedianPruner()
            )

            if study.best_trial.value >= 100:
                print(f"Study {study.study_name} already has 100% accuracy. Skipping further optimization.")
                continue 

            try:
                study.optimize(objective, n_trials=1000, callbacks=[stop_when_100_accuracy])
            except KeyboardInterrupt:
                print("Optimization interrupted!")
                print(f"Best result so far: {study.best_params}")

Samples:   0%|          | 0/3 [00:00<?, ?it/s]

Board Sizes:   0%|          | 0/13 [00:00<?, ?it/s]

Moves Before:   0%|          | 0/3 [00:00<?, ?it/s]

Study Study_3x3_0 already has 100% accuracy. Skipping further optimization.
Study Study_3x3_2 already has 100% accuracy. Skipping further optimization.
Study Study_3x3_5 already has 100% accuracy. Skipping further optimization.


Moves Before:   0%|          | 0/3 [00:00<?, ?it/s]

Study Study_4x4_0 already has 100% accuracy. Skipping further optimization.
Study Study_4x4_2 already has 100% accuracy. Skipping further optimization.
Initialization of sparse structure.
Start trial with c=9876, T=3574.001809734016, s=0.1423227380752139, d=5, e=157


4x4_5:   0%|          | 0/157 [00:00<?, ?it/s]