In [2]:
import pandas as pd
from joblib import load
import tensorflow as tf
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score
import matplotlib.pyplot as plt
import mplhep as hep
plt.style.use([hep.style.ROOT, hep.style.firamath])

In [3]:
version = '9.0.5'
train = pd.read_csv(f'../data_files/{version}/train.csv', index_col=[0])
X_train = train.drop(['category', 'Lb_M', 'IsSimulated'], axis=1)
y_train = train.category

val = df = pd.read_csv(f'../data_files/{version}/val.csv', index_col=[0])
X_val = val.drop(['category', 'Lb_M', 'IsSimulated'], axis=1)
y_val = val.category

test = df = pd.read_csv(f'../data_files/{version}/test.csv', index_col=[0])
X_test = test.drop(['category', 'Lb_M', 'IsSimulated'], axis=1)
y_test = test.category

In [4]:
nn = tf.keras.models.load_model(f'../neural_network/models/v{version}')
knn = load(f'../classification_methods/models/KNN_{version}_tune.joblib')
rfc = load(f'../classification_methods/models/RFC_{version}_tune.joblib')
dtc = load(f'../classification_methods/models/DTC_{version}_tune.joblib')
xgb = load(f'../classification_methods/models/XGB_{version}.joblib')

models = {'NN': nn, 'KNN': knn, 'RFC': rfc, 'DTC': dtc, 'XGB': xgb}

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations
  from pandas import MultiIndex, Int64Index


In [21]:
import tensorflow_addons as tfa

for name, model in models.items():
    try:
        preds = model.predict_proba(X_train.to_numpy())[:,1]
    except:
        preds = model.predict(X_train.to_numpy()).flatten()

    metric = tfa.metrics.F1Score(num_classes=1, threshold=0.8)
    y_pred = np.array([[i] for i in preds])

    metric.update_state(np.array([[i] for i in y_train.to_numpy()]), y_pred)

    result = metric.result()
    print(f"{name} : {result.numpy()[0]}")

NN : 0.925909698009491
KNN : 0.9093238711357117
RFC : 0.9053465127944946
DTC : 0.886944591999054
XGB : 0.9372956156730652
