In [1]:
import os
import sys
import platform
from pathlib import Path
from GraphTsetlinMachine.graphs import Graphs
from GraphTsetlinMachine.tm import MultiClassGraphTsetlinMachine
import pickle
import logging
import optuna
from optuna.exceptions import TrialPruned 
from tqdm.auto import tqdm
from sklearn.metrics import f1_score, precision_score, recall_score
import math
from functools import partial

In [2]:
optuna.logging.set_verbosity(optuna.logging.WARNING)

In [3]:
def get_machine_info():
    machine_name = platform.node()  
    user = os.getenv("USER") or os.getenv("USERNAME") 
    os_name = platform.system()  # Get os
    print(f"Machine: {machine_name}")
    print(f"OS: {os_name}")
    print(f"User: {user}")
    
    # Print machine info
    return machine_name, os_name, user

In [4]:
machine_name, os_name, user = get_machine_info()

Machine: Corsair
OS: Linux
User: jon


In [5]:
if machine_name == "Corsair" and os_name == "Linux" and user == "jon":
    windows_drive = Path("/mnt/b/TsetlinModels")
    os.makedirs(windows_drive / "data", exist_ok=True)
    os.makedirs(windows_drive / "models", exist_ok=True)
    os.makedirs(windows_drive / "graphs", exist_ok=True)

    paths = {
        "data": windows_drive / "data",
        "models": windows_drive / "models",
        "graphs": windows_drive / "graphs",
    }
else:
    os.makedirs("data", exist_ok=True)
    os.makedirs("models", exist_ok=True)
    os.makedirs("graphs", exist_ok=True)

    paths = {
        "data": Path("data"),
        "models": Path("models"),
        "graphs": Path("graphs"),
    }

In [6]:
hypervector_bits = 2
hypervector_size = 64

samples = [1000,10000,100000]
board_sizes = [3,4,5,6,7,8,9,10,11,12,13,14,15]
moves_before = [0, 2, 5]

In [7]:
def stop_when_100_accuracy(study, trial):
    if trial.value >= 100: 
        study.stop()

In [8]:
def objective(trial, graphs_train, graphs_test, X_train, Y_train, X_test, Y_test, board_size, mbf, n_samples, dataset):
    max_clauses = n_samples*10
    max_epochs = math.sqrt(n_samples)*max(1,mbf)
    
    number_of_clauses = trial.suggest_int('number_of_clauses', board_size**2, max_clauses)
    T = trial.suggest_float('T_factor', 0.5, number_of_clauses * 1.2)
    s = trial.suggest_float('s', 0.1, board_size)
    depth = trial.suggest_int('depth', 2, board_size+1)
    epochs = trial.suggest_int('epochs', 15, max_epochs)
    message_size = 32
    message_bits = 2

    tm = MultiClassGraphTsetlinMachine(
        number_of_clauses,
        T,
        s,
        depth=depth,
        message_size=message_size,
        message_bits=message_bits,
        number_of_state_bits=8,
        boost_true_positive_feedback=1,
        grid=(16*13,1,1),
        block=(128,1,1),
    )

    print(f"Start trial with c={number_of_clauses}, T={T}, s={s}, d={depth}, e={epochs}")

    best_test_acc = 0
    best_f1 = 0
    best_prec = 0
    best_rec = 0
    
    progress_bar = tqdm(range(epochs), desc=f"{dataset}", leave=True)
    for epoch in progress_bar:
        tm.fit(graphs_train, Y_train, epochs=1, incremental=True)

        result_test = 100 * (tm.predict(graphs_test) == Y_test).mean()

        f1_score_test = f1_score(Y_test, tm.predict(graphs_test), average='weighted', zero_division=0)
        precision_test = precision_score(Y_test, tm.predict(graphs_test), average='weighted', zero_division=0)
        recall_test = recall_score(Y_test, tm.predict(graphs_test), average='weighted', zero_division=0)

        if result_test > best_test_acc:
            best_test_acc = result_test
            best_f1 = f1_score_test
            best_prec = precision_test
            best_rec = recall_test

        trial.set_user_attr("f1", f1_score_test)
        trial.set_user_attr("precision", precision_test)
        trial.set_user_attr("recall", recall_test)


        progress_bar.set_postfix({
            'Acc':f'{result_test:.2f}%',
            'BestAcc': f'{best_test_acc:.2f}%',
            'F1': f'{best_f1:.2f}',
            'Prec': f'{best_prec:.2f}',
            'Rec': f'{best_rec:.2f}'
        })

        
        if result_test >= 100 and f1_score_test >= 1:
            return result_test
            
        if epoch > 5 and best_test_acc < 90:
            if trial.should_prune():
                raise TrialPruned()

    return result_test

In [None]:
# Running multiple studies for different configurations of the dataset
for n_samples in tqdm(samples, desc="Samples"):
    for board_size in tqdm(board_sizes, desc="Board Sizes", leave=False):
        for mbf in tqdm(moves_before, desc="Moves Before", leave=False):
            dataset = f"{board_size}x{board_size}_{mbf}"
            file_path = paths["graphs"] / f"{dataset}_{n_samples}.pkl"
            with open(file_path, 'rb') as f:
                graphs_train, graphs_test, X_train, Y_train, X_test, Y_test = pickle.load(f)

            study = optuna.create_study(
                #directions=["maximize", "minimize"],  # Maximize accuracy, minimize number of clauses
                direction="maximize",
                study_name=f"Study_{dataset}",
                storage="sqlite:///results/optuna/ja_tsehex_local.db",
                load_if_exists=True
                #pruner=optuna.pruners.MedianPruner()
            )

            if len(study.trials) > 0:
                try:
                    if study.best_trial.value >= 100:
                        print(f"Study {study.study_name} already has 100% accuracy. Skipping further optimization.")
                        continue
                except ValueError:
                    print(f"No valid trials found for {study.study_name}, continuing with optimization.")
            else:
                print(f"No trials found for {study.study_name}. Running new optimization.")

            objective_with_params = partial(objective, graphs_train=graphs_train, graphs_test=graphs_test, 
                                            X_train=X_train, Y_train=Y_train, X_test=X_test, Y_test=Y_test, 
                                            board_size=board_size, mbf=mbf, n_samples=n_samples, dataset=dataset)

            try:
                study.optimize(objective_with_params, n_trials=1000, callbacks=[stop_when_100_accuracy])
            except KeyboardInterrupt:
                print("Optimization interrupted!")
                print(f"Best result so far: {study.best_params}")

Samples:   0%|          | 0/3 [00:00<?, ?it/s]

Board Sizes:   0%|          | 0/13 [00:00<?, ?it/s]

Moves Before:   0%|          | 0/3 [00:00<?, ?it/s]

No valid trials found for Study_3x3_0, continuing with optimization.
Initialization of sparse structure.
Start trial with c=1233, T=919.0441024847984, s=2.5603653098604906, d=2, e=16


3x3_0:   0%|          | 0/16 [00:00<?, ?it/s]

No trials found for Study_3x3_2. Running new optimization.
Initialization of sparse structure.
Start trial with c=3341, T=3673.391876189983, s=1.0109512776574654, d=4, e=31


3x3_2:   0%|          | 0/31 [00:00<?, ?it/s]

No trials found for Study_3x3_5. Running new optimization.
Initialization of sparse structure.
Start trial with c=5415, T=4668.892643124797, s=2.5644675202604024, d=2, e=90


3x3_5:   0%|          | 0/90 [00:00<?, ?it/s]

Moves Before:   0%|          | 0/3 [00:00<?, ?it/s]

No trials found for Study_4x4_0. Running new optimization.
Initialization of sparse structure.
Start trial with c=7267, T=2720.448763450113, s=2.2738383356690126, d=5, e=24


4x4_0:   0%|          | 0/24 [00:00<?, ?it/s]

Initialization of sparse structure.
Start trial with c=8972, T=4244.283469671237, s=3.066712622794668, d=3, e=24


4x4_0:   0%|          | 0/24 [00:00<?, ?it/s]

Initialization of sparse structure.
Start trial with c=7326, T=2672.7410691482687, s=2.4445553575249295, d=3, e=26


4x4_0:   0%|          | 0/26 [00:00<?, ?it/s]

No trials found for Study_4x4_2. Running new optimization.
Initialization of sparse structure.
Start trial with c=1545, T=1375.7552943988285, s=2.4666582439026485, d=4, e=28


4x4_2:   0%|          | 0/28 [00:00<?, ?it/s]