In [1]:
from util import *

from src.dataset import load_hospital
from src.counterfactual import get_baseline_counterfactuals
import joblib

import warnings
warnings.filterwarnings('ignore')
                        
# Get a model
model, encoder, scaler = joblib.load('models/hospital.gz') # Model should have the BlackBox interface
model

TabularModel(
  (lin1): Linear(in_features=10, out_features=200, bias=True)
  (lin2): Linear(in_features=200, out_features=50, bias=True)
  (lin3): Linear(in_features=50, out_features=2, bias=True)
  (bn1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drops): Dropout(p=0.3, inplace=False)
)

In [2]:
loader = iter(load_hospital(100, train=False))
X_corpus, _ = next(loader)
X_test, _ = next(loader)

simplex = get_simplex(model, X_corpus, X_test, verbose = True)

Weight Fitting Epoch: 2000/10000 ; Error: 6.78 ; Regulator: 17.1 ; Reg Factor: 1
Weight Fitting Epoch: 4000/10000 ; Error: 1.34 ; Regulator: 2.86 ; Reg Factor: 1
Weight Fitting Epoch: 6000/10000 ; Error: 0.621 ; Regulator: 0.828 ; Reg Factor: 1
Weight Fitting Epoch: 8000/10000 ; Error: 0.492 ; Regulator: 0.291 ; Reg Factor: 1
Weight Fitting Epoch: 10000/10000 ; Error: 0.451 ; Regulator: 0.108 ; Reg Factor: 1


In [3]:
i = 40
n_cfs = 5

x = simplex.test_examples[i:i+1]
desired_class = model(x).topk(2).indices[0,1]
cat_indices = list(range(len(encoder.cols)))

baseline_cfs = get_baseline_counterfactuals(model = model, target = desired_class, test = x, \
                                            corpus = X_corpus, n_counterfactuals = n_cfs)

cfs = simplex.get_counterfactuals(test_id = i, model = model, n_counterfactuals = n_cfs, cat_indices = cat_indices)

 10%|████                                     | 10/100 [00:00<00:00, 316.85it/s]


In [4]:
cols = ['Gender', 'Neighbourhood', 'Scholarship', 'Hipertension', 'Diabetes', \
       'Alcoholism', 'SMS_received', 'Handcap', 'Age', 'ScheduleDays']

display_tabular_cfs(cfs, model, x, desired_class, scaler, encoder, cols)

Original: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,RESISTÊNCIA,0,0,0,0,0,1,17,9



Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,RESISTÊNCIA,0,0,0,0,0,0,17,9



Predicted:  tensor(0)  ||  Desired:  tensor(0)  ||  Orginal:  tensor(1)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,RESISTÊNCIA,0,0,0,0,0,0,17,9



Predicted:  tensor(0)  ||  Desired:  tensor(0)  ||  Orginal:  tensor(1)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,GRANDE VITÓRIA,0,0,0,0,0,0,17,9



Predicted:  tensor(0)  ||  Desired:  tensor(0)  ||  Orginal:  tensor(1)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,JUCUTUQUARA,0,0,0,0,0,0,17,9



Predicted:  tensor(0)  ||  Desired:  tensor(0)  ||  Orginal:  tensor(1)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,MÁRIO CYPRESTE,0,0,0,0,0,0,17,9



Predicted:  tensor(0)  ||  Desired:  tensor(0)  ||  Orginal:  tensor(1)
************************************************************************************************************************


In [5]:
display_tabular_cfs(baseline_cfs, model, x, desired_class, scaler, encoder, cols)

Original: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,RESISTÊNCIA,0,0,0,0,0,1,17,9



Kept counterfactual generation: 


Unnamed: 0,Gender,Neighbourhood,Scholarship,Hipertension,Diabetes,Alcoholism,SMS_received,Handcap,Age,ScheduleDays
0,F,MÁRIO CYPRESTE,0,1,0,0,0,0,60,23



Predicted:  tensor(0)  ||  Desired:  tensor(0)  ||  Orginal:  tensor(1)
************************************************************************************************************************
