In [1]:
from util import *

from src.dataset import load_adult
from src.counterfactual import get_baseline_counterfactuals
import joblib

import warnings
warnings.filterwarnings('ignore')
                        
# Get a model
model, encoder, scaler = joblib.load('models/adult.gz') # Model should have the BlackBox interface
model

TabularModel(
  (lin1): Linear(in_features=12, out_features=200, bias=True)
  (lin2): Linear(in_features=200, out_features=50, bias=True)
  (lin3): Linear(in_features=50, out_features=2, bias=True)
  (bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drops): Dropout(p=0.3, inplace=False)
)

In [2]:
loader = iter(load_adult(100, train=False))
X_corpus, _ = next(loader)
X_test, _ = next(loader)

simplex = get_simplex(model, X_corpus, X_test, verbose = True)

Weight Fitting Epoch: 2000/10000 ; Error: 6.72 ; Regulator: 39.9 ; Reg Factor: 1
Weight Fitting Epoch: 4000/10000 ; Error: 2.12 ; Regulator: 5.96 ; Reg Factor: 1
Weight Fitting Epoch: 6000/10000 ; Error: 1.36 ; Regulator: 1.63 ; Reg Factor: 1
Weight Fitting Epoch: 8000/10000 ; Error: 1.15 ; Regulator: 0.544 ; Reg Factor: 1
Weight Fitting Epoch: 10000/10000 ; Error: 1.08 ; Regulator: 0.197 ; Reg Factor: 1


In [3]:
i = 40
n_cfs = 5

x = simplex.test_examples[i:i+1]
desired_class = model(x).topk(2).indices[0,1]
cat_indices = list(range(len(encoder.cols)))

baseline_cfs = get_baseline_counterfactuals(model = model, target = desired_class, test = x, \
                                            corpus = X_corpus, n_counterfactuals = n_cfs)

cfs = simplex.get_counterfactuals(test_id = i, model = model, n_counterfactuals = n_cfs, cat_indices = cat_indices, \
                                  mask = torch.Tensor([1,1,1,0,0,0,0,0,0,0,1,0]))

 19%|███████▊                                 | 19/100 [00:00<00:00, 548.51it/s]


In [4]:
cols = ['workclass', 'education', 'marital-status', 'occupation', \
       'relationship', 'race', 'gender', 'native-country', 'capital-gain', \
       'capital-loss', 'hours-per-week', 'age']

display_tabular_cfs(cfs, model, x, desired_class, scaler, encoder, cols)

Original: 


Unnamed: 0,workclass,education,marital-status,occupation,relationship,race,gender,native-country,capital-gain,capital-loss,hours-per-week,age
0,Self-emp-not-inc,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,United-States,0,0,40,33



Kept counterfactual generation: 


Unnamed: 0,workclass,education,marital-status,occupation,relationship,race,gender,native-country,capital-gain,capital-loss,hours-per-week,age
0,Self-emp-not-inc,Bachelors,Married-civ-spouse,Craft-repair,Husband,White,Male,United-States,0,0,40,33



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,workclass,education,marital-status,occupation,relationship,race,gender,native-country,capital-gain,capital-loss,hours-per-week,age
0,Self-emp-not-inc,Bachelors,Married-civ-spouse,Craft-repair,Husband,White,Male,United-States,0,0,45,33



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,workclass,education,marital-status,occupation,relationship,race,gender,native-country,capital-gain,capital-loss,hours-per-week,age
0,Federal-gov,Masters,Married-civ-spouse,Craft-repair,Husband,White,Male,United-States,0,0,40,33



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
Kept counterfactual generation: 


Unnamed: 0,workclass,education,marital-status,occupation,relationship,race,gender,native-country,capital-gain,capital-loss,hours-per-week,age
0,Federal-gov,Masters,Married-civ-spouse,Craft-repair,Husband,White,Male,United-States,0,0,43,33



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************


In [5]:
display_tabular_cfs(baseline_cfs, model, x, desired_class, scaler, encoder, cols)

Original: 


Unnamed: 0,workclass,education,marital-status,occupation,relationship,race,gender,native-country,capital-gain,capital-loss,hours-per-week,age
0,Self-emp-not-inc,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,United-States,0,0,40,33



Kept counterfactual generation: 


Unnamed: 0,workclass,education,marital-status,occupation,relationship,race,gender,native-country,capital-gain,capital-loss,hours-per-week,age
0,Local-gov,Some-college,Married-civ-spouse,Exec-managerial,Husband,White,Male,United-States,0,0,40,52



Predicted:  tensor(1)  ||  Desired:  tensor(1)  ||  Orginal:  tensor(0)
************************************************************************************************************************
