In [1]:
import sys

sys.path.append("../../src")

import pickle
from tuneable_counterfactuals_explainer.explainer import Explainer

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
model = pickle.load(open("model/model.pkl", "rb"))
training_data = pickle.load(open("model/train_X.pkl", "rb"))

In [3]:
training_data.columns

Index(['age', 'bmi', 'weight_first', 'icu_los_day', 'hospital_los_day',
       'day_icu_intime_num', 'hour_icu_intime', 'map_1st', 'hr_1st',
       'temp_1st', 'spo2_1st', 'abg_count', 'wbc_first', 'hgb_first',
       'platelet_first', 'sodium_first', 'potassium_first', 'tco2_first',
       'chloride_first', 'bun_first', 'creatinine_first', 'prediction',
       'day_28_flg'],
      dtype='object')

In [4]:
explainer = Explainer(
    model,
    training_data,
    "target",
    bounding_method="quantile",
    number_samples=20,
    changeability_scores={
        # 'bmi': 0.5,
        "weight_first": 0.5,
        "platelet_first": 1,
        # 'sodium_first': 1,
        # 'potassium_first':1,
        "tco2_first": 1,
        "chloride_first": 1,
        "creatinine_first": 1,
    },
)

In [5]:
res = explainer.explain(
    training_data.iloc[0],
    store_historical_explainers=True,
    # additional_threshold=0.45,
)

  0%|          | 0/5 [00:00<?, ?it/s]

206it [00:02, 83.91it/s]                     

All nodes in this graph have been searched





In [6]:
import graphviz

In [7]:
res

{(): 9.042861906878263e-05,
 ('weight_first',): 7.175638760592074e-05,
 ('platelet_first',): 0.002863131189964069,
 ('tco2_first',): 0.00032398194956892104,
 ('chloride_first',): 0.00019904045421525722,
 ('creatinine_first',): 0.00016466249826343641,
 ('platelet_first', 'weight_first'): 0.0011387710306409436,
 ('platelet_first', 'tco2_first'): 0.014325008126373329,
 ('platelet_first', 'chloride_first'): 0.007509208475529455,
 ('platelet_first', 'creatinine_first'): 0.015019129499198872,
 ('platelet_first', 'creatinine_first', 'weight_first'): 0.00650501543751181,
 ('platelet_first', 'creatinine_first', 'tco2_first'): 0.16570695485004192,
 ('platelet_first',
  'creatinine_first',
  'chloride_first'): 0.035634173449230616,
 ('platelet_first',
  'creatinine_first',
  'tco2_first',
  'weight_first'): 0.0658457896396117,
 ('platelet_first',
  'creatinine_first',
  'tco2_first',
  'chloride_first'): 0.38145822518474326,
 ('platelet_first',
  'creatinine_first',
  'tco2_first',
  'chloride_fi

In [8]:
converter = {
    "weight_first": "Weight",
    "platelet_first": "Platelet",
    "tco2_first": "TCO2",
    "chloride_first": "Chloride",
    "creatinine_first": "Creatinine",
}
temp = [
    ((str(i), ", ".join([converter[z] for z in x])), x, y)
    for i, (x, y) in enumerate(res.items())
]

In [9]:
counts = [1, 5, 4, 3, 2, 1, 1, 2, 1, 1, 2, 1, 1]
count_expanded = []
for i in range(len(counts)):
    count_expanded += [i] * counts[i]

In [10]:
temp_2 = [(x, y) for x, y in zip(count_expanded, temp)]

In [11]:
temp_2

[(0, (('0', ''), (), 9.042861906878263e-05)),
 (1, (('1', 'Weight'), ('weight_first',), 7.175638760592074e-05)),
 (1, (('2', 'Platelet'), ('platelet_first',), 0.002863131189964069)),
 (1, (('3', 'TCO2'), ('tco2_first',), 0.00032398194956892104)),
 (1, (('4', 'Chloride'), ('chloride_first',), 0.00019904045421525722)),
 (1, (('5', 'Creatinine'), ('creatinine_first',), 0.00016466249826343641)),
 (2,
  (('6', 'Platelet, Weight'),
   ('platelet_first', 'weight_first'),
   0.0011387710306409436)),
 (2,
  (('7', 'Platelet, TCO2'),
   ('platelet_first', 'tco2_first'),
   0.014325008126373329)),
 (2,
  (('8', 'Platelet, Chloride'),
   ('platelet_first', 'chloride_first'),
   0.007509208475529455)),
 (2,
  (('9', 'Platelet, Creatinine'),
   ('platelet_first', 'creatinine_first'),
   0.015019129499198872)),
 (3,
  (('10', 'Platelet, Creatinine, Weight'),
   ('platelet_first', 'creatinine_first', 'weight_first'),
   0.00650501543751181)),
 (3,
  (('11', 'Platelet, Creatinine, TCO2'),
   ('platelet

In [12]:
dot = graphviz.Digraph(comment="Search Visualization")
dot
max = 20

for stage in range(1, 10):
    if stage > 1:
        dot.node(
            f"cluster_{stage}",
            label=f"Search Stage {stage-1}",
            style="filled",
            color="lightgrey",
            shape="box",
        )
    else:
        dot.node(
            f"cluster_{stage}",
            label=f"Search Stage {stage-1}",
            color="lightgrey",
            shape="box",
            style="invis",
        )

    if stage > 1:
        dot.body.append(f'"cluster_{stage-1}" -> "cluster_{stage}" [style=invis]\n')

for stage, ((i, x), y, z) in temp_2[:max]:
    if stage == 0:
        dot.node(str(y), "Initial Value")
    else:
        dot.node(str(y), x + "\nScore: " + format(z, "f"))
    if len(y) > 0:
        dot.edge(str(y[:-1]), str(y))
    for sub_stage, ((sub_i, sub_x), sub_y, sub_z) in temp_2[:max]:
        if stage < sub_stage:
            dot.body.append(f'"{str(y)}" -> "{str(sub_y)}" [style=invis]\n')
            break

dot.attr(layout="dot")
# dot.attr(rankdir='LR')
# dot.attr(ranksep='30')
dot.render("doctest-output/round-table.gv").replace("\\", "/")
"doctest-output/round-table.gv.pdf"

'doctest-output/round-table.gv.pdf'