In [125]:
import pandas as pd
import numpy as np
import dice_ml
import tensorflow.keras

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM
import warnings

warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 35)



In [123]:
# load datasets
old_df=pd.read_csv('data/multivariate/kenya_all_indicators.csv')
df = old_df.iloc[:19, 1:4]

df.head()


Unnamed: 0,he,brt,infant
0,4.639261,40.037,28.5
1,4.803683,39.777,28.4
2,4.958624,39.468,28.1
3,5.169111,39.135,27.9
4,5.344938,38.773,27.5


In [124]:
from sklearn.model_selection import train_test_split

print("Dataset Size : ", df.shape, df['brt'].shape)
X_train, X_test, Y_train, Y_test = train_test_split(df, df['brt'],
                                                    train_size=0.70,
                                                    random_state=123)

print("Train/Test Sizes : ",X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)

Dataset Size :  (19, 3) (19,)
Train/Test Sizes :  (13, 3) (6, 3) (13,) (6,)


In [126]:
model = Sequential([
            Dense(50, activation="relu", input_shape=(len(df.columns.values), )),
            Dense(50, activation="relu"),
            Dense(50, activation="relu"),
            Dense(1),
           ])

model.summary()

Model: "sequential_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_44 (Dense)             (None, 50)                200       
_________________________________________________________________
dense_45 (Dense)             (None, 50)                2550      
_________________________________________________________________
dense_46 (Dense)             (None, 50)                2550      
_________________________________________________________________
dense_47 (Dense)             (None, 1)                 51        
Total params: 5,351
Trainable params: 5,351
Non-trainable params: 0
_________________________________________________________________


In [127]:
model.compile(optimizer="adam", loss="mean_squared_error", metrics=["mae"])

history = model.fit(X_train, Y_train, batch_size=8, epochs=100, verbose=0)

In [128]:
from sklearn.metrics import mean_squared_error, r2_score

print("Train MSE : %.2f"%mean_squared_error(Y_train, model.predict(X_train)))
print("Test  MSE : %.2f"%mean_squared_error(Y_test, model.predict(X_test)))

Train MSE : 0.23
Test  MSE : 0.36


In [120]:
d = dice_ml.Data(
    dataframe=df, 
    continuous_features=['he'], 
    outcome_name='brt'
)

m = dice_ml.Model(model=model, backend="TF2")

In [121]:
# initiate DiCE
exp = dice_ml.Dice(d, m)
exp

ValueError: Input 0 of layer sequential_10 is incompatible with the layer: expected axis -1 of input shape to have value 3 but received input with shape (1, 20)

In [19]:
import random

idx = random.randint(1, len(X_test))

print("Actual Price : %.2f"%Y_test[idx])

sample = dict(zip(boston.feature_names, X_test[idx]))
sample

Actual Price : 21.20


{'CRIM': 3.67367,
 'ZN': 0.0,
 'INDUS': 18.1,
 'CHAS': 0.0,
 'NOX': 0.583,
 'RM': 6.312,
 'AGE': 51.9,
 'DIS': 3.9917,
 'RAD': 24.0,
 'TAX': 666.0,
 'PTRATIO': 20.2,
 'B': 388.62,
 'LSTAT': 10.58}

In [20]:
dice_exp = exp.generate_counterfactuals(sample, total_CFs=4, desired_class=1)



Diverse Counterfactuals found! total time taken: 03 min 08 sec


In [21]:
#show_only_changes used to show changed values only
dice_exp.visualize_as_dataframe(show_only_changes=True)

Query instance (original outcome : 2)


Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,Price
0,3.67367,0.0,18.1,0.0,0.583,6.312,51.9,3.9917,24.0,666.0,20.2,388.6,10.58,2.486



Diverse Counterfactual set (new outcome: 1)


Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,Price
0,2.95178,-,19.5,-,0.608,6.577,51.89999999999993,3.9926000000003423,-,677.1,20.200000000000003,388.5999999999995,10.579999999999991,0.0
1,3.67441,-,-,-,-,6.311000000000008,51.900000000000006,3.2732,-,665.9999999999999,20.2,385.2,10.579999999999997,0.0
2,2.65049,-,18.100000000000005,-,-,6.311000000000019,51.90000000000001,3.7196,-,666.0000000000006,20.2,375.5,13.1,0.0
3,6.77594,-,16.6,-,0.556,6.024,46.5,4.3427,-,634.1,19.4,396.4,10.579999999999991,0.0
