In [1]:
from util import *

from src.dataset import load_diabetes
from src.counterfactual import get_baseline_counterfactuals
import joblib

import warnings
warnings.filterwarnings('ignore')
                        
# Get a model
model, encoder, scaler = joblib.load('models/diabetes.gz') # Model should have the BlackBox interface
model

TabularModel(
  (lin1): Linear(in_features=13, out_features=200, bias=True)
  (lin2): Linear(in_features=200, out_features=50, bias=True)
  (lin3): Linear(in_features=50, out_features=2, bias=True)
  (bn1): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drops): Dropout(p=0.3, inplace=False)
)

In [2]:
loader = iter(load_diabetes(100, train=False))
X_corpus, _ = next(loader)
X_test, _ = next(loader)

simplex = get_simplex(model, X_corpus, X_test, verbose = True)

Weight Fitting Epoch: 2000/10000 ; Error: 27.5 ; Regulator: 44.4 ; Reg Factor: 1
Weight Fitting Epoch: 4000/10000 ; Error: 18.9 ; Regulator: 7.95 ; Reg Factor: 1
Weight Fitting Epoch: 6000/10000 ; Error: 17.2 ; Regulator: 2.27 ; Reg Factor: 1
Weight Fitting Epoch: 8000/10000 ; Error: 16.8 ; Regulator: 0.763 ; Reg Factor: 1
Weight Fitting Epoch: 10000/10000 ; Error: 16.6 ; Regulator: 0.281 ; Reg Factor: 1


In [4]:
i = 40
n_cfs = 5

x = simplex.test_examples[i:i+1]
desired_class = model(x).topk(2).indices[0,1]
cat_indices = list(range(len(encoder.cols)))

baseline_cfs = get_baseline_counterfactuals(model = model, target = desired_class, test = x, \
                                            corpus = X_corpus, n_counterfactuals = n_cfs)

cfs = simplex.get_counterfactuals(test_id = i, model = model, n_counterfactuals = n_cfs, cat_indices = cat_indices)

 10%|████                                     | 10/100 [00:00<00:00, 262.41it/s]


In [5]:
cols = ['GenHlth', 'Age', 'Education', 'Income', \
        'HighBP','BMI','HighChol','DiffWalk',\
        'HeartDiseaseorAttack','PhysHlth',\
        'HvyAlcoholConsump','Sex','CholCheck']

display_tabular_cfs(cfs, model, x, desired_class, scaler, encoder, cols)

Original: 


Unnamed: 0,GenHlth,Age,Education,Income,HighBP,BMI,HighChol,DiffWalk,HeartDiseaseorAttack,PhysHlth,HvyAlcoholConsump,Sex,CholCheck
0,category_2,category_7,category_5,category_7,0,27,0,0,0,0,1,1,1



Kept counterfactual generation: 


Unnamed: 0,GenHlth,Age,Education,Income,HighBP,BMI,HighChol,DiffWalk,HeartDiseaseorAttack,PhysHlth,HvyAlcoholConsump,Sex,CholCheck
0,category_2,category_7,category_5,category_7,0,55,0,0,0,0,1,1,1



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,GenHlth,Age,Education,Income,HighBP,BMI,HighChol,DiffWalk,HeartDiseaseorAttack,PhysHlth,HvyAlcoholConsump,Sex,CholCheck
0,category_2,category_11,category_5,category_7,1,27,1,0,0,0,1,1,1



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,GenHlth,Age,Education,Income,HighBP,BMI,HighChol,DiffWalk,HeartDiseaseorAttack,PhysHlth,HvyAlcoholConsump,Sex,CholCheck
0,category_2,category_7,category_5,category_7,1,34,1,0,0,0,1,1,1



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,GenHlth,Age,Education,Income,HighBP,BMI,HighChol,DiffWalk,HeartDiseaseorAttack,PhysHlth,HvyAlcoholConsump,Sex,CholCheck
0,category_2,category_7,category_5,category_7,1,27,1,0,0,21,1,1,1



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,GenHlth,Age,Education,Income,HighBP,BMI,HighChol,DiffWalk,HeartDiseaseorAttack,PhysHlth,HvyAlcoholConsump,Sex,CholCheck
0,category_2,category_8,category_5,category_6,0,30,1,0,0,30,1,1,1



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************


In [6]:
display_tabular_cfs(baseline_cfs, model, x, desired_class, scaler, encoder, cols)

Original: 


Unnamed: 0,GenHlth,Age,Education,Income,HighBP,BMI,HighChol,DiffWalk,HeartDiseaseorAttack,PhysHlth,HvyAlcoholConsump,Sex,CholCheck
0,category_2,category_7,category_5,category_7,0,27,0,0,0,0,1,1,1



Kept counterfactual generation: 


Unnamed: 0,GenHlth,Age,Education,Income,HighBP,BMI,HighChol,DiffWalk,HeartDiseaseorAttack,PhysHlth,HvyAlcoholConsump,Sex,CholCheck
0,category_3,category_11,category_4,category_3,1,28,1,0,0,20,0,0,1



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
