<a href="https://colab.research.google.com/github/22Ifeoma22/22Ifeoma22/blob/main/Counterfactuals_in_XAI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install libraries

In [None]:
# pip install torch

In [None]:
# pip install dice-ml

Import libraries

In [None]:
# %% Imports
from torch.utils.data import DataLoader
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, accuracy_score

import pandas as pd
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE

In [None]:
# path = '/content/data/healthcare-dataset-stroke-data.csv'

## Load the data

In [None]:
# %% Custom DataLoader
class CustomDataLoader:
    def __init__(self, filepath):
        self.filepath = filepath
        self.data = None

    def load_dataset(self):
        self.data = pd.read_csv(self.filepath)

    def preprocess_data(self):
        # Implement your preprocessing here
        self.data.dropna(inplace=True)
        self.data = pd.get_dummies(self.data)

    def get_data_split(self, test_size=0.2, random_state=42):
        X = self.data.drop('stroke', axis=1)
        y = self.data['stroke']
        return train_test_split(X, y, test_size=test_size, random_state=random_state)

    def oversample(self, X_train, y_train):
        smote = SMOTE(random_state=42)
        X_res, y_res = smote.fit_resample(X_train, y_train)
        return X_res, y_res

# %% Load and preprocess data
data_loader = CustomDataLoader('/content/data/healthcare-dataset-stroke-data.csv')
data_loader.load_dataset()
data_loader.preprocess_data()

## Train-Test data split

In [None]:
# Split the data for evaluation
X_train, X_test, y_train, y_test = data_loader.get_data_split()

# Oversample the train data
X_train, y_train = data_loader.oversample(X_train, y_train)


y_test = y_test.reset_index(drop=True)
X_test = X_test.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
X_train = X_train.reset_index(drop=True)


(3927, 22)
(982, 22)
(7542, 22)
(982, 22)


## Random Forest Classifier

In [None]:
# %% Fit blackbox model
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
print(f"F1 Score {f1_score(y_test, y_pred, average='macro')}")
print(f"Accuracy {accuracy_score(y_test, y_pred)}")

In [None]:
# Convert y_test and y_pred to pandas Series
y_test_series = pd.Series(y_test)
y_pred_series = pd.Series(y_pred)

# Get indices where y_test and y_pred are 1
test_indices = y_test_series[y_test_series == 1].index.tolist()
pred_indices = y_pred_series[y_pred_series == 1].index.tolist()

print("Test indices:", test_indices)
print("Prediction indices:", pred_indices)

Test indices: [11, 30, 35, 62, 73, 110, 113, 122, 143, 166, 167, 198, 229, 238, 274, 277, 299, 312, 327, 336, 362, 388, 426, 434, 475, 488, 499, 538, 573, 582, 592, 598, 613, 682, 685, 734, 752, 795, 804, 807, 809, 832, 851, 873, 878, 903, 910, 917, 928, 944, 964, 965, 978]
Prediction indices: [131, 254, 288, 379, 613, 679, 704, 971]


## Create Counterfactual Explanations

In [None]:
# %% Create diverse counterfactual explanations
import dice_ml

# Dataset
data_dice = dice_ml.Data(dataframe=data_loader.data,
                         # For perturbation strategy
                         continuous_features=['age',
                                              'avg_glucose_level',
                                              'bmi'],
                         outcome_name='stroke')

## Creating the Data and Model Objects for DiCE: (Diverse Counterfactual Explanations)

In [None]:
# Model
rf_dice = dice_ml.Model(model=rf,
                        # There exist backends for tf, torch, ...
                        backend="sklearn")
explainer = dice_ml.Dice(data_dice,
                         rf_dice,
                         # Random sampling, genetic algorithm, kd-tree,...
                         method="random")

## Generating and Visualizing Counterfactual Explanations:

In [None]:
# %% Create explanation
# Generate CF based on the blackbox model
input_datapoint = X_test[10:11]

cf = explainer.generate_counterfactuals(input_datapoint,
                                  total_CFs=3,
                                  desired_class="opposite")

In [None]:
print(X_test[0:1])

In [None]:
# Visualize it
# cf.visualize_as_dataframe(show_only_changes=False)

cf.visualize_as_dataframe(show_only_changes=True)

## Creating Feasible (Conditional) Counterfactuals

In [None]:
# Get indices where age is above 70
indices_above_70 = X_test[X_test['age'] > 70].index.tolist()

print("Indices of people whose age is above 70:", indices_above_70)

Indices of people whose age is above 70: [0, 10, 22, 39, 56, 62, 63, 76, 79, 88, 106, 110, 113, 131, 139, 142, 143, 147, 183, 195, 202, 205, 206, 208, 216, 222, 235, 240, 247, 255, 258, 265, 267, 272, 274, 276, 277, 288, 304, 311, 312, 320, 324, 327, 336, 347, 349, 357, 362, 370, 376, 388, 389, 417, 419, 426, 432, 434, 443, 462, 465, 466, 468, 475, 480, 486, 492, 495, 501, 505, 542, 556, 575, 582, 596, 598, 601, 605, 611, 618, 634, 635, 640, 647, 651, 652, 658, 660, 668, 675, 679, 683, 685, 702, 709, 710, 722, 724, 732, 734, 749, 753, 764, 782, 785, 787, 790, 792, 799, 803, 807, 809, 824, 848, 851, 878, 880, 892, 895, 903, 907, 926, 938, 942, 943, 944, 949, 952, 960, 961, 971, 978, 980]


In [None]:
# %% Create feasible (conditional) Counterfactuals
features_to_vary=['avg_glucose_level',
                  'bmi',
                  'smoking_status_smokes']
permitted_range={'avg_glucose_level':[40,300],
                'bmi':[15, 45]}

i = 139

input_datapoint2 = X_test[i:i+1]

print("Label of test data: ", y_test[i])
print(input_datapoint2.to_string(index=False))

# Now generating explanations using the new feature weights

cf = explainer.generate_counterfactuals(input_datapoint2,
                                  total_CFs=10,
                                  desired_class="opposite",
                                  permitted_range=permitted_range,
                                  features_to_vary=features_to_vary)
# Visualize it
cf.visualize_as_dataframe(show_only_changes=True)

Label of test data:  0
   id  age  hypertension  heart_disease  avg_glucose_level  bmi  gender_Female  gender_Male  gender_Other  ever_married_No  ever_married_Yes  work_type_Govt_job  work_type_Never_worked  work_type_Private  work_type_Self-employed  work_type_children  Residence_type_Rural  Residence_type_Urban  smoking_status_Unknown  smoking_status_formerly smoked  smoking_status_never smoked  smoking_status_smokes
44873 81.0             0              0              125.2 40.0              1            0             0                0                 1                   0                       0                  0                        1                   0                     0                     1                       0                               0                            1                      0


100%|██████████| 1/1 [00:00<00:00,  2.82it/s]

Query instance (original outcome : 0)





Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,gender_Female,gender_Male,gender_Other,ever_married_No,...,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Rural,Residence_type_Urban,smoking_status_Unknown,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes,stroke
0,44873,81.0,0,0,125.199997,40.0,1,0,0,0,...,0,1,0,0,1,0,0,1,0,0



Diverse Counterfactual set (new outcome: 1)


Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,gender_Female,gender_Male,gender_Other,ever_married_No,...,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Rural,Residence_type_Urban,smoking_status_Unknown,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes,stroke
0,-,-,-,-,-,-,0.0,-,-,-,...,-,-,-,-,-,-,-,-,-,1.0
1,-,-,-,-,-,-,-,-,-,-,...,-,-,-,-,0.0,-,-,-,-,1.0
2,-,-,-,-,-,-,-,-,-,-,...,-,-,-,-,-,-,-,0.0,-,1.0
3,-,-,-,-,-,-,0.0,-,-,-,...,-,-,-,-,-,-,-,-,-,1.0
4,-,-,-,-,-,-,-,-,-,-,...,-,0.0,-,-,-,-,-,-,-,1.0
5,-,-,-,-,-,18.9,-,-,-,-,...,-,-,-,-,0.0,-,-,-,-,1.0
6,-,-,-,-,-,-,0.0,-,-,-,...,-,-,-,-,-,-,-,-,-,1.0
7,-,-,-,-,-,-,0.0,-,-,-,...,-,-,-,-,-,-,-,-,-,1.0
8,-,-,-,-,257.4,-,0.0,-,-,-,...,-,-,-,-,-,-,-,-,-,1.0
9,-,-,-,-,-,-,-,-,-,-,...,-,-,-,-,-,-,-,0.0,-,1.0


In [None]:
print(y_test[i])

0
