In [49]:
# Date: 8/23/2024
# Note: The datasets and saved models should be located in the same directory.

In [50]:
import pandas as pd
import numpy as np
import random
import pickle

np.random.seed(42)
random.seed(42)

# Sklearn imports
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, mean_squared_error

# DiCE imports
import dice_ml
from dice_ml.utils import helpers  # helper functions

In [51]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [52]:
# Read the first row of the file to get the number of columns
first_row = pd.read_csv('run1_data.csv', nrows=1, index_col=0)
num_columns = len(first_row.columns)

In [53]:
# Read the data ('run1_data.csv') into pandas dataframe, set appropriate column names, and drop the first index column
data = pd.read_csv(
            'run1_data.csv',
            names=[f"c{i+1}" if i < 4 else f"x{i-3}" for i in range(num_columns)],
            index_col=0
        )

In [54]:
# display the data
data.head()

Unnamed: 0,c1,c2,c3,c4,x1,x2,x3,x4,x5,x6,...,x25,x26,x27,x28,x29,x30,x31,x32,x33,x34
0,1169.540683,18.496408,15.224072,31.452473,32.19883,38.288165,58.805226,5.443632,14.134122,20.11951,...,0.684054,0.603952,0.954927,0.695864,0.878525,1.430532,1.26016,1.270764,1.096701,1.291621
1,1209.921175,177.206756,164.514041,178.586172,81.03383,42.163961,88.478525,4.721448,81.971076,13.227354,...,0.68355,1.08599,1.355671,1.289681,0.587842,1.432993,0.999951,0.864257,0.983606,1.015225
2,1244.897008,108.12598,92.286233,115.203093,99.464347,32.396348,36.581992,26.289042,41.961135,3.942099,...,1.014885,1.394907,1.018665,1.107409,0.974816,1.209591,0.615243,1.232794,1.378043,1.019064
3,1302.654609,49.409119,66.761985,42.448202,68.049677,18.728996,97.176322,15.703224,62.271299,11.106618,...,0.682895,1.2869,0.811387,1.271514,1.414949,0.726272,1.49116,0.811467,1.385639,1.193628
4,1342.027851,87.009089,44.307198,131.063038,5.452473,8.67821,85.915581,28.412539,87.278679,31.121029,...,1.213165,1.275869,1.18405,1.297898,1.088989,1.359719,0.942214,0.887986,0.781126,0.59533


In [55]:
# For classification labels, replace all values greater than 0 with 1
for label in ["c2", "c3", "c4"]:
    data[label] = data[label].apply(lambda x: 1 if x > 0 else 0)

# Split the data into four different dataframes, each with one label
c1_data = data.drop(columns=["c2", "c3", "c4"])
c2_data = data.drop(columns=["c1", "c3", "c4"])
c3_data = data.drop(columns=["c1", "c2", "c4"])
c4_data = data.drop(columns=["c1", "c2", "c3"])


In [56]:
# experiment with DTC model using run1_data and c2 label
c2_data

Unnamed: 0,c2,x1,x2,x3,x4,x5,x6,x7,x8,x9,...,x25,x26,x27,x28,x29,x30,x31,x32,x33,x34
0,1,32.198830,38.288165,58.805226,5.443632,14.134122,20.119510,2.208966,31.272813,52.438734,...,0.684054,0.603952,0.954927,0.695864,0.878525,1.430532,1.260160,1.270764,1.096701,1.291621
1,1,81.033830,42.163961,88.478525,4.721448,81.971076,13.227354,26.149467,17.446112,55.342038,...,0.683550,1.085990,1.355671,1.289681,0.587842,1.432993,0.999951,0.864257,0.983606,1.015225
2,1,99.464347,32.396348,36.581992,26.289042,41.961135,3.942099,53.485392,14.652861,1.834099,...,1.014885,1.394907,1.018665,1.107409,0.974816,1.209591,0.615243,1.232794,1.378043,1.019064
3,1,68.049677,18.728996,97.176322,15.703224,62.271299,11.106618,68.134329,37.693453,24.226849,...,0.682895,1.286900,0.811387,1.271514,1.414949,0.726272,1.491160,0.811467,1.385639,1.193628
4,1,5.452473,8.678210,85.915581,28.412539,87.278679,31.121029,9.193852,0.944540,18.204899,...,1.213165,1.275869,1.184050,1.297898,1.088989,1.359719,0.942214,0.887986,0.781126,0.595330
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2495,0,67.440017,23.891418,44.480850,2.553818,9.746482,18.445405,35.914888,4.488894,47.597646,...,0.717102,1.023931,1.273903,1.116558,0.503754,0.595845,0.590680,1.040949,0.669079,0.560457
2496,0,67.459284,24.264945,44.478146,3.352594,10.579603,18.407612,35.907831,4.488894,47.633294,...,0.717079,1.021588,1.258366,1.115985,0.503724,0.595845,0.601449,1.040774,0.669079,0.560499
2497,0,67.459254,24.260590,44.478221,2.677064,10.551194,18.362377,35.907877,3.434927,47.632029,...,0.716482,1.077299,1.274509,1.115286,0.501379,0.595911,0.598460,1.040774,0.669218,0.560422
2498,0,67.459700,23.876126,44.480887,2.432265,2.432756,17.304496,35.916217,3.071389,49.976248,...,0.718572,1.017637,1.259721,1.116720,0.503051,0.595853,0.736168,1.042430,0.658176,0.558035


In [57]:
# extract the target label (c2)
target = c2_data['c2']
# split the data into training and testing sets
train_dataset, test_dataset, y_train, y_test = train_test_split(c2_data, target, test_size=0.2, random_state=0, stratify=target)

# remove the target label from training and testing sets
x_train = train_dataset.drop('c2', axis=1)
x_test = test_dataset.drop('c2', axis=1)


In [58]:
# construct a data object for DiCE with train dataset. 
# Specify the names of the continuous features and the name of the output variable that the ML model will predict.
# (this dataset contains only continuous features so specify all of them)

data_object = dice_ml.Data(dataframe=train_dataset, 
                 continuous_features=['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21', 'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'x31', 'x32', 'x33', 'x34'], 
                 outcome_name='c2')

In [59]:
# load the saved DTC model for run1 data and c2 label
model_name = 'run1_data_c2_dtc.pkl'
with open(model_name, 'rb') as f:
    dtc_model = pickle.load(f)


In [60]:
# Using sklearn backend
dice_model = dice_ml.Model(model=dtc_model, backend="sklearn")

# create an instance of the Dice class, which is used to generate counterfactual explanations
explanation_instance = dice_ml.Dice(data_object, dice_model, method="random")

In [61]:
# get a bad design sample (denoted by 1 in the label) to test on
sample = test_dataset[test_dataset['c2'] == 1].head(1)
sample = sample.drop('c2', axis=1)
sample

Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,x25,x26,x27,x28,x29,x30,x31,x32,x33,x34
461,57.239981,24.276352,45.785583,9.011235,14.580442,26.974037,37.513459,1.276198,67.737769,30.177336,...,1.270062,1.262886,1.359831,1.131628,1.391998,0.565735,1.310283,1.298156,1.02338,1.158065


In [62]:
# check the column names
sample.columns

Index(['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11',
       'x12', 'x13', 'x14', 'x15', 'x16', 'x17', 'x18', 'x19', 'x20', 'x21',
       'x22', 'x23', 'x24', 'x25', 'x26', 'x27', 'x28', 'x29', 'x30', 'x31',
       'x32', 'x33', 'x34'],
      dtype='object')

In [63]:
# generate 2 counterfactuals that can change the original outcome (1) to desired class (0)
counterfactuals = explanation_instance.generate_counterfactuals(sample, total_CFs=2, desired_class="opposite")

# set show_only_changes to True to see only changed feature values for the counterfactuals.
counterfactuals.visualize_as_dataframe(show_only_changes=True)

100%|██████████| 1/1 [00:29<00:00, 29.29s/it]

Query instance (original outcome : 1)





Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,x26,x27,x28,x29,x30,x31,x32,x33,x34,c2
0,57.239983,24.276352,45.785583,9.011234,14.580441,26.974037,37.513458,1.276198,67.73777,30.177336,...,1.262886,1.359831,1.131628,1.391998,0.565735,1.310283,1.298156,1.02338,1.158065,1



Diverse Counterfactual set (new outcome: 0.0)


Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,x26,x27,x28,x29,x30,x31,x32,x33,x34,c2
0,57.239980712666565,24.276352233416432,19.887116,9.011234500013389,14.58044163540698,26.974036577297383,37.513458735588245,1.2761983859472423,67.73776949579245,30.177335976221084,...,1.2628857826313442,1.359830690767527,1.131628243227746,1.3919979657704755,0.5657352583181996,1.3102829765601027,1.2981556587473106,1.0233797658981632,1.1580652326974783,0.0
1,57.239980712666565,24.276352233416432,45.785583039452106,9.011234500013389,14.58044163540698,26.974036577297383,37.513458735588245,1.2761983859472423,67.73776949579245,30.177335976221084,...,1.2628857826313442,1.359830690767527,1.131628243227746,1.3919979657704755,0.5657352583181996,1.3102829765601027,1.2981556587473106,1.0233797658981632,1.1580652326974783,0.0


In [64]:
# set show_only_changes to False, if you want to see the full feature values for the counterfactuals.
counterfactuals.visualize_as_dataframe(show_only_changes=False)

Query instance (original outcome : 1)


Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,x26,x27,x28,x29,x30,x31,x32,x33,x34,c2
0,57.239983,24.276352,45.785583,9.011234,14.580441,26.974037,37.513458,1.276198,67.73777,30.177336,...,1.262886,1.359831,1.131628,1.391998,0.565735,1.310283,1.298156,1.02338,1.158065,1



Diverse Counterfactual set (new outcome: 0.0)


Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,x26,x27,x28,x29,x30,x31,x32,x33,x34,c2
0,57.239981,24.276352,19.887116,9.011235,14.580442,26.974037,37.513459,1.276198,67.737769,30.177336,...,1.262886,1.359831,1.131628,1.391998,0.565735,1.310283,1.298156,1.02338,1.158065,0
1,57.239981,24.276352,45.785583,9.011235,14.580442,26.974037,37.513459,1.276198,67.737769,30.177336,...,1.262886,1.359831,1.131628,1.391998,0.565735,1.310283,1.298156,1.02338,1.158065,0


In [65]:
# set features_to_vary for node locations (features between x1 and x20)
counterfactuals_nodes = explanation_instance.generate_counterfactuals(sample,
                                  total_CFs=2,
                                  desired_class="opposite",
                                  features_to_vary=["x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x18", "x19", "x20"]
                                  )
counterfactuals_nodes.visualize_as_dataframe(show_only_changes=True)

100%|██████████| 1/1 [00:43<00:00, 43.02s/it]

Query instance (original outcome : 1)





Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,x26,x27,x28,x29,x30,x31,x32,x33,x34,c2
0,57.239983,24.276352,45.785583,9.011234,14.580441,26.974037,37.513458,1.276198,67.73777,30.177336,...,1.262886,1.359831,1.131628,1.391998,0.565735,1.310283,1.298156,1.02338,1.158065,1



Diverse Counterfactual set (new outcome: 0.0)


Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,x26,x27,x28,x29,x30,x31,x32,x33,x34,c2
0,57.239980712666565,24.276352233416432,45.785583039452106,21.4626236,14.58044163540698,26.974036577297383,37.513458735588245,1.2761983859472423,67.73776949579245,30.177335976221084,...,1.2628857826313442,1.359830690767527,1.131628243227746,1.3919979657704755,0.5657352583181996,1.3102829765601027,1.2981556587473106,1.0233797658981632,1.1580652326974783,0.0
1,57.239980712666565,24.276352233416432,45.785583039452106,9.011234500013389,14.58044163540698,26.974036577297383,37.513458735588245,1.2761983859472423,67.73776949579245,30.177335976221084,...,1.2628857826313442,1.359830690767527,1.131628243227746,1.3919979657704755,0.5657352583181996,1.3102829765601027,1.2981556587473106,1.0233797658981632,1.1580652326974783,0.0


In [66]:
# set features_to_vary for design variables to control edges (features between x21 and x34)
counterfactuals_edges = explanation_instance.generate_counterfactuals(sample,
                                  total_CFs=2,
                                  desired_class="opposite",
                                  features_to_vary=["x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x29", "x30", "x31", "x32", "x33", "x34"]
                                  )
counterfactuals_edges.visualize_as_dataframe(show_only_changes=True)

100%|██████████| 1/1 [00:29<00:00, 29.62s/it]

Query instance (original outcome : 1)





Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,x26,x27,x28,x29,x30,x31,x32,x33,x34,c2
0,57.239983,24.276352,45.785583,9.011234,14.580441,26.974037,37.513458,1.276198,67.73777,30.177336,...,1.262886,1.359831,1.131628,1.391998,0.565735,1.310283,1.298156,1.02338,1.158065,1



Diverse Counterfactual set (new outcome: 0.0)


Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,...,x26,x27,x28,x29,x30,x31,x32,x33,x34,c2
0,57.239980712666565,24.276352233416432,45.785583039452106,9.011234500013389,14.58044163540698,26.974036577297383,37.513458735588245,1.2761983859472423,67.73776949579245,30.177335976221084,...,1.2628857826313442,1.359830690767527,1.131628243227746,1.3919979657704755,0.5657352583181996,1.0824298,1.2981556587473106,1.0233797658981632,1.1580652326974783,0.0
1,57.239980712666565,24.276352233416432,45.785583039452106,9.011234500013389,14.58044163540698,26.974036577297383,37.513458735588245,1.2761983859472423,67.73776949579245,30.177335976221084,...,1.2628857826313442,1.359830690767527,1.131628243227746,1.3919979657704755,0.5657352583181996,1.0879233,1.2981556587473106,1.28722508,1.1580652326974783,0.0


In [67]:
# local feature importance scores
query_instance = sample
importance = explanation_instance.local_feature_importance(query_instance, total_CFs=10)
# These scores are computed for the given query instance by summarizing a set of counterfactual examples (here we set to have 10) around the point.
print(importance.local_importance)

100%|██████████| 1/1 [02:24<00:00, 144.85s/it]

[{'x18': 1.0, 'x34': 0.3, 'x5': 0.2, 'x6': 0.2, 'x12': 0.2, 'x14': 0.2, 'x20': 0.2, 'x26': 0.2, 'x4': 0.1, 'x10': 0.1, 'x13': 0.1, 'x15': 0.1, 'x22': 0.1, 'x24': 0.1, 'x25': 0.1, 'x29': 0.1, 'x30': 0.1, 'x31': 0.1, 'x1': 0.0, 'x2': 0.0, 'x3': 0.0, 'x7': 0.0, 'x8': 0.0, 'x9': 0.0, 'x11': 0.0, 'x16': 0.0, 'x17': 0.0, 'x19': 0.0, 'x21': 0.0, 'x23': 0.0, 'x27': 0.0, 'x28': 0.0, 'x32': 0.0, 'x33': 0.0}]



