In [1]:
# Sklearn imports
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.ensemble import RandomForestClassifier

# DiCE imports
import dice_ml
from dice_ml.utils import helpers  # helper functions

In [2]:
%load_ext autoreload
%autoreload 2

## Preliminaries: Loading a dataset and a ML model trained over it
### Loading the `Adult` dataset

We use the "adult" income dataset from UCI Machine Learning Repository (https://archive.ics.uci.edu/ml/datasets/adult). For demonstration purposes, we transform the data as described in **dice_ml.utils.helpers** module. 

In [3]:
dataset = helpers.load_adult_income_dataset()

This dataset has 8 features. The outcome is income which is binarized to 0 (low-income, <=50K) or 1 (high-income, >50K). 

In [4]:
dataset.head()

Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,28,Private,Bachelors,Single,White-Collar,White,Female,60,0
1,30,Self-Employed,Assoc,Married,Professional,White,Male,65,1
2,32,Private,Some-college,Married,White-Collar,White,Male,50,0
3,20,Private,Some-college,Single,Service,White,Female,35,0
4,41,Self-Employed,Some-college,Married,White-Collar,White,Male,50,0


In [5]:
# description of transformed features
adult_info = helpers.get_adult_data_info()
adult_info

{'age': 'age',
 'workclass': 'type of industry (Government, Other/Unknown, Private, Self-Employed)',
 'education': 'education level (Assoc, Bachelors, Doctorate, HS-grad, Masters, Prof-school, School, Some-college)',
 'marital_status': 'marital status (Divorced, Married, Separated, Single, Widowed)',
 'occupation': 'occupation (Blue-Collar, Other/Unknown, Professional, Sales, Service, White-Collar)',
 'race': 'white or other race?',
 'gender': 'male or female?',
 'hours_per_week': 'total work hours per week',
 'income': '0 (<=50K) vs 1 (>50K)'}

Split the dataset into train and test sets.

In [6]:
target = dataset["income"]
train_dataset, test_dataset, y_train, y_test = train_test_split(dataset,
                                                                target,
                                                                test_size=0.2,
                                                                random_state=0,
                                                                stratify=target)
x_train = train_dataset.drop('income', axis=1)
x_test = test_dataset.drop('income', axis=1)

Given the train dataset, we construct a data object for DiCE. Since continuous and discrete features have different ways of perturbation, we need to specify the names of the continuous features. DiCE also requires the name of the output variable that the ML model will predict.

In [23]:


train_dataset.info(verbose=1)
# Step 1: dice_ml.Data
d = dice_ml.Data(dataframe=train_dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')

<class 'pandas.core.frame.DataFrame'>
Int64Index: 20838 entries, 20907 to 6862
Data columns (total 9 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             20838 non-null  int64 
 1   workclass       20838 non-null  object
 2   education       20838 non-null  object
 3   marital_status  20838 non-null  object
 4   occupation      20838 non-null  object
 5   race            20838 non-null  object
 6   gender          20838 non-null  object
 7   hours_per_week  20838 non-null  int64 
 8   income          20838 non-null  int64 
dtypes: int64(3), object(6)
memory usage: 1.6+ MB


### Loading the ML model

DiCE supports sklearn, tensorflow and pytorch models. 

The variable *backend* below indicates the implementation type of DiCE we want to use. Four backends are supported: sklearn, TensorFlow 1.x with backend='TF1', Tensorflow 2.x with backend='TF2', and PyTorch with backend='PYT'. 

Below we show use a trained classification model using sklearn. 

In [8]:
numerical = ["age", "hours_per_week"]
categorical = x_train.columns.difference(numerical)

categorical_transformer = Pipeline(steps=[
    ('onehot', OneHotEncoder(handle_unknown='ignore'))])

transformations = ColumnTransformer(
    transformers=[
        ('cat', categorical_transformer, categorical)])

# Append classifier to preprocessing pipeline.
# Now we have a full prediction pipeline.
clf = Pipeline(steps=[('preprocessor', transformations),
                      ('classifier', RandomForestClassifier())])
model = clf.fit(x_train, y_train)

## Generating counterfactual examples using DiCE

We now initialize the DiCE explainer, which needs a dataset and a model. DiCE provides local explanation for the model *m* and requires an query input whose outcome needs to be explained. 

In [21]:



# Using sklearn backend
m = dice_ml.Model(model=model, backend="sklearn")
# Using method=random for generating CFs
exp = dice_ml.Dice(d, m, method="genetic")

AttributeError: 'PublicData' object has no attribute 'info'

The `method` parameter specifies the explanation method. DiCE supports three methods for sklearn models: random sampling, genetic algorithm search, and kd-tree based generation.  

The next code snippet shows how to generate and visualize counterfactuals. The first argument of the `generate_counterfactuals` method is the _query instances_ on which counterfactuals are desired. This can be a dataframe with one or more rows. 

Below we provide a sample input whose outcome is 0 (low-income) as per the ML model object *m*. Given the query input, we can now generate counterfactual explanations to show perturbed inputs from the original input where the ML model outputs class 1 (high-income). The last column shows the output of the classifier: `income-output` >=0.5 is class 1 and `income-output`<0.5 is class 0. 

In [18]:
e1 = exp.generate_counterfactuals(x_test[0:1], total_CFs=2, desired_class="opposite")
e1.visualize_as_dataframe(show_only_changes=True)

  0%|          | 0/1 [00:00<?, ?it/s]

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1 entries, 0 to 0
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype   
---  ------          --------------  -----   
 0   age             1 non-null      int32   
 1   workclass       1 non-null      category
 2   education       1 non-null      category
 3   marital_status  1 non-null      category
 4   occupation      1 non-null      category
 5   race            1 non-null      category
 6   gender          1 non-null      category
 7   hours_per_week  1 non-null      int32   
dtypes: category(6), int32(2)
memory usage: 838.0 bytes
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1 entries, 0 to 0
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   age             1 non-null      float64
 1   workclass       1 non-null      object 
 2   education       1 non-null      object 
 3   marital_status  1 non-null      object 
 4   occu

100%|██████████| 1/1 [00:00<00:00,  1.68it/s]

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1 entries, 0 to 0
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   age             1 non-null      float64
 1   workclass       1 non-null      object 
 2   education       1 non-null      object 
 3   marital_status  1 non-null      object 
 4   occupation      1 non-null      object 
 5   race            1 non-null      object 
 6   gender          1 non-null      object 
 7   hours_per_week  1 non-null      float64
dtypes: float64(2), object(6)
memory usage: 192.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1 entries, 0 to 0
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   age             1 non-null      float64
 1   workclass       1 non-null      object 
 2   education       1 non-null      object 
 3   marital_status  1 non-null      object 
 4   occupation   




Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,29,Private,HS-grad,Married,Blue-Collar,White,Female,38,0



Diverse Counterfactual set (new outcome: 1)


Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,30.0,-,Bachelors,-,Professional,-,-,-,1
0,-,-,Bachelors,-,-,-,Male,40.0,1


The `show_only_changes` parameter highlights the changes from the query instance. If you would like to see the full feature values for the counterfactuals, set it to False.

In [19]:
e1.visualize_as_dataframe(show_only_changes=False)

Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,29,Private,HS-grad,Married,Blue-Collar,White,Female,38,0



Diverse Counterfactual set (new outcome: 1)


Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,30.0,Private,Bachelors,Married,Professional,White,Female,38.0,1
0,29.0,Private,Bachelors,Married,Blue-Collar,White,Male,40.0,1


That's it! You can try generating counterfactual explanations for other examples using the same code. 
It is also possible to restrict the features to vary while generating the counterfactuals, and to specify permitted range of features within which the counterfactual should be generated. 

In [12]:
# Changing only age and education
e2 = exp.generate_counterfactuals(x_test[0:1],
                                  total_CFs=2,
                                  desired_class="opposite",
                                  features_to_vary=["education", "occupation"]
                                  )
e2.visualize_as_dataframe(show_only_changes=True)

  0%|          | 0/1 [00:00<?, ?it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 14946 to 14946
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      int64 
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      int64 
dtypes: int64(2), object(6)
memory usage: 72.0+ bytes
None


100%|██████████| 1/1 [00:00<00:00, 11.08it/s]

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1000 non-null   object
 1   workclass       1000 non-null   object
 2   education       1000 non-null   object
 3   marital_status  1000 non-null   object
 4   occupation      1000 non-null   object
 5   race            1000 non-null   object
 6   gender          1000 non-null   object
 7   hours_per_week  1000 non-null   object
dtypes: object(8)
memory usage: 62.6+ KB
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1000 non-null   object
 1   workclass       1000 non-null   object
 2   education       1000 non-null   object
 3   marital_status  1000 non-null   object
 4   occupation      1000 non-null   obj




Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,29,Private,HS-grad,Married,Blue-Collar,White,Female,38,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,-,-,Assoc,-,Service,-,-,-,1
1,-,-,Masters,-,Professional,-,-,-,1


In [13]:
# Restricting age to be between [20,30] and Education to be either {'Doctorate', 'Prof-school'}.
e3 = exp.generate_counterfactuals(x_test[0:1],
                                  total_CFs=2,
                                  desired_class="opposite",
                                  permitted_range={'age': [20, 30], 'education': ['Doctorate', 'Prof-school']})
e3.visualize_as_dataframe(show_only_changes=True)

  0%|          | 0/1 [00:00<?, ?it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 14946 to 14946
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      int64 
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      int64 
dtypes: int64(2), object(6)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1000 non-null   object
 1   workclass       1000 non-null   object
 2   education       1000 non-null   object
 3   marital_status  1000 non-null   object
 4   occupation      100

100%|██████████| 1/1 [00:00<00:00, 11.02it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 1 to 1
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 1 to 1
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 




Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,29,Private,HS-grad,Married,Blue-Collar,White,Female,38,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,age,workclass,education,marital_status,occupation,race,gender,hours_per_week,income
0,-,Government,-,Single,-,-,-,-,1
1,-,-,Prof-school,-,-,-,Male,-,1


## Generating feature attributions (local and global) using DiCE

DiCE can generate feature importance scores using a summary of the counterfactuals generated. Intuitively, a feature that is changed more often when generating proximal counterfactuals for an input is locally important for causing the model's prediction at the input. Formally, counterfactuals operationalize the __necessity__ criterion for a model explanation: _is the feature value necessary for the given model output?_ 

For more details, refer to the paper, [Towards Unifying Feature Attribution and Counterfactual Explanations: Different Means to the Same End](https://arxiv.org/abs/2011.04917).

### Local feature importance scores

These scores are computed for a given query instance (input point) by summarizing a set of counterfactual examples around the point. 

In [14]:
query_instance = x_test[0:1]
imp = exp.local_feature_importance(query_instance, total_CFs=10)
print(imp.local_importance)

  0%|          | 0/1 [00:00<?, ?it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 14946 to 14946
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      int64 
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      int64 
dtypes: int64(2), object(6)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1000 non-null   object
 1   workclass       1000 non-null   object
 2   education       1000 non-null   object
 3   marital_status  1000 non-null   object
 4   occupation      100

100%|██████████| 1/1 [00:00<00:00,  5.27it/s]

None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 8 to 8
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 9 to 9
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      objec




The `total_CFs` parameter denotes the number of counterfactuals that are used to create the local importance. More the better.

### Global feature importance scores

A global importance score per feature can be estimated by aggregating the scores over individual inputs. The more the inputs, the better the estimate for global importance of a feature.

In [15]:
query_instances = x_test[0:20]
imp = exp.global_feature_importance(query_instances)
print(imp.summary_importance)

  0%|          | 0/20 [00:00<?, ?it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 14946 to 14946
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      int64 
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      int64 
dtypes: int64(2), object(6)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1000 non-null   object
 1   workclass       1000 non-null   object
 2   education       1000 non-null   object
 3   marital_status  1000 non-null   object
 4   occupation      100

  5%|▌         | 1/20 [00:00<00:03,  4.88it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 7 to 7
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 7 to 7
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 10%|█         | 2/20 [00:00<00:03,  4.56it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 605 to 605
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      int64 
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      int64 
dtypes: int64(2), object(6)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1000 non-null   object
 1   workclass       1000 non-null   object
 2   education       1000 non-null   object
 3   marital_status  1000 non-null   object
 4   occupation      1000 no

 15%|█▌        | 3/20 [00:00<00:03,  4.91it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 6 to 6
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 7 to 7
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 20%|██        | 4/20 [00:00<00:03,  5.11it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 1 to 1
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 2 to 2
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 25%|██▌       | 5/20 [00:01<00:02,  5.06it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 3 to 3
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 4 to 4
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 30%|███       | 6/20 [00:01<00:02,  5.21it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 6 to 6
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 7 to 7
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 35%|███▌      | 7/20 [00:01<00:02,  5.28it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 9 to 9
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 9 to 9
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 40%|████      | 8/20 [00:01<00:02,  4.90it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 7 to 7
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 8 to 8
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 45%|████▌     | 9/20 [00:01<00:02,  4.84it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 16138 to 16138
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      int64 
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      int64 
dtypes: int64(2), object(6)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1000 non-null   object
 1   workclass       1000 non-null   object
 2   education       1000 non-null   object
 3   marital_status  1000 non-null   object
 4   occupation      100

 50%|█████     | 10/20 [00:02<00:02,  4.61it/s]

None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 8 to 8
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 9 to 9
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      objec

 55%|█████▌    | 11/20 [00:02<00:01,  4.58it/s]

None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 9 to 9
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 25357 to 25357
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      int64 
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null   

 60%|██████    | 12/20 [00:02<00:01,  4.61it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 9 to 9
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 22187 to 22187
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      int64 
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      ob

 65%|██████▌   | 13/20 [00:02<00:01,  4.60it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 9 to 9
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 9 to 9
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 70%|███████   | 14/20 [00:02<00:01,  4.80it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 3 to 3
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 4 to 4
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 75%|███████▌  | 15/20 [00:03<00:01,  4.96it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 8 to 8
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 9 to 9
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 80%|████████  | 16/20 [00:03<00:00,  5.08it/s]

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10 entries, 0 to 9
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             10 non-null     object
 1   workclass       10 non-null     object
 2   education       10 non-null     object
 3   marital_status  10 non-null     object
 4   occupation      10 non-null     object
 5   race            10 non-null     object
 6   gender          10 non-null     object
 7   hours_per_week  10 non-null     object
dtypes: object(8)
memory usage: 768.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 0 to 0
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 

 85%|████████▌ | 17/20 [00:03<00:00,  5.18it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 2 to 2
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 2 to 2
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 90%|█████████ | 18/20 [00:03<00:00,  5.24it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 4 to 4
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 5 to 5
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

 95%|█████████▌| 19/20 [00:03<00:00,  5.25it/s]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 6 to 6
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5   race            1 non-null      object
 6   gender          1 non-null      object
 7   hours_per_week  1 non-null      object
dtypes: object(8)
memory usage: 72.0+ bytes
None
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1 entries, 7 to 7
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             1 non-null      object
 1   workclass       1 non-null      object
 2   education       1 non-null      object
 3   marital_status  1 non-null      object
 4   occupation      1 non-null      object
 5 

100%|██████████| 20/20 [00:04<00:00,  4.95it/s]

{'education': 0.64, 'marital_status': 0.27, 'age': 0.25, 'hours_per_week': 0.245, 'workclass': 0.235, 'occupation': 0.22, 'race': 0.115, 'gender': 0.08}





## Working with deep learning models (TensorFlow and PyTorch)

We now show examples of gradient-based methods with Tensorflow and Pytorch models. Since the gradient-based methods _optimize_ the loss rather than simply sampling some points, they can be slower to generate counterfactuals. The loss is defined by three component: **validity** (does the CF have the desired model output), **proximity** (distance of CF from original point should be low), and **diversity** (multiple CFs should change different features). The DiCE loss formulation is described in the paper, [Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations](https://arxiv.org/abs/1905.07697).

Below, we use a pre-trained ML model which produces high accuracy comparable to other baselines. For convenience, we include the sample trained model with the DiCE package.

### Explaining a Tensorflow model

In [16]:
backend = 'TF2'  # needs tensorflow installed
ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
# Step 2: dice_ml.Model
m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func="ohe-min-max")

UserConfigValidationException: Unable to import tensorflow. Please install tensorflow

We want to note that the time required to find counterfactuals with Tensorflow 2.x's eager style of execution is significantly greater than that with TensorFlow 1.x's graph execution.

Based on the data object *d* and the model object *m*, we can now instantiate the DiCE class for generating explanations. 

In [None]:
# Step 3: initiate DiCE
exp = dice_ml.Dice(d, m, method="gradient")

Below we provide query instances from `x_test`.

In [None]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(x_test[1:2], total_CFs=4, desired_class="opposite")
# visualize the result, highlight only the changes
dice_exp.visualize_as_dataframe(show_only_changes=True)

The counterfactuals generated above are slightly different from those shown in [our paper](https://arxiv.org/pdf/1905.07697.pdf), where the loss convergence condition was made more conservative for rigorous experimentation. To replicate the results in the paper, add an argument *loss_converge_maxiter=2* (the default value is 1) in the *exp.generate_counterfactuals()* method above. For more info, see *generate_counterfactuals()* method in [dice_ml.dice_interfaces.dice_tensorflow.py](https://github.com/interpretml/DiCE/blob/master/dice_ml/dice_interfaces/dice_tensorflow1.py).

### Explaining a Pytorch model

Just change the backend variable to 'PYT' to use DiCE with PyTorch. Below, we use a pre-trained ML model in PyTorch which produces high accuracy comparable to other baselines. For convenience, we include the sample trained model with the DiCE package. Additionally, we need to provide a data transformer function that converts input dataframe into one-hot encoded/numeric format. 

In [None]:
backend = 'PYT'  # needs pytorch installed
ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
m = dice_ml.Model(model_path=ML_modelpath, backend=backend,  func="ohe-min-max")

Instantiate the DiCE class with the new PyTorch model object *m*. 

In [None]:
exp = dice_ml.Dice(d, m, method="gradient")

In [None]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(x_test[1:3], total_CFs=4, desired_class="opposite")
# highlight only the changes
dice_exp.visualize_as_dataframe(show_only_changes=True)

We can also use method-agnostic explainers like "random" or "genetic". 

In [None]:
m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func="ohe-min-max")
exp = dice_ml.Dice(d, m, method="random")

In [None]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(x_test[1:3], total_CFs=4, desired_class="opposite")
# highlight only the changes
dice_exp.visualize_as_dataframe(show_only_changes=True)

## More resources: What's next?

DiCE has multiple configurable options and support for different kinds of models. Follow these notebooks to learn more.

1. You can constrain the features to vary, weigh the relative importance of different features for computing distance, or specify permitted ranges for each feature. Check out the [Customizing Counterfactuals Notebook](DiCE_with_advanced_options.ipynb). 
2. You can use it for multi-class classification or regression models. [Counterfactuals for Multi-class Classification and Regression Models Notebook](DiCE_multiclass_classification_and_regression.ipynb).
3. Explore the different model-agnostic explanation methods in DiCE. [Model-agnostic Counterfactual Generation Methods](DiCE_model_agnostic_CFs.ipynb).
4. You can generate CFs even without access to training data (e.g., for privacy-sensitive data). [DiCE with Private Data Notebook](DiCE_with_private_data.ipynb).
5. Feasibility of counterfactuals is an important consideration. You can try out this VAE-based method that adds a CF likelihood term to the loss. [VAE-based Counterfactuals (Pytorch only)](DiCE_getting_started_feasible.ipynb).

