In [5]:
# !pip install torch
# !pip install dice-ml
# !pip install imblearn

In [6]:
from torch.utils.data import DataLoader
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, accuracy_score

import pandas as pd
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE

In [None]:
class CustomDataLoader:
    def __init__(self, filepath):
        self.filepath = filepath
        self.data = None

    def load_dataset(self):
        self.data = pd.read_csv(self.filepath)

    def preprocess_data(self):
        self.data.dropna(inplace=True)
        self.data = pd.get_dummies(self.data, drop_first=True)          
        bool_cols = self.data.select_dtypes(include=['bool']).columns
        self.data[bool_cols] = self.data[bool_cols].astype(int)

    def get_data_split(self, test_size=0.2, random_state=42):
        X = self.data.drop('stroke', axis=1)
        y = self.data['stroke']
        return train_test_split(X, y, test_size=test_size, random_state=random_state)

    def oversample(self, X_train, y_train):
        smote = SMOTE(random_state=42)
        X_res, y_res = smote.fit_resample(X_train, y_train)
        return X_res, y_res

data_loader = CustomDataLoader('../healthcare-dataset-stroke-data.csv')
data_loader.load_dataset()
data_loader.preprocess_data()

In [8]:
data_loader.get_data_split()

[         id   age  hypertension  heart_disease  avg_glucose_level   bmi  \
 3565  68302  40.0             0              0              65.77  31.2   
 898   62716  59.0             0              0              81.64  32.8   
 2707  46498  57.0             0              0             217.40  36.6   
 4198   4148  81.0             0              0              71.18  23.9   
 2746  35315  65.0             0              0              95.88  28.5   
 ...     ...   ...           ...            ...                ...   ...   
 4613  45530  19.0             0              0              89.30  22.1   
 511   27832  51.0             0              0              82.93  29.7   
 3247  64498  53.0             0              0              90.65  22.1   
 3946   8041  11.0             0              0              93.51  20.8   
 916   67864  63.0             0              0              57.82  28.8   
 
       gender_Male  gender_Other  ever_married_Yes  work_type_Never_worked  \
 3565   

In [None]:
# Split the data for evaluation
X_train, X_test, y_train, y_test = data_loader.get_data_split()

X_train, y_train = data_loader.oversample(X_train, y_train)


y_test = y_test.reset_index(drop=True)
X_test = X_test.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
X_train = X_train.reset_index(drop=True)




In [10]:
X_train

Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,gender_Male,gender_Other,ever_married_Yes,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Urban,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes
0,68302,40.000000,0,0,65.770000,31.200000,0,0,1,0,1,0,0,1,0,1,0
1,62716,59.000000,0,0,81.640000,32.800000,0,0,1,0,0,1,0,1,0,0,0
2,46498,57.000000,0,0,217.400000,36.600000,0,0,1,0,1,0,0,1,0,1,0
3,4148,81.000000,0,0,71.180000,23.900000,1,0,1,0,0,1,0,1,1,0,0
4,35315,65.000000,0,0,95.880000,28.500000,1,0,1,0,0,1,0,1,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7537,67937,76.212411,0,0,189.499430,30.325110,0,0,0,0,1,0,0,1,0,0,1
7538,58928,70.431539,0,0,232.729252,27.063771,0,0,1,0,0,0,0,0,0,1,0
7539,28311,76.816771,0,0,213.019245,29.988552,0,0,1,0,1,0,0,1,0,0,0
7540,2135,79.909279,1,0,98.558905,32.092584,0,0,1,0,0,1,0,0,0,0,0


In [11]:
X_train

Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,gender_Male,gender_Other,ever_married_Yes,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Urban,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes
0,68302,40.000000,0,0,65.770000,31.200000,0,0,1,0,1,0,0,1,0,1,0
1,62716,59.000000,0,0,81.640000,32.800000,0,0,1,0,0,1,0,1,0,0,0
2,46498,57.000000,0,0,217.400000,36.600000,0,0,1,0,1,0,0,1,0,1,0
3,4148,81.000000,0,0,71.180000,23.900000,1,0,1,0,0,1,0,1,1,0,0
4,35315,65.000000,0,0,95.880000,28.500000,1,0,1,0,0,1,0,1,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7537,67937,76.212411,0,0,189.499430,30.325110,0,0,0,0,1,0,0,1,0,0,1
7538,58928,70.431539,0,0,232.729252,27.063771,0,0,1,0,0,0,0,0,0,1,0
7539,28311,76.816771,0,0,213.019245,29.988552,0,0,1,0,1,0,0,1,0,0,0
7540,2135,79.909279,1,0,98.558905,32.092584,0,0,1,0,0,1,0,0,0,0,0


In [None]:
# Fit blackbox model
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
print(f"F1 Score {f1_score(y_test, y_pred, average='macro')}")
print(f"Accuracy {accuracy_score(y_test, y_pred)}")

F1 Score 0.5539842132690772
Accuracy 0.9032586558044806


In [13]:

y_test_series = pd.Series(y_test)
y_pred_series = pd.Series(y_pred)


test_indices = y_test_series[y_test_series == 1].index.tolist()
pred_indices = y_pred_series[y_pred_series == 1].index.tolist()

print("Test indices:", test_indices)
print("Prediction indices:", pred_indices)

Test indices: [11, 30, 35, 62, 73, 110, 113, 122, 143, 166, 167, 198, 229, 238, 274, 277, 299, 312, 327, 336, 362, 388, 426, 434, 475, 488, 499, 538, 573, 582, 592, 598, 613, 682, 685, 734, 752, 795, 804, 807, 809, 832, 851, 873, 878, 903, 910, 917, 928, 944, 964, 965, 978]
Prediction indices: [0, 49, 76, 79, 90, 101, 131, 134, 138, 142, 206, 238, 254, 265, 281, 312, 349, 353, 357, 368, 370, 375, 379, 388, 407, 417, 418, 419, 426, 433, 466, 486, 583, 587, 596, 638, 644, 649, 651, 656, 667, 675, 676, 679, 704, 712, 734, 785, 788, 795, 797, 807, 823, 888, 892, 910, 961, 963, 964, 971]


In [14]:
# Create diverse counterfactual explanations
import dice_ml

# Dataset
data_dice = dice_ml.Data(dataframe=data_loader.data,
                         continuous_features=['age',
                                              'avg_glucose_level',
                                              'bmi'],
                         outcome_name='stroke')

In [15]:
# Model
rf_dice = dice_ml.Model(model=rf,
                        backend="sklearn")
explainer = dice_ml.Dice(data_dice,
                         rf_dice,
                         method="random") # Random sampling, genetic algorithm, kd-tree,...

In [16]:
X_test[10:11]

Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,gender_Male,gender_Other,ever_married_Yes,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Urban,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes
10,12336,73.0,0,0,87.56,24.1,0,0,1,0,0,1,0,1,0,1,0


In [17]:

input_datapoint = X_test[10:11]



In [18]:
X_test[10:11]

Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,gender_Male,gender_Other,ever_married_Yes,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Urban,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes
10,12336,73.0,0,0,87.56,24.1,0,0,1,0,0,1,0,1,0,1,0


In [19]:
bool_cols = input_datapoint.select_dtypes(include=['bool']).columns    

In [20]:
input_datapoint[bool_cols] = input_datapoint[bool_cols].astype(int)

In [21]:
input_datapoint

Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,gender_Male,gender_Other,ever_married_Yes,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Urban,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes
10,12336,73.0,0,0,87.56,24.1,0,0,1,0,0,1,0,1,0,1,0


In [22]:
cf = explainer.generate_counterfactuals(input_datapoint,
                                  total_CFs=3,
                                  desired_class="opposite")

  candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
  candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
  candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
  candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
  candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
  candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
  candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
  candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
  candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
  candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
  candidate_cfs.at[k, selected

In [23]:


cf.visualize_as_dataframe(show_only_changes=True)


Query instance (original outcome : 0)


Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,gender_Male,gender_Other,ever_married_Yes,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Urban,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes,stroke
0,12336,73.0,0,0,87.559998,24.1,0,0,1,0,0,1,0,1,0,1,0,0



Diverse Counterfactual set (new outcome: 1)


Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,gender_Male,gender_Other,ever_married_Yes,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Urban,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes,stroke
0,-,-,-,-,171.43,-,-,-,-,-,-,-,-,0.0,-,-,-,-
1,-,-,-,-,-,-,-,-,-,-,-,0.0,-,0.0,-,-,-,1.0
2,-,-,-,-,-,-,-,-,-,-,-,-,-,-,-,-,-,1.0


Indices of people whose age is above 70: [0, 10, 22, 39, 56, 62, 63, 76, 79, 88, 106, 110, 113, 131, 139, 142, 143, 147, 183, 195, 202, 205, 206, 208, 216, 222, 235, 240, 247, 255, 258, 265, 267, 272, 274, 276, 277, 288, 304, 311, 312, 320, 324, 327, 336, 347, 349, 357, 362, 370, 376, 388, 389, 417, 419, 426, 432, 434, 443, 462, 465, 466, 468, 475, 480, 486, 492, 495, 501, 505, 542, 556, 575, 582, 596, 598, 601, 605, 611, 618, 634, 635, 640, 647, 651, 652, 658, 660, 668, 675, 679, 683, 685, 702, 709, 710, 722, 724, 732, 734, 749, 753, 764, 782, 785, 787, 790, 792, 799, 803, 807, 809, 824, 848, 851, 878, 880, 892, 895, 903, 907, 926, 938, 942, 943, 944, 949, 952, 960, 961, 971, 978, 980]
