In [1]:
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, f1_score
from psyke import Extractor
from psyke.extraction.hypercubic.strategy import AdaptiveStrategy
from psyke.extraction.hypercubic import Grid, FeatureRanker
from psyke.utils.logic import pretty_theory
from psyke.utils import Target


dataset = pd.read_csv('dataset_vec - Copy.csv')



x= dataset.drop('Class', axis =1)
y = dataset['Class']
dataset = x.join(y)
dataset.columns = [*dataset.columns[:-1], 'class']

train, test = train_test_split(dataset, test_size=0.5, random_state=0)
#predictor = MLPClassifier(alpha=1, max_iter=1000)
predictor = KNeighborsClassifier(n_neighbors=5)
#predictor = DecisionTreeClassifier()
predictor.fit(train.iloc[:, :-1], train.iloc[:, -1])
print(f'Accuracy: {accuracy_score(predictor.predict(test.iloc[:, :-1]), test.iloc[:, -1]):.2f}')
print(f'F1: {f1_score(predictor.predict(test.iloc[:, :-1]), test.iloc[:, -1], average="weighted"):.2f}')


cart = Extractor.cart(predictor, simplify=True)
theory_from_cart = cart.extract(train)
print('CART performance ({} rules):\nAccuracy = {:.2f}\nFidelity = {:.2f}\nF1 = {:.2f}\nF1 = {:.2f}\n'
      .format(cart.n_rules, cart.accuracy(test), cart.accuracy(test, predictor),
              cart.f1(test), cart.f1(test, predictor)))
print('\nCART extracted rules:\n\n' + pretty_theory(theory_from_cart))


Accuracy: 0.93
F1: 0.93
CART performance (3 rules):
Accuracy = 0.86
Fidelity = 0.88
F1 = 0.82
F1 = 0.85


CART extracted rules:

class(C1, C10, C100, C11, C12, C13, C14, C15, C16, C17, C18, C19, C2, C20, C21, C22, C23, C24, C25, C26, C27, C28, C29, C3, C30, C31, C32, C33, C34, C35, C36, C37, C38, C39, C4, C40, C41, C42, C43, C44, C45, C46, C47, C48, C49, C5, C50, C51, C52, C53, C54, C55, C56, C57, C58, C59, C6, C60, C61, C62, C63, C64, C65, C66, C67, C68, C69, C7, C70, C71, C72, C73, C74, C75, C76, C77, C78, C79, C8, C80, C81, C82, C83, C84, C85, C86, C87, C88, C89, C9, C90, C91, C92, C93, C94, C95, C96, C97, C98, C99, Reason_C1, Reason_C10, Reason_C100, Reason_C11, Reason_C12, Reason_C13, Reason_C14, Reason_C15, Reason_C16, Reason_C17, Reason_C18, Reason_C19, Reason_C2, Reason_C20, Reason_C21, Reason_C22, Reason_C23, Reason_C24, Reason_C25, Reason_C26, Reason_C27, Reason_C28, Reason_C29, Reason_C3, Reason_C30, Reason_C31, Reason_C32, Reason_C33, Reason_C34, Reason_C35, Reason_C36, Rea

In [3]:
ranked = FeatureRanker(x.columns).fit(predictor, x).rankings()
gridEx = Extractor.gridex(predictor, Grid(1, AdaptiveStrategy(ranked, [(0.85, 8)])), threshold=.1, min_examples=1)
theory_from_gridEx = gridEx.extract(train)
print('GridEx performance ({} rules):\nAccuracy = {:.2f}\nAccuracy fidelity = {:.2f}\nF1 = {:.2f}\nF1 = {:.2f}\n'
      .format(gridEx.n_rules, gridEx.accuracy(test), gridEx.accuracy(test, predictor),
              gridEx.f1(test), gridEx.f1(test, predictor)))
print('GridEx extracted rules:\n\n' + pretty_theory(theory_from_gridEx))


GridEx performance (4 rules):
Accuracy = 0.86
Accuracy fidelity = 0.87
F1 = 0.82
F1 = 0.84

GridEx extracted rules:

class(C1, C10, C100, C11, C12, C13, C14, C15, C16, C17, C18, C19, C2, C20, C21, C22, C23, C24, C25, C26, C27, C28, C29, C3, C30, C31, C32, C33, C34, C35, C36, C37, C38, C39, C4, C40, C41, C42, C43, C44, C45, C46, C47, C48, C49, C5, C50, C51, C52, C53, C54, C55, C56, C57, C58, C59, C6, C60, C61, C62, C63, C64, C65, C66, C67, C68, C69, C7, C70, C71, C72, C73, C74, C75, C76, C77, C78, C79, C8, C80, C81, C82, C83, C84, C85, C86, C87, C88, C89, C9, C90, C91, C92, C93, C94, C95, C96, C97, C98, C99, Reason_C1, Reason_C10, Reason_C100, Reason_C11, Reason_C12, Reason_C13, Reason_C14, Reason_C15, Reason_C16, Reason_C17, Reason_C18, Reason_C19, Reason_C2, Reason_C20, Reason_C21, Reason_C22, Reason_C23, Reason_C24, Reason_C25, Reason_C26, Reason_C27, Reason_C28, Reason_C29, Reason_C3, Reason_C30, Reason_C31, Reason_C32, Reason_C33, Reason_C34, Reason_C35, Reason_C36, Reason_C37, Rea