# Counterfactuals guided by prototypes on California housing dataset

Requirements
- Python 3.10
- `pip install alibi[tensorflow]`

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"

import tensorflow as tf
tf.get_logger().setLevel(40) # suppress deprecation messages
tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import to_categorical

import os
import numpy as np
import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from alibi.explainers import CounterfactualProto

print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # False

  from .autonotebook import tqdm as notebook_tqdm


TF version:  2.14.1
Eager execution enabled:  False


## Load [California housing](https://scikit-learn.org/stable/datasets/real_world.html#california-housing-dataset) dataset

In [2]:
california = fetch_california_housing(as_frame=True)
X = california.data.to_numpy()
target = california.target.to_numpy()
feature_names = california.feature_names

In [3]:
california.data.head()

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude
0,8.3252,41.0,6.984127,1.02381,322.0,2.555556,37.88,-122.23
1,8.3014,21.0,6.238137,0.97188,2401.0,2.109842,37.86,-122.22
2,7.2574,52.0,8.288136,1.073446,496.0,2.80226,37.85,-122.24
3,5.6431,52.0,5.817352,1.073059,558.0,2.547945,37.85,-122.25
4,3.8462,52.0,6.281853,1.081081,565.0,2.181467,37.85,-122.25


In [4]:
y = np.zeros((target.shape[0],))
y[np.where(target > np.median(target))[0]] = 1

In [5]:
mu = X.mean(axis=0)
sigma = X.std(axis=0)
X = (X - mu) / sigma

In [6]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

## Black-box model

### Train the model

In [7]:
np.random.seed(42)
tf.random.set_seed(42)

In [8]:
def nn_model():
    x_in = Input(shape=(8,))
    x = Dense(40, activation='relu')(x_in)
    x = Dense(40, activation='relu')(x)
    x_out = Dense(2, activation='softmax')(x)
    nn = Model(inputs=x_in, outputs=x_out)
    nn.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    return nn

In [14]:
nn_path = os.path.join("models", "alibi", "housing")
nn_model_path = os.path.join(nn_path, "nn_california.h5")

In [12]:
if not os.path.exists(nn_path):
    os.makedirs(nn_path)

In [15]:
if not os.path.exists(nn_model_path):
    nn = nn_model()
    nn.summary()
    nn.fit(X_train, y_train, batch_size=64, epochs=500, verbose=0)
    nn.save(nn_model_path, save_format='h5')
else:
    nn = load_model(nn_model_path)

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 8)]               0         
                                                                 
 dense_3 (Dense)             (None, 40)                360       
                                                                 
 dense_4 (Dense)             (None, 40)                1640      
                                                                 
 dense_5 (Dense)             (None, 2)                 82        
                                                                 
Total params: 2082 (8.13 KB)
Trainable params: 2082 (8.13 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


  saving_api.save_model(


In [16]:
score = nn.evaluate(X_test, y_test, verbose=0)
print('Test accuracy: ', score[1])

Test accuracy:  0.87548447


  updates = self.state_updates


## Generate counterfactual guided by the nearest class prototype

In [17]:
X = X_test[1].reshape((1,) + X_test[1].shape)
shape = X.shape

In [18]:
cf = CounterfactualProto(nn, shape, use_kdtree=True, theta=10., max_iterations=1000,
                         feature_range=(X_train.min(axis=0), X_train.max(axis=0)),
                         c_init=1., c_steps=10)

cf.fit(X_train)

  updates=self.state_updates,
No encoder specified. Using k-d trees to represent class prototypes.


CounterfactualProto(meta={
  'name': 'CounterfactualProto',
  'type': ['blackbox', 'tensorflow', 'keras'],
  'explanations': ['local'],
  'params': {
              'kappa': 0.0,
              'beta': 0.1,
              'gamma': 0.0,
              'theta': 10.0,
              'cat_vars': None,
              'ohe': False,
              'use_kdtree': True,
              'learning_rate_init': 0.01,
              'max_iterations': 1000,
              'c_init': 1.0,
              'c_steps': 10,
              'eps': (0.001, 0.001),
              'clip': (-1000.0, 1000.0),
              'update_num_grad': 1,
              'write_dir': None,
              'feature_range': (array([-1.77429947, -2.19618048, -1.83504572, -1.61076772, -1.25612255,
       -0.22899997, -1.44288613, -2.38599234]), array([  5.85828581,   1.85618152,  55.16323628,  51.78248741,
        30.25033022, 119.41910319,   2.95806762,   2.62528006])),
              'shape': (1, 8),
              'is_model': True,
              '

In [19]:
explanation = cf.explain(X)

In [20]:
print(f'Original prediction: {explanation.orig_class}')
print(f'Counterfactual prediction: {explanation.cf["class"]}')

Original prediction: 0
Counterfactual prediction: 1


In [21]:
orig = X * sigma + mu
counterfactual = explanation.cf['X'] * sigma + mu
delta = counterfactual - orig
for i, f in enumerate(feature_names):
    if np.abs(delta[0][i]) > 1e-4:
        print(f'{f}: {delta[0][i]}')

AveOccup: -0.8686982136340218


In [22]:
pd.DataFrame(orig, columns=feature_names)

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude
0,2.5313,30.0,5.039384,1.193493,1565.0,2.679795,35.14,-119.46


In [23]:
pd.DataFrame(counterfactual, columns=feature_names)

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude
0,2.5313,30.0,5.039384,1.193493,1565.000004,1.811096,35.14,-119.46
