In [None]:
from classifiers import LocalClassifierPerLevel, LocalClassifierPerNode, LocalClassifierPerParentNode
from classifiers import Explainer
from classifiers import datasets, metrics as hmetrics
import shap
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics as metrics

In [None]:
# Load train and test splits
X_train, X_test, Y_train, Y_test = datasets.load_platypus()

# Use random forest classifiers for every node
rfc = RandomForestClassifier()

In [None]:
# LocalClassifierPerNode

# Local Classifier Per Node

In [None]:
lcpn_classifier = LocalClassifierPerNode(
    local_classifier=rfc, 
    replace_classifiers=False, 
    binary_policy="inclusive",
    edge_list="./hierarchy.csv",
    verbose=0,
)

# Train local classifier per node
lcpn_classifier.fit(X_train, Y_train)

# Predict
predictions, uncertainties = lcpn_classifier.predict(X_test, return_uncertainty=True)

In [None]:
precision = hmetrics.h_precision_score(Y_test, predictions)
print("hprecision:", precision)

recall = hmetrics.h_recall_score(Y_test, predictions)
print("hrecall:", recall)

f1 = hmetrics.h_f1_score(Y_test, predictions)
print("hf1:", f1)

In [None]:
true_labels = np.array(Y_test.to_list())

# Flatten the true and predicted labels to calculate micro/macro/weighted metrics
flat_true_labels = true_labels.flatten()
flat_predicted_labels = predictions.flatten()

# Calculate micro, macro, and weighted precision, recall, and F1 score
micro_precision = metrics.precision_score(flat_true_labels, flat_predicted_labels, average='micro')
macro_precision = metrics.precision_score(flat_true_labels, flat_predicted_labels, average='macro')
weighted_precision = metrics.precision_score(flat_true_labels, flat_predicted_labels, average='weighted')

micro_recall = metrics.recall_score(flat_true_labels, flat_predicted_labels, average='micro')
macro_recall = metrics.recall_score(flat_true_labels, flat_predicted_labels, average='macro')
weighted_recall = metrics.recall_score(flat_true_labels, flat_predicted_labels, average='weighted')

micro_f1 = metrics.f1_score(flat_true_labels, flat_predicted_labels, average='micro')
macro_f1 = metrics.f1_score(flat_true_labels, flat_predicted_labels, average='macro')
weighted_f1 = metrics.f1_score(flat_true_labels, flat_predicted_labels, average='weighted')

print("Micro Precision:", micro_precision)
print("Macro Precision:", macro_precision)
print("Weighted Precision:", weighted_precision)

print("Micro Recall:", micro_recall)
print("Macro Recall:", macro_recall)
print("Weighted Recall:", weighted_recall)

print("Micro F1 Score:", micro_f1)
print("Macro F1 Score:", macro_f1)
print("Weighted F1 Score:", weighted_f1)

# Calculate Hamming Loss
loss = metrics.hamming_loss(flat_true_labels, flat_predicted_labels)

print("Hamming Loss:", loss)

In [None]:
# Define Explainer
explainer = Explainer(lcpn_classifier, data=X_train.values, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Filter samples which only predicted "Respiratory" at first level
respiratory_idx = lcpn_classifier.predict(X_test)[:, 0] == "Respiratory"

# Specify additional filters to obtain only level 0
shap_filter = {"level": 0, "class": "Respiratory_1", "sample": respiratory_idx}

# Use .sel() method to apply the filter and obtain filtered results
shap_val_respiratory = explanations.sel(shap_filter)

# Plot feature importance on test set
shap.plots.violin(
    shap_val_respiratory.shap_values,
    feature_names=X_train.columns.values,
    plot_size=(13, 8),
)

# Local Classifier Per Parent Node

In [None]:
lcppn_classifier = LocalClassifierPerParentNode(
    local_classifier=rfc, 
    replace_classifiers=False,
    # edge_list="./hierarchy.csv",
)

# Train local classifier per node
lcppn_classifier.fit(X_train, Y_train)

# Predict
predictions = lcppn_classifier.predict(X_test)
print(predictions)

# Local Classifier Per Level

In [None]:
lcpl_classifier = LocalClassifierPerLevel(
    local_classifier=rfc, 
    replace_classifiers=False,
    # edge_list="./hierarchy.csv",
)

# Train local classifier per level
lcpl_classifier.fit(X_train, Y_train)

# Predict
predictions = lcpl_classifier.predict(X_test)
print(predictions)