# Counterfactuals Guided by Prototypes for Heart Disease dataset:

In [None]:
!pip install alibi

Collecting alibi
[?25l  Downloading https://files.pythonhosted.org/packages/89/c1/7e6bbb4a69d84063d84dbf39ef5f95d9ee230379c542bed4e44ca8b878d8/alibi-0.5.5-py3-none-any.whl (228kB)
[K     |█▍                              | 10kB 17.2MB/s eta 0:00:01[K     |██▉                             | 20kB 1.7MB/s eta 0:00:01[K     |████▎                           | 30kB 2.2MB/s eta 0:00:01[K     |█████▊                          | 40kB 2.5MB/s eta 0:00:01[K     |███████▏                        | 51kB 2.0MB/s eta 0:00:01[K     |████████▋                       | 61kB 2.3MB/s eta 0:00:01[K     |██████████                      | 71kB 2.4MB/s eta 0:00:01[K     |███████████▌                    | 81kB 2.7MB/s eta 0:00:01[K     |█████████████                   | 92kB 2.9MB/s eta 0:00:01[K     |██████████████▍                 | 102kB 2.8MB/s eta 0:00:01[K     |███████████████▉                | 112kB 2.8MB/s eta 0:00:01[K     |█████████████████▏              | 122kB 2.8MB/s eta 0:00:0

In [None]:
import tensorflow as tf
tf.get_logger().setLevel(40) # suppress deprecation messages
tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import to_categorical
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from alibi.explainers import CounterFactualProto

print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # False

TF version:  2.3.0
Eager execution enabled:  False


In [None]:
df = pd.read_csv('/content/heartu.csv')
df.head(5)

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,condition
0,69,1,0,160,234,1,2,131,0,0.1,1,1,0,0
1,69,0,0,140,239,0,0,151,0,1.8,0,2,0,0
2,66,0,0,150,226,0,0,114,0,2.6,2,0,0,0
3,65,1,0,138,282,1,2,174,0,1.4,1,1,0,1
4,64,1,0,110,211,0,2,144,1,1.8,1,0,0,0


In [None]:
df.shape

(297, 14)

In [None]:
feature_names=['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal']
data = df[feature_names]
target = df.condition

Standardizing the data

In [None]:
mu = data.mean(axis=0)
sigma = data.std(axis=0)
data = (data - mu) / sigma

Splitting the data into training and testing set

In [None]:
x_train, x_test, y_train, y_test = train_test_split(data,target, random_state=0)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

Seeding a single counterfactual value for the runtime

In [None]:
np.random.seed(0)
tf.random.set_seed(0)

In [None]:
def nn_model():
    x_in = Input(shape=(13,))
    x = Dense(40, activation='relu')(x_in)
    x = Dense(40, activation='relu')(x)
    x_out = Dense(2, activation='softmax')(x)
    nn = Model(inputs=x_in, outputs=x_out)
    nn.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    return nn

Applying a neural network model

In [None]:
nn = nn_model()
nn.summary()
nn.fit(x_train, y_train, batch_size=64, epochs=500, verbose=0)
nn.save('nn_heart.h5', save_format='h5')

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 13)]              0         
_________________________________________________________________
dense (Dense)                (None, 40)                560       
_________________________________________________________________
dense_1 (Dense)              (None, 40)                1640      
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 82        
Total params: 2,282
Trainable params: 2,282
Non-trainable params: 0
_________________________________________________________________


In [None]:
nn = load_model('nn_heart.h5')
score = nn.evaluate(x_test, y_test, verbose=0)
print('Test accuracy: ', score[1])

Test accuracy:  0.82666665


We can see that the test accuracy is 82.67%, which is not too bad

In [None]:
x_test=x_test.to_numpy()
x_train=x_train.to_numpy()

Taking an instance for obtaining its counterfactual

In [None]:
X = x_test[3].reshape((1,) + x_test[3].shape)
shape = X.shape
shape

(1, 13)

In [None]:
nn = load_model('nn_heart.h5')

# Here we use k-d trees for reprenting the class prototypes
cf = CounterFactualProto(nn, shape, use_kdtree=True, theta=10., max_iterations=1000,
                         feature_range=(x_train.min(axis=0), x_train.max(axis=0)),
                         c_init=1., c_steps=10)

cf.fit(x_train)
explanation = cf.explain(X)

No encoder specified. Using k-d trees to represent class prototypes.


In [None]:
print('Original prediction: {}'.format(explanation.orig_class))
print('Counterfactual prediction: {}'.format(explanation.cf['class']))
#print(explanation.cf['proba'])
sigma=sigma.to_numpy()
mu=mu.to_numpy()
orig = X * sigma + mu
counterfactual = explanation.cf['X'] * sigma + mu
delta = counterfactual - orig
for i, f in enumerate(feature_names):
    if np.abs(delta[0][i]) > 1e-4:
      print('{}: {}'.format(f, delta[0][i]))

Original prediction: 1
Counterfactual prediction: 0
cp: -0.8131889454245802
trestbps: -20.180294808502538
thalach: 7.900992597841537
oldpeak: -1.798546826159645
ca: -1.8905086346674485
thal: -0.44745449912977175


Here we see that the original prediction is 1 (Disease), we apply the counterfactual explanations to change the condition from disease to No Disease.  

Here, we see that the chest pain type should be decrease by ```|floor(-0.813)| = 1``` i.e the chest pain should belong to recognizable
(symptomatic) category. Also, the resting blood pressure should be low. The maximum heart rate of the person should be higher by around 8, which can be done by regular exercise. Old peak value should be lowered by having a dash diet. ca should also be reduced by ```|floor(-1.89)| = 2``` by detecting the blockage using angioplasty and doing a bypass surgery if needed. 

