In [10]:
# standard data science packages
import numpy as np
import pandas as pd

# imodels imports
from imodels.tree.rf_plus.rf_plus.rf_plus_models import \
    RandomForestPlusRegressor, RandomForestPlusClassifier
from imodels.tree.rf_plus.feature_importance.rfplus_explainer import \
    RFPlusMDI, AloRFPlusMDI

# functions for subgroup experiments
from subgroup_detection import *
from subgroup_experiment import *
import shap

# sklearn imports
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, \
    accuracy_score, r2_score, f1_score, log_loss, root_mean_squared_error

# pipeline imports
from subgroup import *

In [11]:
# set inputs
seed = 1
dataids = [361247, 361243, 361242, 361251, 361253, 361260, 361259, 361256, 361254, 361622]
dataid = dataids[0]
clustertype = "hierarchical"

In [12]:
# get data
X, y = get_openml_data(dataid)

# split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3,
                                                    random_state=seed)

# check if task is regression or classification
if len(np.unique(y)) == 2:
    task = 'classification'
else:
    task = 'regression'
    
# fit the prediction models
rf, rf_plus_baseline, rf_plus = fit_models(X_train, y_train, task)

  X, y = get_openml_data(dataid)
  dataset = get_dataset(task.dataset_id, *dataset_args, **get_dataset_kwargs)
  return datasets.get_dataset(self.dataset_id)
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:   11.8s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:   14.3s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:  1.6min
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  4.5min finished


In [13]:
# obtain shap feature importances
shap_explainer = shap.TreeExplainer(rf)
shap_train_values, shap_train_rankings = get_shap(X_train, shap_explainer,
                                                    task)
shap_test_values, shap_test_rankings = get_shap(X_test, shap_explainer,
                                                task)

In [14]:
# create list of lmdi variants
lmdi_variants = create_lmdi_variant_map()

In [15]:
# obtain lmdi feature importances
lmdi_explainers = get_lmdi_explainers(rf_plus, lmdi_variants,
                                        rf_plus_baseline = rf_plus_baseline)
lfi_train_values, lfi_train_rankings = get_lmdi(X_train, y_train,
                                                lmdi_variants,
                                                lmdi_explainers)
lfi_test_values, lfi_test_rankings = get_lmdi(X_test, None,
                                                lmdi_variants,
                                                lmdi_explainers)
# add shap to the dictionaries
lfi_train_values["shap"] = shap_train_values
lfi_train_rankings["shap"] = shap_train_rankings
lfi_test_values["shap"] = shap_test_values
lfi_test_rankings["shap"] = shap_test_rankings

# add the raw data to the dictionaries as a baseline of comparison
lfi_train_values["rawdata"] = X_train
lfi_test_values["rawdata"] = X_test

In [16]:
# get the clusterings
# method_to_labels, method_to_indices = get_train_clusters(lfi_train_values, clustertype)
train_clusters = get_train_clusters(lfi_train_values, clustertype)
cluster_centroids = get_cluster_centroids(lfi_train_values, train_clusters)
test_clusters = get_test_clusters(lfi_test_values, cluster_centroids)

In [17]:
# compute the performance
metrics_to_scores = compute_performance(X_train, X_test, y_train, y_test,
                                        train_clusters, test_clusters, task)

In [18]:
metrics_to_scores

{'r2': {'lmdi_baseline': {2: 0.6445649451638324,
   3: 0.8680766126738856,
   4: 0.913119037406979,
   5: -215.23091111985252,
   6: -215.91991426100282,
   7: -78.7535766025578,
   8: -84.01922175742995,
   9: -141.54733632707973,
   10: -149.47990644371674},
  'aloo_l2_signed_normed_leafavg_rank': {2: 0.9927271668877967,
   3: 0.998027875479179,
   4: 0.9980136251092391,
   5: 0.9980137595945218,
   6: 0.9980868009483398,
   7: 0.9989535457257578,
   8: 0.9989539218165313,
   9: 0.9990000203695893,
   10: 0.9994746285775445},
  'aloo_l2_signed_normed_leafavg_norank': {2: 0.9927271668877967,
   3: 0.998027875479179,
   4: 0.9980136251092391,
   5: 0.9980137595945218,
   6: 0.9980868009483398,
   7: 0.9989535457257578,
   8: 0.9989539218165313,
   9: 0.9990000203695893,
   10: 0.9994746285775445},
  'aloo_l2_signed_normed_noleafavg_rank': {2: 0.9927271668877967,
   3: 0.8129611156493682,
   4: 0.6971885071248699,
   5: 0.8822557273813058,
   6: 0.9980904410847861,
   7: 0.9993043660518