In [1]:
from datetime import datetime
import time
import warnings
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier, HistGradientBoostingClassifier
from sklearn import metrics
from sklearn.neighbors import KDTree

import shap
import xgboost as xgb
import pickle
import graphviz
import dice_ml
from scipy.optimize import fsolve
from dice_ml.utils import helpers
from tableone import TableOne, load_dataset
from IPython.display import Latex

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
shap.initjs()
pd.set_option('display.max_columns', None)
warnings.filterwarnings('ignore')

In [3]:
data_raw = pd.read_stata('MIMIC-SAD_dta_files/MIMIC-IV.dta')
data = data_raw.drop(['deliriumtime', 'hosp_mort', 'icu28dmort', 'stay_id', 'icustay', 'hospstay', 'sepsistime'], axis=1).dropna()
dummies = pd.get_dummies(data['race'])
data = data.drop('race',axis=1).join(dummies)
dummies = pd.get_dummies(data['first_careunit'])
data = data.drop('first_careunit',axis=1).join(dummies)

xgb_matrix = xgb.DMatrix(data.drop(['sad'], axis=1))

model = pickle.load(open("xgb.pkl", "rb"))

In [4]:
is_categorical = ['gender', 'vent', 'crrt', 'vaso', 'seda', 'ami', 'ckd', 'copd', 'hyperte', 'dm', 'aki', 'stroke', 'AISAN', 'BLACK', 'HISPANIC', 'OTHER', 'WHITE', 'unknown', 'CCU', 'CVICU', 'MICU', 'MICU/SICU', 'NICU', 'SICU', 'TSICU']
data[:50]

Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,sad,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU
0,44.0,79.0,0,37.0,100.0,28.0,98.0,107.0,66.0,75.0,8.5,12.9,268.0,12.0,0.9,102.0,138.0,105.0,3.5,2.2,7.8,3.4,1.3,14.5,37.400002,25.0,12.0,15.0,0.0,0.0,1.0,0.0,3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1,0,0,0,0,0,1,0,0,0,0,0,0
1,56.0,119.300003,0,36.720001,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,15.0,0.0,0.0,1.0,0.0,8,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,1,0,0,0,1,0,0,0,0
4,76.0,77.599998,1,36.720001,59.0,21.0,97.0,107.0,90.0,94.0,9.4,11.0,280.0,10.0,0.5,123.0,136.0,100.0,3.3,1.5,9.1,3.6,1.0,11.3,24.9,24.0,15.0,15.0,0.0,0.0,0.0,0.0,3,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,1,0,0,0,0,0,0,0,0,0,0,1
5,83.0,72.0,0,36.330002,109.0,16.0,100.0,111.0,63.0,79.0,4.8,13.3,307.0,62.0,2.8,108.0,136.0,108.0,3.6,2.1,6.4,4.1,1.4,16.200001,26.9,18.0,14.0,15.0,1.0,0.0,1.0,1.0,3,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0,0,0,0,1,0,0,0,0,0,0,1,0
6,57.0,77.5,0,38.669998,101.0,23.0,99.0,130.0,84.0,93.0,17.200001,15.1,261.0,25.0,1.0,100.0,138.0,105.0,4.3,2.0,8.5,4.0,1.2,13.5,33.799999,21.0,16.0,13.0,1.0,0.0,0.0,1.0,3,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0,0,0,0,1,0,0,0,1,0,0,0,0
7,78.0,69.699997,0,36.169998,91.0,24.0,93.0,103.0,66.0,75.0,18.700001,10.7,122.0,11.0,1.0,89.0,136.0,108.0,3.9,1.6,7.1,2.9,2.0,21.5,40.700001,20.0,12.0,15.0,1.0,0.0,1.0,1.0,4,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1,0,0,0,0,0,0,0,0,0,0,1,0
8,85.0,106.099998,0,34.439999,90.0,20.0,92.0,148.0,76.0,100.0,7.8,12.1,263.0,63.0,3.2,145.0,142.0,98.0,5.2,1.8,8.4,9.6,1.9,20.1,29.0,19.0,30.0,15.0,1.0,0.0,1.0,1.0,4,1.0,0.0,0.0,1.0,1.0,1.0,1.0,0.0,0,0,0,0,0,1,1,0,0,0,0,0,0
9,61.0,100.900002,0,37.5,81.0,13.0,99.0,102.0,78.0,90.0,16.200001,10.3,162.0,14.0,0.7,165.0,133.0,105.0,5.0,2.6,8.1,2.7,1.2,13.6,25.9,23.0,13.0,15.0,1.0,0.0,1.0,1.0,3,0.0,0.0,0.0,1.0,1.0,1.0,0.0,1.0,0,0,0,0,1,0,0,1,0,0,0,0,0
11,87.0,76.199997,1,36.110001,87.0,16.0,100.0,107.0,43.0,59.0,10.9,11.1,332.0,37.0,1.4,100.0,137.0,96.0,3.9,2.2,8.8,4.2,2.8,28.700001,38.0,27.0,18.0,15.0,0.0,0.0,1.0,0.0,2,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0,0,0,1,0,1,0,0,0,0,0,0
13,64.0,80.0,0,37.939999,91.0,14.0,100.0,70.0,42.0,49.0,15.6,11.3,150.0,18.0,1.4,134.0,142.0,110.0,4.5,2.4,11.5,3.2,1.1,13.3,28.1,23.0,11.0,15.0,1.0,0.0,1.0,0.0,8,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0,0,0,0,1,0,0,1,0,0,0,0,0


In [5]:
data_no_target = data.drop('sad',axis=1)

In [6]:
data_no_target.dtypes

age            float32
weight         float32
gender            int8
temperature    float32
heart_rate     float32
resp_rate      float32
spo2           float32
sbp            float32
dbp            float32
mbp            float32
wbc            float32
hemoglobin     float32
platelet       float32
bun            float32
cr             float32
glu            float32
Na             float32
Cl             float32
K              float32
Mg             float32
Ca             float32
P              float32
inr            float32
pt             float32
ptt            float32
bicarbonate    float32
aniongap       float64
gcs            float64
vent           float32
crrt           float32
vaso           float32
seda           float32
sofa_score        int8
ami            float32
ckd            float32
copd           float32
hyperte        float32
dm             float32
aki            float32
stroke         float32
AISAN            uint8
BLACK            uint8
HISPANIC         uint8
OTHER      

In [7]:
xgb_pred = np.where(model.predict(xgb_matrix) > 0.5, 1, 0)
np.count_nonzero((xgb_pred == data['sad']).astype(int))

8436

# KDTree sklearn - Euclidian

In [8]:
model.predict(xgb.DMatrix(data.iloc[[1]].drop(['sad'], axis=1)))

array([0.36970806], dtype=float32)

In [9]:
tree = KDTree(data.drop('sad',axis=1))

In [10]:
class KDTreeCounterFactual:
    def __init__(self, data, model):
        """
        :param data: data, *exclusief* target
        :param model: XGBoost model op basis waarvan we counterfactuals genereren
        """
        self.tree = KDTree(data)
        self.model = model
        

    def generate(self, X, n):
        """
        :param X: case waar we counterfactuals voor genereren; enkele rij in een pd df
        :param n: aantal counterfactuals
        """
        pred = self.model.predict(xgb.DMatrix(X))
        j = 1
        while True:
            dst, i = self.tree.query(X, k=(n*10)**j, return_distance=True)
            d = data_no_target.iloc[i[0]]
            d['reg'] = self.model.predict(xgb.DMatrix(d))
            d['pred'] = (d['reg'] > 0.5)
            d['dst'] = dst[0]
            if np.count_nonzero(d['pred']==(not (pred>0.5))) >= n:
                break
            j += 1
        return d[d['pred']==(not (pred>0.5))][:n]

In [11]:
model.predict(xgb.DMatrix(data.iloc[[1]].drop(['sad'], axis=1)))

array([0.36970806], dtype=float32)

In [12]:
x = data.iloc[[1]]
x

Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,sad,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU
1,56.0,119.300003,0,36.720001,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,15.0,0.0,0.0,1.0,0.0,8,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,1,0,0,0,1,0,0,0,0


In [84]:
cf_kdtree = KDTreeCounterFactual(data_no_target, model)
start = time.time()
cfs = cf_kdtree.generate(data_no_target.iloc[[1]], 50)
print(time.time() - start)
cfs

0.042379140853881836


Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU,reg,pred,dst
2912,62.0,103.0,0,33.200001,78.0,16.0,97.0,89.0,51.0,58.0,10.9,8.1,55.0,101.0,8.6,70.0,138.0,105.0,4.9,2.9,7.9,6.6,2.5,27.200001,45.900002,17.0,21.0,4.0,0.0,0.0,1.0,0.0,8,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,1,0,0,0,1,0,0,0,0,0.870129,True,50.399212
6348,69.0,104.0,1,36.5,93.0,31.0,93.0,71.0,43.0,50.0,0.9,11.5,37.0,38.0,3.0,101.0,136.0,103.0,6.1,2.3,8.9,9.7,1.6,18.1,37.400002,15.0,28.0,12.0,1.0,1.0,1.0,1.0,3,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0,0,0,0,0,1,0,0,0,0,0,0,1,0.841014,True,52.857719
12476,54.0,143.5,0,36.610001,82.0,8.0,98.0,83.0,42.0,49.0,7.1,9.5,36.0,68.0,2.3,118.0,135.0,101.0,5.2,2.3,8.6,6.0,3.1,34.799999,47.200001,26.0,13.0,12.0,0.0,0.0,1.0,0.0,4,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,1,0,0,0,1,0,0,0,0,0.681989,True,54.418489
6652,62.0,94.5,1,36.560001,92.0,26.0,99.0,89.0,43.0,55.0,6.7,7.2,67.0,43.0,2.7,68.0,133.0,107.0,4.6,1.9,7.1,4.2,2.2,22.6,38.599998,17.0,14.0,15.0,1.0,0.0,1.0,1.0,6,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0,0,0,0,1,0,0,0,1,0,0,0,0,0.650931,True,57.073599
792,75.0,105.900002,0,36.560001,83.0,20.0,100.0,103.0,66.0,77.0,1.1,9.3,61.0,76.0,2.4,108.0,139.0,101.0,4.4,2.6,7.7,4.5,1.4,15.5,21.4,23.0,15.0,15.0,1.0,0.0,0.0,0.0,5,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0,0,0,0,1,0,1,0,0,0,0,0,0,0.645523,True,61.315624
9934,78.0,96.599998,0,35.830002,64.0,16.0,99.0,87.0,60.0,69.0,9.7,7.7,61.0,93.0,5.4,63.0,146.0,109.0,5.8,2.0,8.9,8.5,1.4,15.0,33.900002,20.0,17.0,7.0,1.0,1.0,1.0,0.0,10,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,1,0,0,0,1,0,0,0,0,0.884454,True,61.651458
311,54.0,73.5,1,36.939999,64.0,18.0,99.0,97.0,52.0,65.0,5.0,7.6,55.0,80.0,3.3,106.0,136.0,103.0,3.8,2.6,7.9,7.3,2.3,25.200001,41.5,15.0,22.0,12.0,0.0,0.0,0.0,0.0,3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,1,0,0,0,0,0,0,1,0,0.667059,True,65.47174
10610,63.0,106.300003,1,37.779999,95.0,22.0,100.0,81.0,52.0,59.0,2.1,6.7,41.0,16.0,1.2,85.0,134.0,102.0,4.5,2.1,6.9,4.7,2.1,22.200001,60.0,9.0,28.0,15.0,1.0,0.0,1.0,1.0,3,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,0,1,0,0,0,1,0,0,0,0.755925,True,66.169053
5650,70.0,83.400002,1,37.389999,90.0,29.0,90.0,104.0,43.0,55.0,8.5,8.2,77.0,64.0,2.4,75.0,136.0,103.0,4.8,2.2,7.7,6.3,1.9,20.4,34.799999,19.0,14.0,15.0,0.0,1.0,1.0,0.0,9,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0,0,0,0,0,1,0,0,0,1,0,0,0,0.588144,True,67.497919
7040,60.0,88.150002,0,36.5,67.0,11.0,79.0,108.0,60.0,76.0,9.1,7.7,14.0,82.0,2.6,71.0,130.0,97.0,5.3,2.8,9.6,5.6,3.5,37.700001,64.599998,17.0,16.0,14.0,1.0,0.0,1.0,1.0,11,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0,0,0,0,0,1,0,0,1,0,0,0,0,0.834421,True,67.57478


In [14]:
X = data_no_target.iloc[[1]]
j = 1
n = 1

pred = model.predict(xgb.DMatrix(X))

dst, i = tree.query(X, k=n*j*10, return_distance=True)
d = data_no_target.iloc[i[0]]
print(model.predict(xgb.DMatrix(d)))
d['reg'] = model.predict(xgb.DMatrix(d))
print((d['reg'] > 0.5))
d['pred'] = (d['reg'] > 0.5)
print(dst[0])
d[d['pred']==(not (pred>0.5))][:2]

[0.36970806 0.24656451 0.45233867 0.8701292  0.8410143  0.6819893
 0.41461933 0.35539132 0.30063853 0.42608678]
1        False
13735    False
4477     False
2912      True
6348      True
12476     True
7861     False
12615    False
9830     False
10906    False
Name: reg, dtype: bool
[ 0.         36.54837399 39.1571644  50.39921151 52.85771926 54.41848888
 54.67960276 55.09871653 55.26797535 55.92425837]


Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU,reg,pred
2912,62.0,103.0,0,33.200001,78.0,16.0,97.0,89.0,51.0,58.0,10.9,8.1,55.0,101.0,8.6,70.0,138.0,105.0,4.9,2.9,7.9,6.6,2.5,27.200001,45.900002,17.0,21.0,4.0,0.0,0.0,1.0,0.0,8,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,1,0,0,0,1,0,0,0,0,0.870129,True
6348,69.0,104.0,1,36.5,93.0,31.0,93.0,71.0,43.0,50.0,0.9,11.5,37.0,38.0,3.0,101.0,136.0,103.0,6.1,2.3,8.9,9.7,1.6,18.1,37.400002,15.0,28.0,12.0,1.0,1.0,1.0,1.0,3,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0,0,0,0,0,1,0,0,0,0,0,0,1,0.841014,True


# Genetic

In [148]:
class GeneticCounterfactual:
    def __init__(self, data, model, is_categorical, limit=10, population_size=10000):
        self.data = data
        self.lim_min = np.array(data.min(axis=0), dtype=np.float32)
        self.lim_max = np.array(data.max(axis=0), dtype=np.float32)
        self.model = model
        self.is_categorical = data.columns.isin(is_categorical)
        self.population = None
        self.population_size = population_size
        self.limit = limit
        
    def generate(self, X, n):
        target = (self.model.predict(xgb.DMatrix(X)) > .5)[0]
        self._populate()
        self._cull(np.array(X), target)
        for i in range(self.limit):
            self._repopulate()
            self._cull(np.array(X), target)
            
        fitness = np.apply_along_axis(lambda x: self._fitness_euclidian(np.array(X), x), 1, self.population[:n])
        res = pd.DataFrame(data=self.population[:n], columns=self.data.columns)
        res['reg'] = self.model.predict(xgb.DMatrix(res))
        res['pred'] = (res['reg'] > 0.5)
        res['fitness'] = fitness  # optioneel; voegt een kolom met de resultaten van de fitness functie toe
        return res
        
    def _populate(self):
        population = np.random.rand(self.population_size, self.data.shape[1])
        for i in range(population.shape[1]):
            population[:,i] = population[:,i] * (self.lim_max[i] - self.lim_min[i]) + self.lim_min[i]
            if self.is_categorical[i]:
                population[:,i] = population[:,i].round()
        self.population = population

    def _cull(self, X, target):
        pred = self.model.predict(xgb.DMatrix(pd.DataFrame(data=self.population, columns=self.data.columns)))
        self.population = self.population[np.invert((pred > .5) == target)]
        if self.population.shape[0] > (self.population_size/2):
            self._sort(X)
            self.population = self.population[:int(self.population_size/2),:]

    def _sort(self, X):
        fitness = np.apply_along_axis(lambda x: self._fitness_euclidian(X, x), 1, self.population)  # hier kan de fitness functie aangepast worden.
        self.population = self.population[np.argsort(fitness)]

    def _repopulate(self):
        while self.population.shape[0] <= self.population_size/2:
            pop2 = self.population.copy()
            for i in range(pop2.shape[1]):
                if random.randint(0, 1):
                    pop2[:,i] = np.roll(pop2[:,i], 1)
            self.population = np.concatenate([self.population, pop2])
        
        if self.population.shape[0] < self.population_size:
            pop2 = self.population.copy()[:(self.population_size - self.population.shape[0])]
            for i in range(pop2.shape[1]):
                if random.randint(0, 1):
                    pop2[:,i] = np.roll(pop2[:,i], 1)
            self.population = np.concatenate([self.population, pop2])
            
    def _fitness_euclidian(self, a, b):
        return np.sqrt(np.sum((a-b)**2))

    def _fitness_euclidian_relative(self, a, b):
        return np.sqrt(np.sum(((np.abs(a-b)-self.lim_min)/self.lim_max)**2))

In [151]:
is_categorical = ['gender', 'vent', 'crrt', 'vaso', 'seda', 'ami', 'ckd', 'copd', 'hyperte', 'dm', 'aki', 'stroke', 'AISAN', 'BLACK', 'HISPANIC', 'OTHER', 'WHITE', 'unknown', 'CCU', 'CVICU', 'MICU', 'MICU/SICU', 'NICU', 'SICU', 'TSICU']
cf_g = GeneticCounterfactual(data_no_target, model, is_categorical, limit=10, population_size=data.shape[0])

In [146]:
print(model.predict(xgb.DMatrix(data.drop(['sad'], axis=1).iloc[[1]])))
data_no_target.iloc[[1]]

[0.36970806]


Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU
1,56.0,119.300003,0,36.720001,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,15.0,0.0,0.0,1.0,0.0,8,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,1,0,0,0,1,0,0,0,0


In [152]:
cf_g.limit = 10
start = time.time()
cfs = cf_g.generate(data_no_target.iloc[[1]], 50)
print(time.time() - start)
cfs

init
9727
aaaaaaaaaaaaaaaaaa
11196
1469
bbbbbbbbbbbbbbbbbb
9727
0
cull1
5598
pop 0
396
aaaaaaaaaaaaaaaaaa
11196
396
bbbbbbbbbbbbbbbbbb
10800
0
cull 0
0
pop 1
325
aaaaaaaaaaaaaaaaaa
11196
325
bbbbbbbbbbbbbbbbbb
10871
0
cull 1
0
pop 2
244
aaaaaaaaaaaaaaaaaa
11196
244
bbbbbbbbbbbbbbbbbb
10952
0
cull 2
0
pop 3
172
aaaaaaaaaaaaaaaaaa
11196
172
bbbbbbbbbbbbbbbbbb
11024
0
cull 3
0
pop 4
167
aaaaaaaaaaaaaaaaaa
11196
167
bbbbbbbbbbbbbbbbbb
11029
0
cull 4
0
pop 5
128
aaaaaaaaaaaaaaaaaa
11196
128
bbbbbbbbbbbbbbbbbb
11068
0
cull 5
0
pop 6
124
aaaaaaaaaaaaaaaaaa
11196
124
bbbbbbbbbbbbbbbbbb
11072
0
cull 6
0
pop 7
137
aaaaaaaaaaaaaaaaaa
11196
137
bbbbbbbbbbbbbbbbbb
11059
0
cull 7
0
pop 8
100
aaaaaaaaaaaaaaaaaa
11196
100
bbbbbbbbbbbbbbbbbb
11096
0
cull 8
0
pop 9
98
aaaaaaaaaaaaaaaaaa
11196
98
bbbbbbbbbbbbbbbbbb
11098
0
cull 9
0
3.843332052230835


Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU,reg,pred,fitness
0,58.577962,113.624846,0.0,34.951167,74.580085,35.053535,96.34815,72.241807,55.259217,66.879928,24.716751,18.942647,37.476165,58.164771,12.983727,92.713465,139.271245,110.78957,2.898821,6.888783,4.389141,2.944419,3.982453,24.191382,40.551759,17.874432,11.017591,6.318115,1.0,0.0,0.0,1.0,10.868989,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.751381,True,36.197918
1,63.007506,129.297629,0.0,41.219304,92.268798,18.705239,75.45197,88.135189,44.591121,53.142199,24.307806,8.505852,31.278909,81.032661,5.745098,92.093705,132.87566,98.641524,3.167822,1.163521,5.831815,7.756647,4.532616,19.144966,19.967541,14.955623,5.803261,14.172035,0.0,0.0,1.0,0.0,8.748531,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.63555,True,41.430031
2,72.562883,125.468386,1.0,32.891041,82.6518,13.934981,82.586108,88.135189,50.457232,55.655845,24.307806,5.580972,47.561566,81.032661,15.566466,82.573704,139.271245,110.78957,7.996217,4.781649,8.617942,10.107025,5.545891,32.414443,43.288796,9.282947,17.402186,8.089513,0.0,1.0,0.0,0.0,14.759312,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.632928,True,41.988929
3,58.559401,113.897127,1.0,34.993398,82.6518,29.385158,96.34815,88.135189,50.457232,55.655845,12.157465,5.580972,47.561566,91.690795,17.090153,72.995814,139.271245,110.78957,7.996217,4.781649,8.617942,13.021808,5.545891,16.277258,43.288796,18.029622,17.402186,3.283355,0.0,1.0,0.0,0.0,14.759312,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.574373,True,43.058453
4,58.559401,113.897127,1.0,34.993398,82.741106,29.385158,96.34815,82.672108,41.227676,55.477575,12.157465,9.377761,31.687369,91.690795,17.090153,72.995814,137.51317,102.426843,5.595268,2.522643,11.163307,13.021808,2.522757,16.277258,47.447301,18.029622,27.168613,3.283355,0.0,1.0,0.0,0.0,2.469833,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.696969,True,43.632553
5,52.428709,115.179786,0.0,32.810758,74.580085,8.648401,71.409703,72.241807,55.259217,66.879928,24.307806,18.942647,37.476165,51.079822,14.60776,74.865678,139.271245,110.78957,2.898821,6.888783,4.389141,5.367935,3.982453,24.933681,40.551759,27.854064,11.017591,7.703526,0.0,0.0,0.0,1.0,10.868989,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.529168,True,44.488228
6,58.577962,113.624846,1.0,34.951167,74.761323,35.053535,96.34815,74.270075,41.227676,41.392714,24.716751,19.474793,42.570926,58.164771,12.983727,92.713465,128.246338,100.559819,2.898821,6.87292,3.898836,2.944419,4.838611,24.191382,47.447301,17.874432,29.358594,6.318115,1.0,0.0,0.0,0.0,10.002411,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.903306,True,45.285544
7,52.194443,113.624846,0.0,41.510952,71.53369,35.053535,75.531898,93.479221,50.62044,74.579956,19.14586,11.25187,53.998241,52.270272,4.018241,86.09244,123.682316,110.78957,5.036888,6.888783,13.878012,6.172328,3.982453,22.981422,44.571114,11.570289,11.017591,12.626669,0.0,0.0,0.0,1.0,11.250209,0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.721818,True,45.396209
8,56.022922,128.267135,0.0,38.385365,71.53369,22.753458,67.022073,93.479221,50.62044,74.579956,7.723061,11.25187,53.998241,73.456914,14.60776,91.257935,123.682316,110.78957,5.036888,6.888783,13.878012,3.36012,3.982453,24.933681,44.571114,17.037058,11.017591,6.571531,1.0,1.0,0.0,1.0,11.250209,0.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.628699,True,46.636193
9,56.022922,128.267135,1.0,38.385365,74.761323,22.753458,67.022073,74.270075,41.227676,41.392714,7.723061,19.474793,42.570926,73.456914,14.60776,91.257935,128.246338,100.559819,2.898821,6.87292,3.898836,3.36012,4.838611,24.933681,47.447301,17.037058,29.358594,6.571531,1.0,1.0,0.0,0.0,10.002411,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.796207,True,46.677675


In [None]:
model.predict(xgb.DMatrix(cfs.drop(['fitness'],axis=1)))

In [21]:
is_categorical = ['gender', 'ami', 'ckd', 'copd', 'hyperte', 'dm', 'aki', 'stroke', 'AISAN', 'BLACK', 'HISPANIC', 'OTHER', 'WHITE', 'unknown', 'CCU', 'CVICU', 'MICU', 'MICU/SICU', 'NICU', 'SICU', 'TSICU']

lim_min = np.array(data.min(axis=0))
lim_max = np.array(data.max(axis=0))

np.sqrt(np.sum((lim_min-lim_max)**2))

1460.518631981596

In [22]:
is_categorical = ['gender', 'ami', 'ckd', 'copd', 'hyperte', 'dm', 'aki', 'stroke', 'AISAN', 'BLACK', 'HISPANIC', 'OTHER', 'WHITE', 'unknown', 'CCU', 'CVICU', 'MICU', 'MICU/SICU', 'NICU', 'SICU', 'TSICU']
is_categorical = data_no_target.columns.isin(is_categorical)
population = np.random.rand(data_no_target.shape[0], data_no_target.shape[1])
for i in range(population.shape[1]):
    population[:,i] = population[:,i] * (lim_max[i] - lim_min[i]) + lim_min[i]
    if is_categorical[i]:
        population[:,i] = population[:,i].round()
population[1]

array([ 55.93614756, 249.18761109,   0.        ,  31.21942301,
        73.251353  ,  36.02758701,  97.35780043,  72.92071341,
        46.37058076,  73.89042719,  53.27578879,   7.18090343,
       294.20374128, 118.7115833 ,  18.98159926, 348.23320919,
       158.06853139, 112.5464534 ,   4.36837016,   1.65078417,
         3.49622987,   2.93478562,   9.15424849,  71.60770226,
       114.48488651,  30.78039731,  43.48128318,   9.96474102,
         0.38661045,   0.93018928,   0.98221608,   0.82826723,
        11.01211375,   1.        ,   0.        ,   0.        ,
         0.        ,   1.        ,   0.        ,   1.        ,
         1.        ,   0.        ,   1.        ,   0.        ,
         0.        ,   1.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   1.        ,
         1.        ])

In [23]:
X = data_no_target.iloc[[1]]
model.predict(xgb.DMatrix(X))

array([0.36970806], dtype=float32)

In [30]:
def _fitness_euclidian(a, b):
        return np.sqrt(np.sum((a-b)**2))

_fitness_euclidian(data_no_target.iloc[1], population[1])


411.34580117519624

In [31]:
population.shape

(11196, 53)

In [32]:
a = np.array([[1,1],[2,2]])
b = np.array([[3,3],[4,4]])
np.concatenate([a,b])

array([[1, 1],
       [2, 2],
       [3, 3],
       [4, 4]])

In [33]:
# d = pd.DataFrame(data=population, columns=data_no_target.columns)
p = model.predict(xgb.DMatrix(pd.DataFrame(data=population, columns=data_no_target.columns)))
target = True
population2 = population[(p > 0.5)!=target]
population_size = population.shape[0]

print(population_size)
print(population2.shape)

while population2.shape[0] <= population_size/2:
    pop2 = population2.copy()
    for i in range(pop2.shape[1]):
        if random.randint(0, 1):
            pop2[:,i] = np.roll(pop2[:,i], 1)
    population2 = np.concatenate([population2, pop2])
    print(population2.shape)

if population2.shape[0] < population_size:
    pop2 = population2.copy()[:(population_size - population2.shape[0])]
    for i in range(pop2.shape[1]):
        if random.randint(0, 1):
            pop2[:,i] = np.roll(pop2[:,i], 1)
    population2 = np.concatenate([population2, pop2])
    print(population2.shape)

11196
(1959, 53)
(3918, 53)
(7836, 53)
(11196, 53)


In [34]:
is_categorical

['gender',
 'vent',
 'crrt',
 'vaso',
 'seda',
 'ami',
 'ckd',
 'copd',
 'hyperte',
 'dm',
 'aki',
 'stroke',
 'AISAN',
 'BLACK',
 'HISPANIC',
 'OTHER',
 'WHITE',
 'unknown',
 'CCU',
 'CVICU',
 'MICU',
 'MICU/SICU',
 'NICU',
 'SICU',
 'TSICU']

# Genetic - feature select

In [221]:
class GeneticCounterfactual:
    def __init__(self, data, model, is_categorical, use_feats=None, limit=10, population_size=10000):
        if use_feats == None:
            use_feats = data.columns
        self.data = data[use_feats]
        self.data_raw = data
        self.use_feats = use_feats
        self.use_feats_ind = data.columns.isin(use_feats)
        self.lim_min = np.array(self.data.min(axis=0))
        self.lim_max = np.array(self.data.max(axis=0))
        self.is_categorical = self.data.columns.isin(is_categorical)
        
        self.model = model
        self.population = None
        self.population_size = population_size
        self.limit = limit
        
    def generate(self, X, n):
        target = (self.model.predict(xgb.DMatrix(X)) > .5)[0]
        X = np.array(X)

        pred_template = np.zeros((self.population_size, self.data_raw.shape[1]))
        pred_template[:,np.where(np.invert(self.use_feats_ind))[0]] = X[0][np.where(np.invert(self.use_feats_ind))]

        X = X[0][np.where(self.use_feats_ind)]
        
        self._populate()
        self._cull(X, target, pred_template)
        for i in range(self.limit):
            self._repopulate()
            self._cull(X, target, pred_template)
            
        fitness = np.apply_along_axis(lambda x: self._fitness_euclidian(X, x), 1, self.population[:n])
        pred_template[:self.population.shape[0],np.where(self.use_feats_ind)[0]] = self.population
        
        res = pd.DataFrame(data=pred_template[:n], columns=self.data_raw.columns)
        res['reg'] = self.model.predict(xgb.DMatrix(res))
        res['pred'] = (res['reg'] > 0.5)
        res['fitness'] = fitness  # optioneel; voegt een kolom met de resultaten van de fitness functie toe
        return res
    
    def _populate(self):
        population = np.random.rand(self.population_size, self.data.shape[1])
        for i in range(population.shape[1]):
            population[:,i] = population[:,i] * (self.lim_max[i] - self.lim_min[i]) + self.lim_min[i]
            if self.is_categorical[i]:
                population[:,i] = population[:,i].round()
        self.population = population

    def _cull(self, X, target, pred_template):
        pred_template[:,np.where(self.use_feats_ind)[0]] = self.population
        pred = self.model.predict(xgb.DMatrix(pd.DataFrame(data=pred_template, columns=self.data_raw.columns)))
        self.population = self.population[np.invert((pred > .5) == target)]
        if self.population.shape[0] > (self.population_size / 2):
            self._sort(X)
            self.population = self.population[:int(self.population_size / 2), :]
        if self.population.shape[0]==0:
            raise Exception("geen mogelijke counterfactuals met deze parameters!") 

    def _sort(self, X):
        fitness = np.apply_along_axis(lambda x: self._fitness_euclidian(X, x), 1, self.population)  # hier kan de fitness functie aangepast worden.
        self.population = self.population[np.argsort(fitness)]

    def _repopulate(self):
        while self.population.shape[0] <= self.population_size/2:
            pop2 = self.population.copy()
            for i in range(pop2.shape[1]):
                if random.randint(0, 1):
                    pop2[:,i] = np.roll(pop2[:,i], 1)
            self.population = np.concatenate([self.population, pop2])
        
        if self.population.shape[0] < self.population_size:
            pop2 = self.population.copy()[:(self.population_size - self.population.shape[0])]
            for i in range(pop2.shape[1]):
                if random.randint(0, 1):
                    pop2[:,i] = np.roll(pop2[:,i], 1)
            self.population = np.concatenate([self.population, pop2])
            
    def _fitness_euclidian(self, a, b):
        return np.sqrt(np.sum((a-b)**2))

    def _fitness_euclidian_relative(self, a, b):
        return np.sqrt(np.sum(((np.abs(a-b)-self.lim_min)/self.lim_max)**2))

In [222]:
patient_a = data.iloc[[1]].drop(['sad'], axis=1)
print(model.predict(xgb.DMatrix(patient_a)))
patient_b = patient_a.copy()
patient_b['age'] = 56.060580
patient_b['weight'] = 119.161171
patient_b['temperature'] = 36.629207
patient_b['gcs'] = 14.927048
print(model.predict(xgb.DMatrix(patient_b)))
print(patient_a)
print(patient_b)
data_no_target.iloc[[1]]

[0.36970806]
[0.60333914]
    age      weight  gender  temperature  heart_rate  resp_rate  spo2   sbp  \
1  56.0  119.300003       0    36.720001        82.0       22.0  90.0  75.0   

    dbp   mbp   wbc  hemoglobin  platelet   bun   cr   glu     Na     Cl    K  \
1  56.0  61.0  13.0         7.2      36.0  70.0  2.7  83.0  128.0  103.0  4.0   

    Mg   Ca    P  inr    pt        ptt  bicarbonate  aniongap   gcs  vent  \
1  2.2  7.0  4.5  2.1  22.4  38.400002         15.0      14.0  15.0   0.0   

   crrt  vaso  seda  sofa_score  ami  ckd  copd  hyperte   dm  aki  stroke  \
1   0.0   1.0   0.0           8  0.0  0.0   0.0      0.0  0.0  1.0     0.0   

   AISAN  BLACK  HISPANIC  OTHER  WHITE  unknown  CCU  CVICU  MICU  MICU/SICU  \
1      0      0         0      0      1        0    0      0     1          0   

   NICU  SICU  TSICU  
1     0     0      0  
        age      weight  gender  temperature  heart_rate  resp_rate  spo2  \
1  56.06058  119.161171       0    36.629207        82

Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU
1,56.0,119.300003,0,36.720001,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,15.0,0.0,0.0,1.0,0.0,8,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0,0,0,1,0,0,0,1,0,0,0,0


In [227]:
is_categorical = ['gender', 'vent', 'crrt', 'vaso', 'seda', 'ami', 'ckd', 'copd', 'hyperte', 'dm', 'aki', 'stroke', 'AISAN', 'BLACK', 'HISPANIC', 'OTHER', 'WHITE', 'unknown', 'CCU', 'CVICU', 'MICU', 'MICU/SICU', 'NICU', 'SICU', 'TSICU']
use_feats = ['age', 'weight', 'temperature', 'gcs']
cf_g = GeneticCounterfactualFeatureSelect(data_no_target, model, is_categorical, use_feats=use_feats, limit=3, population_size=data_no_target.shape[0])

start = time.time()
cfs = cf_g.generate(data_no_target.iloc[[1]], 10)
print(time.time() - start)
cfs

2.086761951446533


Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU,reg,pred,fitness
0,56.005251,119.269784,0.0,36.692435,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,14.978537,0.0,0.0,1.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.602751,True,0.04649
1,56.005251,119.246231,0.0,36.692435,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,14.978537,0.0,0.0,1.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.602751,True,0.06434
2,56.005251,119.269784,0.0,36.629115,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,14.978537,0.0,0.0,1.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.603339,True,0.098294
3,56.005251,119.246231,0.0,36.631133,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,14.938974,0.0,0.0,1.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.603339,True,0.120585
4,56.005251,119.246231,0.0,36.631133,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,14.938974,0.0,0.0,1.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.603339,True,0.120585
5,56.005251,119.365286,0.0,36.631133,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,14.901753,0.0,0.0,1.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.603339,True,0.147782
6,56.005251,119.458982,0.0,36.631133,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,14.938974,0.0,0.0,1.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.603339,True,0.192155
7,56.005251,119.458982,0.0,36.631133,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,14.938974,0.0,0.0,1.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.603339,True,0.192155
8,56.005251,119.458982,0.0,36.631133,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,14.938974,0.0,0.0,1.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.603339,True,0.192155
9,56.005251,119.365286,0.0,36.631133,82.0,22.0,90.0,75.0,56.0,61.0,13.0,7.2,36.0,70.0,2.7,83.0,128.0,103.0,4.0,2.2,7.0,4.5,2.1,22.4,38.400002,15.0,14.0,14.836773,0.0,0.0,1.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.603339,True,0.197054


In [73]:
is_categorical = ['gender', 'vent', 'crrt', 'vaso', 'seda', 'ami', 'ckd', 'copd', 'hyperte', 'dm', 'aki', 'stroke', 'AISAN', 'BLACK', 'HISPANIC', 'OTHER', 'WHITE', 'unknown', 'CCU', 'CVICU', 'MICU', 'MICU/SICU', 'NICU', 'SICU', 'TSICU']
cf_g = GeneticCounterfactual2(data_no_target, model, is_categorical)

cf_g.limit = 10
start = time.time()
cfs = cf_g.generate(data_no_target.iloc[[1]], 10)
print(time.time() - start)
cfs

2.173279285430908


Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU,fitness
0,50.397846,107.957032,0.0,30.542786,82.51446,25.884028,91.299977,70.374381,55.722989,71.388789,9.931664,17.581225,35.62915,47.352333,8.972173,81.926187,144.479824,108.728748,6.620368,2.402952,3.416028,6.77827,7.705976,21.846222,48.43688,30.854593,20.587574,12.567325,0.0,1.0,0.0,0.0,3.55145,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,42.540191
1,65.169346,109.338137,0.0,36.944427,86.784456,9.793065,73.030664,80.237835,71.057817,79.182455,16.712325,18.229984,41.03161,56.795586,8.972173,81.926187,127.139932,104.333975,7.90793,1.384668,12.887598,9.182749,3.408822,33.473855,48.43688,16.501388,12.259129,14.293087,0.0,1.0,0.0,0.0,9.689845,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,43.87913
2,54.623226,107.957032,0.0,36.267334,82.928809,18.607208,93.51164,68.166792,46.927794,71.388789,34.94104,9.927418,47.464468,74.717092,1.930549,93.769144,127.284131,117.981001,7.90793,4.871273,2.13187,0.846715,5.706103,11.477729,29.88875,33.710071,20.587574,7.886653,1.0,1.0,1.0,0.0,11.766395,1.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0,45.955176
3,66.348347,97.895668,0.0,34.870428,72.59934,16.495515,86.797319,70.374381,36.541281,48.914375,13.207927,17.581225,26.692501,70.499172,8.972173,93.944021,126.711798,111.384582,3.054336,0.611866,8.477781,6.691679,7.705976,20.616484,48.43688,30.854593,8.625009,7.347658,0.0,0.0,0.0,1.0,11.895065,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0,1.0,0.0,1.0,46.713889
4,59.36257,124.758894,1.0,37.265296,78.442072,35.027489,95.392742,75.942646,84.938256,47.456945,16.073593,14.341601,13.593062,67.283286,6.971336,93.702507,139.657939,104.043723,2.7087,2.856554,12.887598,6.190835,6.871417,15.630307,44.335897,13.121479,15.485406,11.400783,1.0,1.0,0.0,1.0,6.118917,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0,47.92651
5,66.348347,130.160387,1.0,36.267334,61.908767,8.169351,86.797319,76.832374,46.927794,47.456945,30.756467,7.988257,58.902485,74.717092,0.165713,93.769144,127.284131,106.643631,1.939806,0.611866,2.13187,5.945275,0.819593,11.477729,34.016938,10.713111,15.485406,7.886653,1.0,0.0,1.0,1.0,11.766395,0.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,0.0,48.687406
6,42.763009,120.649026,0.0,41.239252,88.459077,18.607208,93.51164,68.166792,67.521443,79.182455,4.323713,9.927418,22.360203,87.118192,1.930549,100.131601,126.711798,117.981001,5.643266,4.942806,14.396786,0.846715,5.706103,15.496178,29.88875,33.710071,12.259129,10.406265,1.0,0.0,1.0,0.0,10.613927,1.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,49.602939
7,71.636799,138.416239,1.0,40.969203,72.687715,16.495515,95.392742,69.761153,71.057817,59.609298,22.623232,13.06013,56.621511,70.919416,4.923046,77.99631,133.815822,77.510077,8.882551,4.192033,6.83653,1.554807,3.477784,29.656643,33.793169,14.67526,10.071832,11.769046,1.0,0.0,1.0,0.0,11.885897,1.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,49.610029
8,39.349549,107.957032,0.0,34.870428,82.51446,11.781939,91.299977,70.374381,71.057817,71.388789,9.931664,17.581225,35.62915,51.14275,8.972173,63.56694,126.711798,109.378509,6.620368,4.671441,12.887598,6.77827,7.705976,33.933606,48.43688,30.854593,20.587574,12.358182,0.0,0.0,0.0,0.0,2.253219,1.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,1.0,49.696884
9,53.410881,112.625713,0.0,31.756135,94.30276,12.332937,96.28192,80.237835,41.363331,77.492276,15.036496,18.229984,47.508737,51.14275,2.950416,94.093374,147.855689,114.059546,2.500823,3.04933,7.001134,14.685786,6.871417,33.933606,44.335897,16.501388,18.665,5.781388,0.0,1.0,1.0,0.0,3.681378,0.0,1.0,1.0,1.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,0.0,50.65334


In [186]:
a = np.array([[1,2,3],
              [4,5,6],
              [7,8,9]])
print(np.where([1,0,1,1,0,0,1]))
print(np.insert(a,[1,3],np.array([[5,5,5],[6,6,6]]).T,axis=1))

a[:,[1,2]] = np.array([[5,4,5],[6,4,6]]).T
a

(array([0, 2, 3, 6], dtype=int64),)
[[1 5 2 3 6]
 [4 5 5 6 6]
 [7 5 8 9 6]]


array([[1, 5, 6],
       [4, 4, 4],
       [7, 5, 6]])

# 'Random'

In [35]:
class SensitivityCounterfactual:
    def __init__(self, data, model, is_discrete, ):
        self.data = data
        self.model = model
        self.is_discrete = is_discrete

    def _replace_inplace(feat, X, x):
        X = X.copy()
        X[feat] = x
        return X

    def generate_single(self, X, feat):
        X = X.copy()
        x0 = X[feat][1]
        
        self.model.predict(xgb.DMatrix(X))

In [36]:
X = data_no_target.iloc[[5]]
feat = 'temperature'
x0 = X[feat].iloc[0]

def _replace_inplace(feat, X, x):
    X = X.copy()
    X[feat] = x
    return X

# model.predict(xgb.DMatrix(_replace_inplace(feat, X, 100)), )
x, infodict, ier, mesg = fsolve(lambda x: model.predict(xgb.DMatrix(_replace_inplace(feat, X, x)))-0.5, x0, full_output=True)
print(x)
print(infodict)
print(ier)
print(mesg)
# X

[1844.66990662]
{'nfev': 15, 'fjac': array([[1.]]), 'r': array([-0.]), 'qtf': array([0.05244273]), 'fvec': array([0.05244273])}
5
The iteration is not making good progress, as measured by the 
  improvement from the last ten iterations.


In [37]:
X[feat].iloc[0]

36.17

In [38]:
model.predict(xgb.DMatrix(_replace_inplace(feat, X, 69.7)))

array([0.5524427], dtype=float32)

In [39]:
model.predict(xgb.DMatrix(_replace_inplace(feat, X, 400)))

array([0.5524427], dtype=float32)

In [40]:
X

Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU
7,78.0,69.699997,0,36.169998,91.0,24.0,93.0,103.0,66.0,75.0,18.700001,10.7,122.0,11.0,1.0,89.0,136.0,108.0,3.9,1.6,7.1,2.9,2.0,21.5,40.700001,20.0,12.0,15.0,1.0,0.0,1.0,1.0,4,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1,0,0,0,0,0,0,0,0,0,0,1,0


In [41]:
_replace_inplace(['age', 'weight'], X, [2,3])

Unnamed: 0,age,weight,gender,temperature,heart_rate,resp_rate,spo2,sbp,dbp,mbp,wbc,hemoglobin,platelet,bun,cr,glu,Na,Cl,K,Mg,Ca,P,inr,pt,ptt,bicarbonate,aniongap,gcs,vent,crrt,vaso,seda,sofa_score,ami,ckd,copd,hyperte,dm,aki,stroke,AISAN,BLACK,HISPANIC,OTHER,WHITE,unknown,CCU,CVICU,MICU,MICU/SICU,NICU,SICU,TSICU
7,2,3,0,36.169998,91.0,24.0,93.0,103.0,66.0,75.0,18.700001,10.7,122.0,11.0,1.0,89.0,136.0,108.0,3.9,1.6,7.1,2.9,2.0,21.5,40.700001,20.0,12.0,15.0,1.0,0.0,1.0,1.0,4,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1,0,0,0,0,0,0,0,0,0,0,1,0


In [42]:
feat = ['age', 'weight']
x0 = X[feat].iloc[0]
np.array(x0)

array([78. , 69.7], dtype=float32)

In [45]:
X = data_no_target.iloc[[5]]
feat = ['age', 'weight']
x0 = X[feat].iloc[0]

def _replace_inplace_double(x, y, feat, X):
    X = X.copy()
    X[feat[0]] = x
    X[feat[1]] = y
    return X

def _replace_pred(x, y, feat, X):
    pass

# model.predict(xgb.DMatrix(_replace_inplace(feat, X, 100)), )
x, infodict, ier, mesg = fsolve(lambda x, y: model.predict(xgb.DMatrix(_replace_inplace(feat, X, x, y)))-0.5, x0, full_output=True)
print(x)
print(infodict)
print(ier)
print(mesg)
# X

TypeError: <lambda>() missing 1 required positional argument: 'y'

In [44]:
x0

age       78.000000
weight    69.699997
Name: 7, dtype: float32