# 0. Getting Started with DECE Engine

In [1]:
from cf_ml.dataset import load_diabetes_dataset
from cf_ml.model import PytorchModelManager
from cf_ml.cf_engine import CFEnginePytorch

Prepare for the dataset and the model.

In [2]:
# load the sample PIMA dataset
dataset = load_diabetes_dataset()

# set the target model 
model = None
retrain = True

# initialize the model manager with the sample dataset
model_manager = PytorchModelManager(dataset, model=model)

# (re)train the model if needed
if retrain:
    model_manager.train()

Epoch: 0, loss=0.682, train_accuracy=0.640, test_accuracy=0.695
Epoch: 1, loss=0.618, train_accuracy=0.640, test_accuracy=0.695
Epoch: 2, loss=0.602, train_accuracy=0.640, test_accuracy=0.695
Epoch: 3, loss=0.443, train_accuracy=0.640, test_accuracy=0.695
Epoch: 4, loss=0.680, train_accuracy=0.684, test_accuracy=0.740
Epoch: 5, loss=0.448, train_accuracy=0.655, test_accuracy=0.701
Epoch: 6, loss=0.672, train_accuracy=0.710, test_accuracy=0.753
Epoch: 7, loss=0.841, train_accuracy=0.717, test_accuracy=0.766
Epoch: 8, loss=0.699, train_accuracy=0.697, test_accuracy=0.747
Epoch: 9, loss=0.522, train_accuracy=0.731, test_accuracy=0.766
Epoch: 10, loss=0.639, train_accuracy=0.728, test_accuracy=0.766
Epoch: 11, loss=0.480, train_accuracy=0.730, test_accuracy=0.766
Epoch: 12, loss=0.867, train_accuracy=0.743, test_accuracy=0.766
Epoch: 13, loss=0.314, train_accuracy=0.710, test_accuracy=0.747
Epoch: 14, loss=0.453, train_accuracy=0.761, test_accuracy=0.779
Epoch: 15, loss=0.507, train_accura

In [3]:
# initialize DECE engine
engine = CFEnginePytorch(dataset, model_manager)

## Use Case-1: Basic Usage
Generate one counterfactual explanation for a target example.

In [4]:
# an example instance that contains all attribute values from a patient
example = {
    'Pregnancies': 0,
    'Glucose': 120,
    'BloodPressure': 70,
    'SkinThickness': 25,
    'BMI': 30,
    'Insulin': 80,
    'DiabetesPedigreeFunction': 0.5,
    'Age': 40,
}

In [5]:
# generate one counterfactual explanation to the example
counterfactuals = engine.generate_counterfactual_examples(example)

# log the generated counterfactual explanation in [feature_1, feature_2, ..., target_class, prediction_class]
counterfactuals.all

[1/1]  Epoch-0, time cost: 0.426s, loss: 0.001, iteration: 500, validation rate: 1.000


Unnamed: 0,Pregnancies,Glucose,BloodPressure,SkinThickness,Insulin,BMI,DiabetesPedigreeFunction,Age,Outcome,Outcome_pred
0,1.0,122.0,70.0,25.0,93.0,30.3,0.593,42.0,positive,positive


## Use Case-2: Generate Counterfactual Examples for Multiple Examples
Generate one counterfactual explanation for each target example.

In [6]:
# 500 examples stored in a pandas.DataFrame
examples = dataset.get_train_X(preprocess=False).iloc[:500]

In [7]:
counterfactuals = engine.generate_counterfactual_examples(examples)

# log all valid counterfactual explanations
counterfactuals.valid.head()

[500/500]  Epoch-0, time cost: 3.900s, loss: 1.227, iteration: 1999, validation rate: 0.978


Unnamed: 0,Pregnancies,Glucose,BloodPressure,SkinThickness,Insulin,BMI,DiabetesPedigreeFunction,Age,Outcome,Outcome_pred
0,7.0,112.0,78.0,29.0,126.0,31.5,0.691,54.0,negative,negative
1,6.0,135.0,60.0,23.0,0.0,32.6,0.443,25.0,positive,positive
3,2.0,144.0,56.0,21.0,135.0,25.6,0.833,25.0,positive,positive
4,8.0,102.0,0.0,0.0,0.0,29.7,0.183,38.0,negative,negative
5,10.0,58.0,0.0,0.0,0.0,47.1,0.573,40.0,negative,negative


## Use Case-3: Generate Customized Counterfactual Examples 
Generate multiple counterfactual explanation with constraints.

In [8]:
examples = dataset.get_train_X(preprocess=False).iloc[:5]

# setting for counterfactual explanation generation
setting = {
    # number of counterfactual explanations for each example
    'num': 5, 
    # attributes that are allowed to change
    'changeable_attr': ['Glucose', 'BMI', 'BloodPressure'], 
    # variation ranges of counterfactual explanations defined in a dict
    'cf_range': { 
        'BloodPressure': {
            'min': 50,
            'max': 120,
        },
        'BMI': {
            'min': 20,
            'max': 50
        }
    }
}

In [9]:
counterfactuals = engine.generate_counterfactual_examples(examples, setting=setting)

# log all valid counterfactual explanations
counterfactuals.valid

[5/5]  Epoch-0, time cost: 10.457s, loss: 0.260, iteration: 1999, validation rate: 0.920


Unnamed: 0,Pregnancies,Glucose,BloodPressure,SkinThickness,Insulin,BMI,DiabetesPedigreeFunction,Age,Outcome,Outcome_pred
0,7.0,91.0,88.0,29.0,126.0,41.0,0.692,54.0,negative,negative
1,7.0,120.0,113.0,29.0,126.0,34.7,0.692,54.0,negative,negative
2,7.0,137.0,79.0,29.0,126.0,20.0,0.692,54.0,negative,negative
4,7.0,101.0,71.0,29.0,126.0,35.1,0.692,54.0,negative,negative
5,4.0,148.0,60.0,23.0,0.0,31.8,0.443,22.0,positive,positive
6,4.0,175.0,76.0,23.0,0.0,23.3,0.443,22.0,positive,positive
7,4.0,145.0,50.0,23.0,0.0,28.2,0.443,22.0,positive,positive
8,4.0,114.0,59.0,23.0,0.0,50.0,0.443,22.0,positive,positive
9,4.0,116.0,54.0,23.0,0.0,47.2,0.443,22.0,positive,positive
11,0.0,165.0,116.0,33.0,680.0,50.0,0.427,23.0,negative,negative
