In [1]:
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report

import pandas as pd
import time
import metrics

# counterfactuals libraries
import dice_ml
from nice import NICE

import tensorflow as tf
tf.get_logger().setLevel(40) 
tf.compat.v1.disable_v2_behavior() 
from alibi.explainers import CounterfactualProto
from alibi.utils import ohe_to_ord, ord_to_ohe

2023-10-03 19:01:51.587059: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-03 19:01:51.739121: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-03 19:01:51.742856: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv('data/TelcoChurn.csv')

In [3]:
df.shape

(7043, 20)

In [4]:
df.head()

Unnamed: 0,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,OnlineBackup,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
0,Female,0,Yes,No,1,No,No phone service,DSL,No,Yes,No,No,No,No,Month-to-month,Yes,Electronic check,29.85,29.85,No
1,Male,0,No,No,34,Yes,No,DSL,Yes,No,Yes,No,No,No,One year,No,Mailed check,56.95,1889.5,No
2,Male,0,No,No,2,Yes,No,DSL,Yes,Yes,No,No,No,No,Month-to-month,Yes,Mailed check,53.85,108.15,Yes
3,Male,0,No,No,45,No,No phone service,DSL,Yes,No,Yes,Yes,No,No,One year,No,Bank transfer (automatic),42.3,1840.75,No
4,Female,0,No,No,2,Yes,No,Fiber optic,No,No,No,No,No,No,Month-to-month,Yes,Electronic check,70.7,151.65,Yes


In [5]:
# convert TotalCharges to numeric
df['TotalCharges'] = pd.to_numeric(df['TotalCharges'], errors='coerce')
# impude missing values with mean
df['TotalCharges'] = df['TotalCharges'].fillna(df['TotalCharges'].mean())

# convert Churn to numeric
df['Churn'] = df['Churn'].map({'Yes': 1, 'No': 0})

In [6]:
features = df.iloc[:, :-1]
labels = df.iloc[:, -1].values

target_name = 'Churn'
feature_names = list(features.columns)

In [7]:
categorical_ids = [index for index, dtype in enumerate(df.dtypes) if dtype == 'object']
numerical_ids = [index for index, dtype in enumerate(df.dtypes) if dtype != 'object'][:-1]

In [8]:
categorical_features = [feature_names[index] for index in categorical_ids]
numerical_features = [feature_names[index] for index in numerical_ids]

In [9]:
# 将文本特征保存起来
category_map_tmp = {}
for col in categorical_features:
    le = LabelEncoder()
    features[col] = le.fit_transform(features[col].values)
    category_map_tmp[col] = list(le.classes_)

In [10]:
# 要进行counterfactual的个数
N_CF = 20

target_names = ['Good', 'Bad']

In [11]:
X, Y = df.iloc[:, :-1], df.iloc[:, -1]
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=7)

# DICE 方法

In [12]:
# Define numerical standard scaler.
num_transf = StandardScaler()

# Define categorical one-hot encoder.
cat_transf = OneHotEncoder(
    categories=[range(len(x)) for x in category_map_tmp.values()],
    handle_unknown="ignore"
)

# Define column transformer
preprocessor = ColumnTransformer(
    transformers=[
        ("cat", cat_transf, categorical_ids),
        ("num", num_transf, numerical_ids),
    ],
    sparse_threshold=0
)
# Fit preprocessor.
preprocessor.fit(X_train)

# Preprocess train and test dataset.
X_train_ohe = preprocessor.transform(X_train)

# Select one of the below classifiers.
clf = RandomForestClassifier(n_estimators=100, random_state=42)

# Fit the classifier.
clf.fit(X_train_ohe, Y_train)
# Define prediction function.
predictor = lambda x: clf.predict_proba(preprocessor.transform(x))

print(classification_report(y_true=Y_test, y_pred=predictor(X_test).argmax(axis=1)))

              precision    recall  f1-score   support

           0       0.82      0.88      0.85      1021
           1       0.60      0.47      0.53       388

    accuracy                           0.77      1409
   macro avg       0.71      0.68      0.69      1409
weighted avg       0.76      0.77      0.76      1409



In [13]:
d = dice_ml.Data(dataframe=df, continuous_features=numerical_features, outcome_name=target_name)

# 使用sklearn作为backend
backend = 'sklearn'

# 将sklearn的预测器包装成dice_ml的预测器
class ModelWrapper:
    def __init__(self, predictor_func):
        self.predictor_func = predictor_func
    
    def predict_proba(self, instances):
        return self.predictor_func(instances)
    
    def predict(self, instances):
       
        prob_preds = self.predictor_func(instances)
    
        return np.argmax(prob_preds, axis=1)

model_wrapper = ModelWrapper(predictor)

m = dice_ml.Model(model=model_wrapper, backend=backend)

In [14]:
# 计算proximity
dice_method = "random"

dice_result = []

for i in range(N_CF):
    counterfactuals_list = []
    exp = dice_ml.Dice(d, m, method=dice_method)
    query_instance_df = pd.DataFrame([X_test.iloc[i]])
    dice_exp = exp.generate_counterfactuals(query_instance_df, total_CFs=5, desired_class="opposite")
    
    final_cfs_df = dice_exp.cf_examples_list[0].final_cfs_df
    
    counterfactuals_list.append((query_instance_df, final_cfs_df))
    
    metrics_dice = metrics.calculate_metrics(
        counterfactuals_list, df,
        numerical_features, categorical_features,
        preprocessor, dice_method,target_name
    )
    dice_result.append(metrics_dice)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:01<00:00,  1.27s/it]
100%|██████████| 1/1 [00:01<00:00,  1.05s/it]
100%|██████████| 1/1 [00:01<00:00,  1.60s/it]
100%|██████████| 1/1 [00:01<00:00,  1.30s/it]
100%|██████████| 1/1 [00:07<00:00,  7.31s/it]
100%|██████████| 1/1 [00:00<00:00,  1.14it/s]
100%|██████████| 1/1 [00:02<00:00,  2.35s/it]
100%|██████████| 1/1 [00:00<00:00,  1.35it/s]
100%|██████████| 1/1 [00:00<00:00,  1.39it/s]
100%|██████████| 1/1 [00:01<00:00,  1.32s/it]
100%|██████████| 1/1 [00:00<00:00,  1.33it/s]
100%|██████████| 1/1 [03:30<00:00, 210.42s/it]
100%|██████████| 1/1 [00:00<00:00,  1.90it/s]
100%|██████████| 1/1 [00:00<00:00,  1.93it/s]
100%|██████████| 1/1 [00:00<00:00,  1.53it/s]
100%|██████████| 1/1 [00:00<00:00,  1.95it/s]
100%|██████████| 1/1 [00:00<00:00,  2.24it/s]
100%|██████████| 1/1 [00:00<00:00,  1.72it/s]
100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
100%|██████████| 1/1 [00:00<00:00,  1.68it/s]


In [15]:
# calculate the average of the metrics
dice_avg_proximity_cont = np.mean([x['avg_proximity_cont'] for x in dice_result])
dice_avg_proximity_cat = np.mean([x['avg_proximity_cat'] for x in dice_result])
dice_avg_sparsity = np.mean([x['avg_sparsity'] for x in dice_result])

print(f"Average proximity for continuous features: {dice_avg_proximity_cont}")
print(f"Average proximity for categorical features: {dice_avg_proximity_cat}")
print(f"Average sparsity: {dice_avg_sparsity}")

Average proximity for continuous features: 0.0
Average proximity for categorical features: 0.05731707317073172
Average sparsity: 2.35


In [16]:
# validate and time

# Select some positive examples.
X_negative = X_test[np.argmax(predictor(X_test), axis=1) == 0]
query_instance_df = pd.DataFrame(X_negative[0:N_CF], columns=feature_names)

counterfactuals_list = []
dice_time_list = []
dice_validity_list = []

for _, instance in query_instance_df.iterrows():
    instance_df = pd.DataFrame(instance).T
    
    start_time = time.time()
    
    dice_exp = exp.generate_counterfactuals(instance_df, total_CFs=5, desired_class="opposite")
    
    time_taken = time.time() - start_time
    dice_time_list.append(time_taken)
    
    # dice_exp[0].cf_examples_list[0].final_cfs_df
    if hasattr(dice_exp, 'cf_examples_list') and dice_exp.cf_examples_list[0]:
        cf_df = dice_exp.cf_examples_list[0].final_cfs_df
        if cf_df is not None:
            counterfactuals_list.append((instance_df, cf_df))
            dice_validity_list.append(1)
        else:
            counterfactuals_list.append((instance_df, None))
            dice_validity_list.append(0)
    else:
        counterfactuals_list.append((instance_df, None))
        dice_validity_list.append(0)

dice_avg_time = np.mean(dice_time_list)
dice_avg_validity = np.mean(dice_validity_list)

print("Average Time Taken per instance:", dice_avg_time)
print("Average Validity:", dice_avg_validity) 

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  1.23it/s]
100%|██████████| 1/1 [00:00<00:00,  1.36it/s]
100%|██████████| 1/1 [00:00<00:00,  1.81it/s]
100%|██████████| 1/1 [00:00<00:00,  1.62it/s]
100%|██████████| 1/1 [00:00<00:00,  1.78it/s]
100%|██████████| 1/1 [00:00<00:00,  2.00it/s]
100%|██████████| 1/1 [00:00<00:00,  2.01it/s]
100%|██████████| 1/1 [00:00<00:00,  1.60it/s]
100%|██████████| 1/1 [00:00<00:00,  1.67it/s]
100%|██████████| 1/1 [02:54<00:00, 174.64s/it]
100%|██████████| 1/1 [00:02<00:00,  2.71s/it]
100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
100%|██████████| 1/1 [00:00<00:00,  2.43it/s]
100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
100%|██████████| 1/1 [00:00<00:00,  2.36it/s]
100%|██████████| 1/1 [00:00<00:00,  1.58it/s]
100%|██████████| 1/1 [02:54<00:00, 174.30s/it]
100%|██████████| 1/1 [00:10<00:00, 10.10s/it]
100%|██████████| 1/1 [00:00<00:00,  2.12it/s]
100%|██████████| 1/1 [00:00<00:00,  1.79it/s]

Average Time Taken per instance: 18.53922688961029
Average Validity: 1.0





In [17]:
dice_exp.visualize_as_dataframe()

Query instance (original outcome : 0)


Unnamed: 0,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,OnlineBackup,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
0,Male,0,Yes,No,71,Yes,Yes,Fiber optic,Yes,Yes,Yes,Yes,Yes,Yes,Two year,Yes,Bank transfer (automatic),116.099998,8310.549805,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,OnlineBackup,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
0,Male,0.0,Yes,No,19.0,Yes,Yes,Fiber optic,Yes,Yes,Yes,Yes,Yes,Yes,Two year,Yes,Bank transfer (automatic),116.1,2457.7303,1
1,Male,1.0,Yes,No,39.0,Yes,Yes,DSL,Yes,Yes,Yes,Yes,Yes,Yes,Two year,Yes,Bank transfer (automatic),116.1,8310.55,1
2,Male,0.0,Yes,No,19.0,Yes,Yes,Fiber optic,Yes,Yes,Yes,Yes,Yes,Yes,Two year,Yes,Bank transfer (automatic),104.19,8310.55,1
3,Male,1.0,Yes,No,6.0,Yes,Yes,Fiber optic,Yes,Yes,No,Yes,Yes,Yes,Two year,Yes,Bank transfer (automatic),116.1,4201.4054,1
4,Male,1.0,Yes,No,39.0,Yes,Yes,DSL,Yes,No,Yes,Yes,Yes,Yes,Two year,Yes,Bank transfer (automatic),116.1,8310.55,1


# NICE 方法

In [18]:
X_nice = X.values
y_nice = Y.values

X_train_nice, X_test_nice, y_train_nice, y_test_nice = train_test_split(X_nice, y_nice, test_size=0.2, random_state=42)

clf_nice = Pipeline([
    ('preprocessor', ColumnTransformer(
        [
            ('num', num_transf , numerical_ids),
            ('cat', cat_transf, categorical_ids)
        ]
    )),
    ('classifier', RandomForestClassifier(n_estimators=100, random_state=42))]
)

clf_nice.fit(X_train_nice, y_train_nice)

In [19]:
predict_fn_nice = lambda x: clf_nice.predict_proba(x)

NICE_explainer = NICE(
    X_train=X_train_nice,
    predict_fn=predict_fn_nice,
    y_train=y_train_nice,
    cat_feat=categorical_ids,
    num_feat=numerical_ids
)

In [20]:
# 计算proximity

nice_result = []

for i in range(N_CF):
    counterfactuals_list_nice = []
    
    query_instance_df_nice = pd.DataFrame(X_test_nice[i:i+1,:], columns=feature_names)
    nice_exp = NICE_explainer.explain(X_test_nice[i:i+1,:])[0]

    final_cfs_nice = pd.DataFrame([nice_exp], columns=feature_names)
    final_cfs_nice[target_name] = clf_nice.predict(final_cfs_nice)
    
    counterfactuals_list_nice.append((query_instance_df_nice, final_cfs_nice))
    
    metrics_nice = metrics.calculate_metrics(
        counterfactuals_list_nice, df,
        numerical_features, categorical_features,
        preprocessor,'nice',target_name
    )
    nice_result.append(metrics_nice)



In [21]:
counterfactuals_list_nice

[(  gender SeniorCitizen Partner Dependents tenure PhoneService MultipleLines  \
  0   Male           0.0      No        Yes   22.0          Yes            No   
  
    InternetService OnlineSecurity OnlineBackup DeviceProtection TechSupport  \
  0     Fiber optic             No          Yes              Yes          No   
  
    StreamingTV StreamingMovies        Contract PaperlessBilling  \
  0         Yes              No  Month-to-month              Yes   
  
        PaymentMethod MonthlyCharges TotalCharges  
  0  Electronic check           89.4       2001.5  ,
    gender  SeniorCitizen Partner Dependents  tenure PhoneService MultipleLines  \
  0   Male            0.0      No        Yes    22.0          Yes            No   
  
    InternetService OnlineSecurity OnlineBackup DeviceProtection TechSupport  \
  0     Fiber optic             No          Yes              Yes          No   
  
    StreamingTV StreamingMovies        Contract PaperlessBilling  \
  0         Yes             

In [22]:
# calculate the average of the metrics
nice_avg_proximity_cont = np.mean([x['avg_proximity_cont'] for x in nice_result])
nice_avg_proximity_cat = np.mean([x['avg_proximity_cat'] for x in nice_result])
nice_avg_sparsity = np.mean([x['avg_sparsity'] for x in nice_result])

print(f"Average proximity for continuous features: {nice_avg_proximity_cont}")
print(f"Average proximity for categorical features: {nice_avg_proximity_cat}")
print(f"Average sparsity: {nice_avg_sparsity}")

Average proximity for continuous features: 0.0
Average proximity for categorical features: 0.04878048780487805
Average sparsity: 2.0


In [23]:
# validate and time

X_negative_nice = X_test_nice[np.argmax(predict_fn_nice(X_test_nice), axis=1) == 0]
query_instance_df_nice = pd.DataFrame(X_negative_nice[0:N_CF], columns=feature_names)

counterfactuals_list_nice = []
nice_time_list = []
nice_validity_list = []

for _, instance in query_instance_df_nice.iterrows():
    instance_df = pd.DataFrame(instance).T
    
    start_time = time.time()
    
    nice_exp = NICE_explainer.explain(instance_df.values)
    
    time_taken = time.time() - start_time
    nice_time_list.append(time_taken)
    
    cf_df = pd.DataFrame([nice_exp[0]], columns=feature_names)
    cf_df_prob = clf_nice.predict_proba(cf_df)
    if cf_df_prob[0][1] > 0.55:
        counterfactuals_list_nice.append((instance_df, cf_df))
        nice_validity_list.append(1)
    else:
        counterfactuals_list_nice.append((instance_df, None))
        nice_validity_list.append(0) 
        
nice_avg_time = np.mean(nice_time_list)
nice_avg_validity = np.mean(nice_validity_list)

print("Average Time Taken per instance:", nice_avg_time)
print("Average Validity:", nice_avg_validity)



Average Time Taken per instance: 0.1255448579788208
Average Validity: 0.8




# PROTOTYPE方法

In [24]:
categorical_ids, numerical_ids

([0, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 4, 17, 18])

In [25]:
data_perm = np.random.permutation(np.c_[features, labels])
X_alibi = data_perm[:,:-1]
y_alibi = data_perm[:,-1]

idx = 6000
y_train_alibi, y_test_alibi = y_alibi[:idx], y_alibi[idx:]

# 将文本特征和数值特征重新排列，使得文本特征在前，数值特征在后
X_alibi = np.c_[X_alibi[:,0], X_alibi[:,2:4], X_alibi[:, 5:17], X_alibi[:,1],
                X_alibi[:,4], X_alibi[:,17:]]

feature_names_alibi = categorical_features + numerical_features

print(feature_names_alibi)

['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod', 'SeniorCitizen', 'tenure', 'MonthlyCharges', 'TotalCharges']


In [26]:
# 将文本特征进行one-hot编码,并记录其位置与长度

category_map = {}
for i, (_, v) in enumerate(category_map_tmp.items()):
    category_map[i] = v 
    
cat_vars_ord = {}
n_categories = len(list(category_map.keys()))
for i in range(n_categories):
    cat_vars_ord[i] = len(np.unique(X_alibi[:, i]))
print(cat_vars_ord)

cat_vars_ohe = ord_to_ohe(X_alibi, cat_vars_ord)[1]
print(cat_vars_ohe)

{0: 2, 1: 2, 2: 2, 3: 2, 4: 3, 5: 3, 6: 3, 7: 3, 8: 3, 9: 3, 10: 3, 11: 3, 12: 3, 13: 2, 14: 4}
{0: 2, 2: 2, 4: 2, 6: 2, 8: 3, 11: 3, 14: 3, 17: 3, 20: 3, 23: 3, 26: 3, 29: 3, 32: 3, 35: 2, 37: 4}


In [27]:
categorical_features_alibi = categorical_features
numerical_features_alibi = numerical_features

In [28]:
# 对数据进行预处理
X_num = X_alibi[:, -len(numerical_ids):].astype(np.float32, copy=False)
xmin, xmax = X_num.min(axis=0), X_num.max(axis=0)
rng = (-1., 1.)
X_num_scaled = (X_num - xmin) / (xmax - xmin) * (rng[1] - rng[0]) + rng[0]

X_cat = X_alibi[:, :-len(numerical_ids)].copy()
ohe = OneHotEncoder(categories='auto', sparse_output=False).fit(X_cat)
X_cat_ohe = ohe.transform(X_cat)

# 对特征位置进行重新排序，使得文本特征在前，数值特征在后
X_alibi = np.c_[X_cat_ohe, X_num_scaled].astype(np.float32, copy=False)
X_train_alibi, X_test_alibi = X_alibi[:idx, :], X_alibi[idx:, :]
print(X_train_alibi.shape, X_test_alibi.shape)

(6000, 45) (1043, 45)


In [29]:
# define data preprocessor
num_transf = StandardScaler()
cat_transf = OneHotEncoder(
    categories=[range(len(x)) for x in category_map.values()],
    handle_unknown='ignore'
)
preprocessor_alibi = ColumnTransformer(
    transformers=[
        ('cat', cat_transf, categorical_ids),
        ('num', num_transf, numerical_ids)
    ],
    sparse_threshold=0
)

# fit data preprocessor
preprocessor_alibi = preprocessor_alibi.fit(features)

In [30]:
# 训练模型
clf_ablit = RandomForestClassifier(n_estimators=100, random_state=42)
clf_ablit.fit(X_train_alibi, y_train_alibi)

In [31]:
def predict_fn(x):
    pred_prob = clf_ablit.predict_proba(x)
    return np.hstack([1 - pred_prob[:,1].reshape(-1, 1), pred_prob[:,1].reshape(-1, 1)])

In [32]:
# Initialize the explainer object
X_alibi = X_test_alibi[0].reshape((1,) + X_test_alibi[0].shape)

shape = X_alibi.shape
beta = .01
c_init = 1.
c_steps = 5
max_iterations = 500
rng = (-1., 1.)  # scale features between -1 and 1
rng_shape = (1,) + features.shape[1:]
feature_range = ((np.ones(rng_shape) * rng[0]).astype(np.float32), 
                 (np.ones(rng_shape) * rng[1]).astype(np.float32))

In [33]:
cf = CounterfactualProto(predict_fn,
                         shape,
                         beta=beta,
                         cat_vars=cat_vars_ohe,
                         ohe=True,  # OHE flag
                         max_iterations=max_iterations,
                         feature_range=feature_range,
                         c_init=c_init,
                         c_steps=c_steps
                        )

cf.fit(X_train_alibi, d_type='abdm', disc_perc=[25, 50, 75])



CounterfactualProto(meta={
  'name': 'CounterfactualProto',
  'type': ['blackbox', 'tensorflow', 'keras'],
  'explanations': ['local'],
  'params': {
              'kappa': 0.0,
              'beta': 0.01,
              'gamma': 0.0,
              'theta': 0.0,
              'cat_vars': {
                            0: 2,
                            2: 2,
                            4: 2,
                            6: 2,
                            8: 3,
                            11: 3,
                            14: 3,
                            17: 3,
                            20: 3,
                            23: 3,
                            26: 3,
                            29: 3,
                            32: 3,
                            35: 2,
                            37: 4}
                          ,
              'ohe': True,
              'use_kdtree': False,
              'learning_rate_init': 0.01,
              'max_iterations': 500,
              'c_init

In [34]:
def describe_instance(X, explanation, target_names, eps=1e-2):
    print('Original instance: {}  -- proba: {}'.format(target_names[explanation.orig_class],
                                                       explanation.orig_proba[0]))
    print('Counterfactual instance: {}  -- proba: {}'.format(target_names[explanation.cf['class']],
                                                             explanation.cf['proba'][0]))
    print('\nCounterfactual perturbations...')
    print('\nCategorical:')
    X_orig_ord = ohe_to_ord(X, cat_vars_ohe)[0]
    X_cf_ord = ohe_to_ord(explanation.cf['X'], cat_vars_ohe)[0]
    delta_cat = {}
    for i, (_, v) in enumerate(category_map.items()):
        cat_orig = v[int(X_orig_ord[0, i])]
        cat_cf = v[int(X_cf_ord[0, i])]
        if cat_orig != cat_cf:
            delta_cat[feature_names_alibi[i]] = [cat_orig, cat_cf]
    if delta_cat:
        for k, v in delta_cat.items():
            print('{}: {}  -->   {}'.format(k, v[0], v[1]))
    print('\nNumerical:')
    delta_num = X_cf_ord[0, -4:] - X_orig_ord[0, -4:]
    n_keys = len(list(cat_vars_ord.keys()))
    for i in range(delta_num.shape[0]):
        if np.abs(delta_num[i]) > eps:
            print('{}: {:.2f}  -->   {:.2f}'.format(feature_names_alibi[i+n_keys],
                                            X_orig_ord[0,i+n_keys],
                                            X_cf_ord[0,i+n_keys]))
            

def calculate_proximity_pro(X_orig_ord, X_cf_ord, explanation, df):
    counterfactuals_list = []
    
    query_instance_df_alibi = pd.DataFrame(X_orig_ord, columns=feature_names_alibi)
    final_cfs_alibi = pd.DataFrame(X_cf_ord, columns=feature_names_alibi)
    final_cfs_alibi[target_name] = 0 if explanation.cf['proba'][0][1] < 0.5 else 1
    
    counterfactuals_list.append((query_instance_df_alibi, final_cfs_alibi))
    
    metrics_alibi = metrics.calculate_metrics(
        counterfactuals_list, df, numerical_features_alibi, 
        categorical_features_alibi, preprocessor_alibi,"prototype",target_name
    )
    
    return metrics_alibi

In [35]:
pd_churn= pd.DataFrame(features.values, columns=feature_names)
pd_churn[target_name] = labels

In [36]:
y_pred_alibi=predict_fn(X_test_alibi).argmax(axis=1)
instances_alibi = X_test_alibi[y_pred_alibi == 0][:N_CF]

metrics_alibi = []
time_alibi = []
counterfactuals_list_alibi = []

for i in range(N_CF):
    
    instance = instances_alibi[i].reshape(1, -1)
    start_time = time.time()
    explanation = cf.explain(instance)
    end_take = time.time() - start_time
    time_alibi.append(end_take)
    if explanation.cf is not None:
        counterfactuals_list_alibi.append(explanation.cf['X'])
        X_orig_ord = ohe_to_ord(X_alibi, cat_vars_ohe)[0]
        X_cf_ord = ohe_to_ord(explanation.cf['X'], cat_vars_ohe)[0]
        metric= calculate_proximity_pro(X_orig_ord, X_cf_ord, explanation, pd_churn)
        metrics_alibi.append(metric)
        describe_instance(X_alibi, explanation, target_names)
    else:
        counterfactuals_list_alibi.append(None)

2023-10-03 19:18:18.689777: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:353] MLIR V1 optimization pass is not enabled


Original instance: Good  -- proba: [0.81 0.19]
Counterfactual instance: Bad  -- proba: [0.49 0.51]

Counterfactual perturbations...

Categorical:

Numerical:
MonthlyCharges: 0.03  -->   0.04


No counterfactual found!


Original instance: Good  -- proba: [0.65 0.35]
Counterfactual instance: Bad  -- proba: [0.46033333 0.53966667]

Counterfactual perturbations...

Categorical:
gender: Female  -->   Male
InternetService: Fiber optic  -->   No
OnlineSecurity: No  -->   No internet service
OnlineBackup: No  -->   No internet service
DeviceProtection: No  -->   No internet service
TechSupport: No  -->   No internet service
StreamingTV: No  -->   No internet service
StreamingMovies: No  -->   No internet service
PaperlessBilling: Yes  -->   No
PaymentMethod: Electronic check  -->   Mailed check

Numerical:
tenure: -0.92  -->   -0.95
MonthlyCharges: 0.03  -->   0.79
TotalCharges: -0.96  -->   -1.00


No counterfactual found!


Original instance: Good  -- proba: [0.58 0.42]
Counterfactual instance: Bad  -- proba: [0.4 0.6]

Counterfactual perturbations...

Categorical:
DeviceProtection: No  -->   Yes
PaperlessBilling: Yes  -->   No

Numerical:
tenure: -0.92  -->   -0.97
MonthlyCharges: 0.03  -->   0.14
TotalCharges: -0.96  -->   -0.99


No counterfactual found!


Original instance: Good  -- proba: [0.78 0.22]
Counterfactual instance: Bad  -- proba: [0.48 0.52]

Counterfactual perturbations...

Categorical:
gender: Female  -->   Male
MultipleLines: No  -->   Yes
OnlineBackup: No  -->   Yes
DeviceProtection: No  -->   Yes
TechSupport: No  -->   Yes
StreamingTV: No  -->   Yes
StreamingMovies: No  -->   Yes
Contract: Month-to-month  -->   One year
PaperlessBilling: Yes  -->   No

Numerical:
tenure: -0.92  -->   -0.54
MonthlyCharges: 0.03  -->   0.83
TotalCharges: -0.96  -->   -0.24


No counterfactual found!
No counterfactual found!


Original instance: Good  -- proba: [0.67 0.33]
Counterfactual instance: Bad  -- proba: [0.49 0.51]

Counterfactual perturbations...

Categorical:
gender: Female  -->   Male
MultipleLines: No  -->   Yes
DeviceProtection: No  -->   Yes

Numerical:
tenure: -0.92  -->   -0.44
MonthlyCharges: 0.03  -->   0.18
TotalCharges: -0.96  -->   -0.56


No counterfactual found!


Original instance: Good  -- proba: [0.56 0.44]
Counterfactual instance: Bad  -- proba: [0.49 0.51]

Counterfactual perturbations...

Categorical:
gender: Female  -->   Male
Partner: No  -->   Yes
PhoneService: Yes  -->   No
MultipleLines: No  -->   No phone service
InternetService: Fiber optic  -->   DSL

Numerical:
tenure: -0.92  -->   -0.86
MonthlyCharges: 0.03  -->   -0.87
TotalCharges: -0.96  -->   -0.98


No counterfactual found!
No counterfactual found!
No counterfactual found!
No counterfactual found!
No counterfactual found!


Original instance: Good  -- proba: [0.66 0.34]
Counterfactual instance: Bad  -- proba: [0.48 0.52]

Counterfactual perturbations...

Categorical:
MultipleLines: No  -->   Yes
OnlineBackup: No  -->   Yes
PaymentMethod: Electronic check  -->   Mailed check

Numerical:
tenure: -0.92  -->   -0.39
MonthlyCharges: 0.03  -->   0.11
TotalCharges: -0.96  -->   -0.67


No counterfactual found!
No counterfactual found!


In [37]:
# calculate the average of the proximity metrics
pro_avg_proximity_cont = np.mean([x['avg_proximity_cont'] for x in metrics_alibi])
pro_avg_proximity_cat = np.mean([x['avg_proximity_cat'] for x in metrics_alibi])
pro_avg_sparsity= np.mean([x['avg_proximity_cat'] for x in metrics_alibi])


print(f"Average proximity for continuous features: {pro_avg_proximity_cont}")
print(f"Average proximity for categorical features: {pro_avg_proximity_cat}")
print(f"Average sparsity: {pro_avg_sparsity}")

Average proximity for continuous features: 0.35714285714285715
Average proximity for categorical features: 0.2578397212543554
Average sparsity: 0.2578397212543554


In [38]:
# validate and time
pro_avg_time = np.mean(time_alibi)

invalid_count = [cf for cf in counterfactuals_list_alibi if cf is None]
pro_avg_validity = 1 - len(invalid_count) / N_CF
print("Average Time Taken per instance:", pro_avg_time)
print("Average Validity:", pro_avg_validity)

Average Time Taken per instance: 171.31638306379318
Average Validity: 0.35


# Results

In [39]:
result_dict = {
    "Dice": {
        "proximity_cont": dice_avg_proximity_cont,
        "proximity_cat": dice_avg_proximity_cat,
        "sparsity": dice_avg_sparsity,
        "time(s)": dice_avg_time,
        "validity": dice_avg_validity
    },
    "Nice": {
        "proximity_cont": nice_avg_proximity_cont,
        "proximity_cat": nice_avg_proximity_cat,
        "sparsity": nice_avg_sparsity,
        "time(s)": nice_avg_time,
        "validity": nice_avg_validity
    },
    "Prototype": {
        "proximity_cont": pro_avg_proximity_cont,
        "proximity_cat": pro_avg_proximity_cat,
        "sparsity": pro_avg_sparsity,
        "time(s)": pro_avg_time,
        "validity": pro_avg_validity
    }
}

result = pd.DataFrame(result_dict).T.round(3)
result

Unnamed: 0,proximity_cont,proximity_cat,sparsity,time(s),validity
Dice,0.0,0.057,2.35,18.539,1.0
Nice,0.0,0.049,2.0,0.126,0.8
Prototype,0.357,0.258,0.258,171.316,0.35
