In [1]:
from util import *

from src.dataset import load_hospital
from src.counterfactual import get_baseline_counterfactuals

import joblib
import time

import warnings
warnings.filterwarnings('ignore')
                        
# Get a model
model, encoder, scaler = joblib.load('models/hospital.gz') # Model should have the BlackBox interface
model

2022-11-13 16:48:38.548140: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


TabularModel(
  (lin1): Linear(in_features=10, out_features=200, bias=True)
  (lin2): Linear(in_features=200, out_features=50, bias=True)
  (lin3): Linear(in_features=50, out_features=2, bias=True)
  (bn1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drops): Dropout(p=0.3, inplace=False)
)

In [2]:
loader = iter(load_hospital(100, train=False))
X_corpus, _ = next(loader)
X_test, _ = next(loader)

simplex = get_simplex(model, X_corpus, X_test, verbose = True)

Weight Fitting Epoch: 2000/10000 ; Error: 6.78 ; Regulator: 17.1 ; Reg Factor: 1
Weight Fitting Epoch: 4000/10000 ; Error: 1.34 ; Regulator: 2.86 ; Reg Factor: 1
Weight Fitting Epoch: 6000/10000 ; Error: 0.621 ; Regulator: 0.828 ; Reg Factor: 1
Weight Fitting Epoch: 8000/10000 ; Error: 0.492 ; Regulator: 0.291 ; Reg Factor: 1
Weight Fitting Epoch: 10000/10000 ; Error: 0.451 ; Regulator: 0.108 ; Reg Factor: 1


In [3]:
%%time
test_id = 10

cfs, x, desired_class = get_simplex_cf_tabular(simplex, model, test_id, encoder)

CPU times: user 2.26 s, sys: 117 ms, total: 2.37 s
Wall time: 322 ms


In [5]:
cols = ['Gender', 'Neighbourhood', 'Scholarship', 'Hipertension', 'Diabetes', \
       'Alcoholism', 'SMS_received', 'Handcap', 'Age', 'ScheduleDays']

x = simplex.test_examples[test_id:test_id+1]
display_tabular_cfs(cfs, model, x, desired_class, scaler, encoder, cols)

Original: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,REDENÇÃO,0,1,1,0,0,0,57,14



Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,REDENÇÃO,0,1,1,0,0,1,57,14



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,CARATOÍRA,0,0,0,0,0,1,30,15



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,M,MARIA ORTIZ,0,0,0,0,0,1,30,15



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************


In [6]:
baseline_cfs = get_baseline_counterfactuals(model = model, target = desired_class, test = x, \
                                            corpus = X_corpus)

display_tabular_cfs(baseline_cfs, model, x, desired_class, scaler, encoder, cols)

Original: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,REDENÇÃO,0,1,1,0,0,0,57,14



Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,M,MARIA ORTIZ,0,0,0,0,0,1,40,11



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************


In [7]:
%%time
cf_proto_cf = get_cfproto_cf(X_corpus, model, x)
display_tabular_cfs(cf_proto_cf, model, x, desired_class, scaler, encoder, cols)




2022-11-13 16:49:58.966269: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
No encoder specified. Using k-d trees to represent class prototypes.
2022-11-13 16:49:59.233441: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled


Original: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,REDENÇÃO,0,1,1,0,0,0,57,14



Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,REDENÇÃO,0,1,1,0,0,1,57,14



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
CPU times: user 6min 44s, sys: 30.3 s, total: 7min 14s
Wall time: 1min


# Comparison

In [8]:
from tqdm import tqdm 

times = []
sparsity = []

for test_id in tqdm(range(20)):
    start = time.time()
    cfs, x, desired_class = get_simplex_cf_tabular(simplex, model, test_id, encoder)
    end = time.time()
    
    start_b = time.time()
    baseline_cfs = get_baseline_counterfactuals(model = model, target = desired_class, test = x, \
                                            corpus = X_corpus)
    end_b = time.time()
    
    start_c = time.time()
    cf_proto_cf = get_cfproto_cf(X_corpus, model, x)
    end_c = time.time()
    
    times.append([end - start, end_b - start_b, end_c - start_c])
    sparsity.append([(cfs[0] != x).sum(), (baseline_cfs[0] != x).sum(), (cf_proto_cf[0] != x).sum()])

  0%|                                                    | 0/20 [00:00<?, ?it/s]No encoder specified. Using k-d trees to represent class prototypes.
  5%|██▏                                         | 1/20 [00:56<17:49, 56.30s/it]No encoder specified. Using k-d trees to represent class prototypes.
 10%|████▍                                       | 2/20 [02:00<18:21, 61.21s/it]No encoder specified. Using k-d trees to represent class prototypes.
 15%|██████▌                                     | 3/20 [03:10<18:23, 64.93s/it]No encoder specified. Using k-d trees to represent class prototypes.
 20%|████████▊                                   | 4/20 [04:13<17:08, 64.31s/it]No encoder specified. Using k-d trees to represent class prototypes.
 25%|███████████                                 | 5/20 [05:14<15:44, 62.99s/it]No encoder specified. Using k-d trees to represent class prototypes.
 30%|█████████████▏                              | 6/20 [06:15<14:32, 62.32s/it]No encoder specified. Usin

In [9]:
pd.DataFrame(times, columns  = ['simplex', 'nn', 'cfproto']).describe().to_csv('results/hospital_times.csv')

In [10]:
pd.DataFrame(sparsity, columns  = ['simplex', 'nn', 'cfproto']).applymap(int).describe().to_csv('results/hospital_sparsity.csv')