# 0. Getting Started with DECE Engine

In [1]:
from cf_ml.dataset import load_diabetes_dataset
from cf_ml.model import PytorchModelManager
from cf_ml.cf_engine import CFEnginePytorch

Prepare for the dataset and the model.

In [2]:
# load the sample PIMA dataset
dataset = load_diabetes_dataset()

# set the target model 
model = None
retrain = True

# initialize the model manager with the sample dataset
model_manager = PytorchModelManager(dataset, model=model)

# (re)train the model if needed
if retrain:
    model_manager.train(verbose=False)

print("Train Accuracy: {:.3f}, Test Accuracy: {:.3f}".format(
    model_manager.evaluate('train'), model_manager.evaluate('test')))

Train Accuracy: 0.798, Test Accuracy: 0.766


In [3]:
# initialize DECE engine
engine = CFEnginePytorch(dataset, model_manager)

## Use Case-1: Basic Usage
Generate one counterfactual explanation for a target example.

In [4]:
# an example instance that contains all attribute values from a patient
example = {
    'Pregnancies': 0,
    'Glucose': 120,
    'BloodPressure': 70,
    'SkinThickness': 25,
    'BMI': 30,
    'Insulin': 80,
    'DiabetesPedigreeFunction': 0.5,
    'Age': 40,
}

In [5]:
# generate one counterfactual explanation to the example
counterfactuals = engine.generate_counterfactual_examples(example)

# log the generated counterfactual explanation in [feature_1, feature_2, ..., target_class, prediction_class]
counterfactuals.all

[1/1]  Epoch-0, time cost: 0.392s, loss: 0.001, iteration: 500, validation rate: 1.000


Unnamed: 0,Pregnancies,Glucose,BloodPressure,SkinThickness,Insulin,BMI,DiabetesPedigreeFunction,Age,Outcome,Outcome_pred
0,1.0,124.0,72.0,25.0,80.0,30.0,0.503,41.0,positive,positive


## Use Case-2: Generate Counterfactual Examples for Multiple Examples
Generate one counterfactual explanation for each target example.

In [6]:
# 500 examples stored in a pandas.DataFrame
examples = dataset.get_train_X(preprocess=False).iloc[:500]

In [7]:
counterfactuals = engine.generate_counterfactual_examples(examples)

# log all valid counterfactual explanations
counterfactuals.valid.head()

[500/500]  Epoch-0, time cost: 4.286s, loss: 1.254, iteration: 1999, validation rate: 0.986


Unnamed: 0,Pregnancies,Glucose,BloodPressure,SkinThickness,Insulin,BMI,DiabetesPedigreeFunction,Age,Outcome,Outcome_pred
0,7.0,124.0,78.0,29.0,126.0,29.0,0.692,54.0,negative,negative
1,5.0,138.0,59.0,23.0,0.0,31.1,0.463,27.0,positive,positive
2,0.0,159.0,91.0,32.0,684.0,51.4,0.408,22.0,negative,negative
3,2.0,139.0,55.0,21.0,135.0,29.8,0.839,26.0,positive,positive
4,8.0,106.0,0.0,0.0,0.0,29.7,0.183,38.0,negative,negative


## Use Case-3: Generate Customized Counterfactual Examples 
Generate multiple counterfactual explanations with constraints.

In [8]:
examples = dataset.get_train_X(preprocess=False).iloc[:5]

# setting for counterfactual explanation generation
setting = {
    # number of counterfactual explanations for each example
    'num': 5, 
    # attributes that are allowed to change
    'changeable_attr': ['Glucose', 'BMI', 'BloodPressure'], 
    # variation ranges of counterfactual explanations defined in a dict
    'cf_range': { 
        'BloodPressure': {
            'min': 50,
            'max': 120,
        },
        'BMI': {
            'min': 20,
            'max': 50
        }
    }
}

In [9]:
counterfactuals = engine.generate_counterfactual_examples(examples, setting=setting)

# log all valid counterfactual explanations
counterfactuals.valid

[5/5]  Epoch-0, time cost: 3.175s, loss: 0.284, iteration: 543, validation rate: 0.960


Unnamed: 0,Pregnancies,Glucose,BloodPressure,SkinThickness,Insulin,BMI,DiabetesPedigreeFunction,Age,Outcome,Outcome_pred
1,7.0,114.0,77.0,29.0,126.0,31.5,0.692,54.0,negative,negative
2,7.0,129.0,99.0,29.0,126.0,30.2,0.692,54.0,negative,negative
3,7.0,135.0,90.0,29.0,126.0,26.3,0.692,54.0,negative,negative
4,7.0,147.0,85.0,29.0,126.0,21.6,0.692,54.0,negative,negative
5,4.0,147.0,60.0,23.0,0.0,32.0,0.443,22.0,positive,positive
6,4.0,143.0,50.0,23.0,0.0,29.3,0.443,22.0,positive,positive
7,4.0,135.0,52.0,23.0,0.0,35.2,0.443,22.0,positive,positive
8,4.0,134.0,63.0,23.0,0.0,39.3,0.443,22.0,positive,positive
9,4.0,123.0,55.0,23.0,0.0,45.1,0.443,22.0,positive,positive
10,0.0,165.0,90.0,33.0,680.0,46.4,0.427,23.0,negative,negative
