In [2]:
from time import perf_counter
import random
import pickle
import os 
import uuid

import numpy as np
import pandas as pd
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import tqdm

import green_tsetlin as gt

def get_dataset(data_dir):
    t0 = perf_counter()
    
    dataset_train = pickle.load(open(os.path.join(data_dir, "train_dataset.pkl"),"rb"))
    dataset_test = pickle.load(open(os.path.join(data_dir, "val_dataset.pkl"), "rb"))

    X_train = np.array(dataset_train["images"])
    y_train = np.array(dataset_train["labels"])

    X_test = np.array(dataset_test["images"])
    y_test = np.array(dataset_test["labels"])
    
    X_train = np.where(X_train.reshape((X_train.shape[0], 45 * 45)) > 200, 1, 0)
    X_train = X_train.astype(np.uint8)
        
    X_test = np.where(X_test.reshape((X_test.shape[0], 45 * 45)) > 200, 1, 0)
    X_test = X_test.astype(np.uint8)
    
    y_train = y_train.astype(np.uint32)
    y_test = y_test.astype(np.uint32)

    print("X_train.shape:{}".format(X_train.shape))
    print("y_train.shape:{}".format(y_train.shape))
    print("X_test.shape:{}".format(X_test.shape))
    print("y_test.shape:{}".format(y_test.shape))

    t1 = perf_counter()    
    delta = t1 - t0
    print("getting data time:{}".format(delta))

    return X_train, X_test, y_train, y_test

data_dir = "/home/steffenm/data/cv/dataset"

xt, xe, yt, ye = get_dataset(data_dir)
#print(xe)

n_clauses = 5000
n_literals = xt.shape[1]
n_classes = 41
s = 10.0
n_literal_budget = 10
threshold = 1000    
n_jobs = 1
seed = 42

tm = gt.TsetlinMachine(n_literals=n_literals, n_clauses=n_clauses, n_classes=n_classes, s=s, threshold=threshold, literal_budget=n_literal_budget)

trainer = gt.Trainer(tm, n_epochs=1, seed=seed, n_jobs=n_jobs, progress_bar=True)
trainer.set_train_data(xt, yt)
trainer.set_eval_data(xe, ye)    
trainer.train()

tm_save_path = os.path.join(data_dir, "tm_state.npz")
tm.save_state(tm_save_path)

print("--- results ---")
print(trainer.results)
print("--")


predictor = tm.get_predictor()

total=0
for i, x in enumerate(xe):
    y_pred = predictor.predict(x)
    print("y_pred:{}".format(y_pred))
    print("y_true:{}".format(ye[i]))
    if y_pred == ye[i]:
        total += 1
accuracy = total/len(xe)


print("accuracy:{}".format(accuracy))


print("<done>")
    

X_train.shape:(83093, 2025)
y_train.shape:(83093,)
X_test.shape:(20774, 2025)
y_test.shape:(20774,)
getting data time:0.6231197998858988


Processing epoch 1 of 1, train acc: 0.540, best eval score: 0.585 (epoch: 0): 100%|██████████| 1/1 [01:57<00:00, 117.61s/it]


--- results ---
{'train_time_of_epochs': [109.73825527983718], 'best_eval_score': 0.5854433426398382, 'best_eval_epoch': 0, 'n_epochs': 1, 'train_log': [0.539961248239924], 'eval_log': [0.5854433426398382], 'did_early_exit': False}
--
y_pred:12
y_true:14
y_pred:12
y_true:8
y_pred:19
y_true:19
y_pred:6
y_true:6
y_pred:21
y_true:9
y_pred:10
y_true:10
y_pred:11
y_true:11
y_pred:1
y_true:21
y_pred:20
y_true:20
y_pred:16
y_true:16
y_pred:22
y_true:22
y_pred:12
y_true:10
y_pred:8
y_true:8
y_pred:20
y_true:20
y_pred:4
y_true:5
y_pred:16
y_true:18
y_pred:17
y_true:17
y_pred:2
y_true:2
y_pred:26
y_true:26
y_pred:20
y_true:20
y_pred:12
y_true:15
y_pred:7
y_true:21
y_pred:23
y_true:23
y_pred:21
y_true:25
y_pred:16
y_true:16
y_pred:3
y_true:3
y_pred:20
y_true:20
y_pred:22
y_true:21
y_pred:1
y_true:1
y_pred:18
y_true:18
y_pred:22
y_true:0
y_pred:23
y_true:23
y_pred:7
y_true:7
y_pred:4
y_true:4
y_pred:15
y_true:0
y_pred:12
y_true:12
y_pred:20
y_true:20
y_pred:15
y_true:15
y_pred:12
y_true:21
y_pred:

In [3]:
ds = gt.DenseState.load_from_file(tm_save_path)
    
rs = gt.RuleSet(False)
rs.compile_from_dense_state(ds)

print(rs.rules[0])
print(rs.weights[0])

predictor = gt.Predictor(explanation="none", multi_label=False)
predictor._set_ruleset(rs)
predictor._allocate_backend()


print("predictor.predict(x):", predictor.predict(xe[1]))
print("votes:", predictor._inf.get_votes())

[784]
[25, 14, 4, -30, 33, -13, -19, -33, -43, -2, -27, 47, -64, -12, -2, -16, 7, -14, -71, -16, 82, 57, -14, 26, -20, -10, 4, -6, 2, 0, -2, -6, -1, -2, 0, -6, -6, -4, -4, -4, -3]
predictor.predict(x): 12
votes: [-1740  -335 -1318 -1875   416  -883 -1264  -961  -520  -808  -638  -631
   526 -1514 -1646   392 -1293  -914 -1272 -2453  -384 -1726  -683 -1475
  -846 -1871  -856  -736  -692  -669  -692  -740  -670  -679  -681  -708
  -746  -704  -691  -683  -676]
