# Towards Safe and Interpretable Optimal Treatment Regime Estimation Using Mechanistic Modeling and Interpolation
This notebook implements the optimal regime estimation method via nearest neighbors estimation

Let's first load the necessary libraries and source files

In [None]:
%load_ext lab_black
# I use Lab Black to format my code. A well formatted code is the answer to happy living.
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tqdm
import helper
import importlib
import sklearn.linear_model as lm
import sklearn.tree as tree
import sklearn.ensemble as en
from tabulate import tabulate
import seaborn as sns
from sklearn.model_selection import StratifiedKFold


sns.set(style="whitegrid", font_scale=2)

importlib.reload(helper)

import warnings

warnings.filterwarnings("ignore")

Now we will load data from dataset.csv file in the main folder. For replicability of our results, we set a random seed. I love 42 - it is answer to the Ultimate Question of Life, the Universe, and Everything

In [None]:
# Read the dataset from the CSV file located in the parent directory
np.random.seed(42)
df = pd.read_csv("../dataset.csv")  # load the dataset
df = df.set_index("SID")  # setting patient id as the unique ID
uniq_sids = np.unique(df.index)  # find the unique patient IDs

df_params = pd.read_csv(
    "pd_params.csv", index_col=0
)  # load pharmacodynamics parameters estimated at an earlier stage

In [None]:
pd.concat([df_params.mean(), df_params.std()], axis=1)
fig, ax = plt.subplots(ncols=3, figsize=(20, 6), dpi=800)
sns.boxplot(
    df_params[["propofol_50", "levetiracetam_50"]].rename(
        columns=lambda x: "propofol" if "propofol" in x else "levetiracetam"
    ),
    ax=ax[0],
    showfliers=False,
)
ax[0].set_ylabel("ED50")
sns.boxplot(
    df_params[["propofol_Hill", "levetiracetam_Hill"]].rename(
        columns=lambda x: "propofol" if "propofol" in x else "levetiracetam"
    ),
    ax=ax[1],
    showfliers=False,
)
ax[1].set_ylabel(r"$\alpha$   or   Hill Coefficient")

sns.boxplot(
    df_params[["baseline_avg"]].rename(columns=lambda x: ""),
    ax=ax[2],
    showfliers=False,
    color="C2",
)
ax[2].set_ylabel(r"$\beta$")
ax[2].set_ylim(0, 1.05)

plt.tight_layout()
fig.savefig("PD_estimates.pdf")

In [None]:
X = df[[col for col in df.columns if "C_imputed_" in col]].groupby(by="SID").mean()
X_wo_params = X.copy(deep=True)
X = X.join(df_params).dropna()  # X are pre-treatment covariates
X_normalized = helper.normalize(X)
X_normalized_2 = helper.normalize(X_wo_params)

Y = (df[["Y"]].groupby(by="SID").mean() > 4).loc[
    X.index
]  # Y is binarized mRS score - our outcome of interest. 0 is good and 1 is bad.

### Estimating $\pi$

Now, we will estimate the observed policy for each unit for both the drugs of our interest

In [None]:
m_prop, m_lev = ([], [])  # initiating the list of policies per patient
score_prop = []
for i in tqdm.tqdm(range(uniq_sids.shape[0])):
    df_i = df.loc[
        uniq_sids[i]
    ]  # fetching the dataset specific to a patient. Each row is a time point

    v_i_prop = helper.get_features_prop(
        df_i
    )  # summarizing the EA burden timeseries to get state space representation

    v_i_lev = helper.get_features_lev(
        df_i
    )  # summarizing the EA burden timeseries to get state space representation

    m_prop_i = lm.RidgeCV(fit_intercept=False).fit(
        v_i_prop.drop(columns=["Prop_Act"]), v_i_prop["Prop_Act"]
    )  # fitting a regression to estimate the observed policy for propofol

    m_lev_i = lm.RidgeCV(fit_intercept=False).fit(
        v_i_lev.drop(columns=["Lev_Act"]), v_i_lev["Lev_Act"]
    )  # fitting a regression to estimate the observed policy for levetiracetam

    m_prop.append(m_prop_i.coef_)
    m_lev.append(m_lev_i.coef_)
    score_prop.append(
        m_prop_i.score(v_i_prop.drop(columns=["Prop_Act"]), v_i_prop["Prop_Act"])
    )
m_prop = pd.DataFrame(
    np.array(m_prop),
    columns=v_i_prop.drop(columns=["Prop_Act"]).columns,
    index=uniq_sids,
)  # constructing the results dataframe for propofol
m_lev = pd.DataFrame(
    np.array(m_lev),
    columns=v_i_lev.drop(columns=["Lev_Act"]).columns,
    index=uniq_sids,
)  # constructing the results dataframe for levetiracetam

In [None]:
T = (
    m_prop.join(m_lev, lsuffix=" [prop]", rsuffix=" [lev]").loc[X.index].round(1)
)  # constructing a vector of treatment where each row is a unit and columns are the coefficients of the estimated observed policy for both the drugs -- propofol followed by levetiracetam

Visualizing the avg. observed policy for propofol

In [None]:
import seaborn as sns

sns.set(style="whitegrid", font_scale=1.5)
fig, ax = plt.subplots(figsize=(20, 5))
obs_policy = (
    pd.concat([T.mean(), T.sem()], axis=1)
    .round(2)
    .rename(columns={0: "avg", 1: "std.err"})
    .sort_values(by="avg", ascending=False)
)

obs_policy.loc[[c for c in obs_policy.index if "[prop]" in c], "avg"].plot(kind="bar")

In [None]:
pd.concat([T.mean(), T.std()], axis=1).loc[
    [col for col in T.columns if "[prop]" in col]
]

In [None]:
import seaborn as sns

sns.set(style="whitegrid", font_scale=1.5)
fig, ax = plt.subplots(figsize=(20, 5))
obs_policy = (
    pd.concat([T.mean(), T.sem()], axis=1)
    .round(2)
    .rename(columns={0: "avg", 1: "std.err"})
    .sort_values(by="avg", ascending=False)
)

obs_policy.loc[[c for c in obs_policy.index if "[lev]" in c], "avg"].plot(kind="bar")

### Matching Step

Next, we group the data by patient ID. We need to do this because each time stamp is a row in the original data. We do this to extract pre-treatment covariates and post-discharge outcome that are not time varying
We follow this by performing the Matching.

#### Estimating $\pi^{opt}$

In [None]:
# We use Model-to-Match-esque method to learn the distance metric for matching. We run the non-parametric regression to estimate E[ Y | X]. We use the variable importances from this non-parametric regression to inform our distance metric
# m_dist_metric = lm.LassoCV().fit(X, Y)
np.random.seed(42)
skf = StratifiedKFold(n_splits=3)
opt_mean_policy_array = []
df_y_array = []
indices = X_normalized.index
for i, (e_index, t_index) in enumerate(skf.split(X, Y)):
    train_index = indices[t_index]
    est_index = indices[e_index]
    m_dist_metric = en.GradientBoostingClassifier(max_depth=1, n_estimators=100).fit(
        X_normalized.loc[train_index], Y.loc[train_index]
    )
    dist_metric = (
        m_dist_metric.feature_importances_ / m_dist_metric.feature_importances_.max()
    )

    # dist_metric = np.abs(m_dist_metric.coef_) / np.abs(m_dist_metric.coef_).max()

    MG, D = helper.caliper_match(
        X=X_normalized.loc[est_index].to_numpy(), metric=dist_metric, caliper=0.012
    )  # create the matched groups

    opt_mean_policy_ = pd.concat(
        [
            T.loc[est_index][
                MG[i] > 0
            ]  # get treatment assignment for all units in the matched group of unit i
            .join(
                Y.loc[est_index][MG[i] > 0]
            )  # join outcomes for all units in the MG of i with treatments
            .groupby("Y")  # group by outcomes, here the outcome is binary
            .mean()  # get the average treatment for each outcome
            .iloc[[0]]  # choose the treatment with the minimum outcome
            for i in range(MG.shape[0])
        ],
        axis=0,
    )

    opt_mean_policy_ = opt_mean_policy_.reset_index()
    opt_mean_policy_.index = X.loc[est_index].index

    MGs_T_Y = [
        T.loc[est_index][
            MG[i] > 0
        ].join(  # get treatment assignment for all units in the matched group of unit i
            Y.loc[est_index][MG[i] > 0]
        )
        for i in range(MG.shape[0])
    ]

    log_regs = {
        i: lm.LogisticRegression().fit(MGs_T_Y[i].drop(columns=["Y"]), MGs_T_Y[i]["Y"])
        for i in range(MG.shape[0])
        if len(np.unique(MGs_T_Y[i]["Y"])) > 1
    }

    y_opt = [
        log_regs[i].predict_proba(
            opt_mean_policy_.drop(columns=["Y"]).iloc[i].values.reshape(1, -1)
        )[0, 1]
        for i in range(MG.shape[0])
        if len(np.unique(MGs_T_Y[i]["Y"])) > 1
    ]

    y_obs_proba = [
        log_regs[i].predict_proba(T.loc[est_index].iloc[i].values.reshape(1, -1))[0, 1]
        for i in range(MG.shape[0])
        if len(np.unique(MGs_T_Y[i]["Y"])) > 1
    ]

    df_ = pd.DataFrame()
    df_["y_opt"] = y_opt
    df_["y_adm"] = y_obs_proba
    df_.index = [
        est_index[i] for i in range(MG.shape[0]) if len(np.unique(MGs_T_Y[i]["Y"])) > 1
    ]

    df_y_array.append(df_)
    opt_mean_policy_array.append(opt_mean_policy_)

In [None]:
opt_mean_policy_ = pd.concat(opt_mean_policy_array).groupby(by="SID").mean()
df_y_diff = pd.concat(df_y_array)
df_y_diff["SID"] = df_y_diff.index
df_y_diff = df_y_diff.set_index("SID")
df_y_diff = df_y_diff.groupby(by="SID").mean()

In [None]:
print((df_y_diff["y_opt"] - df_y_diff["y_adm"]).mean())
print(1.96 * ((df_y_diff["y_opt"]).sem() + df_y_diff["y_adm"].sem()))

In [None]:
sns.histplot(
    D.reshape(
        -1,
    )
)
plt.xlim(-0.0001, 0.05)

In [None]:
# pd.DataFrame(T.mean())
import scipy

res_summary = pd.DataFrame(opt_mean_policy_.mean()).join(
    pd.DataFrame(Y.join(T).mean()), rsuffix="_obs"
)

pd.DataFrame(res_summary["0"] - res_summary["0_obs"]).join(
    pd.DataFrame(
        (
            scipy.stats.norm.sf(
                abs(
                    (
                        (res_summary["0"] - res_summary["0_obs"])
                        / np.sqrt(opt_mean_policy_.sem() ** 2 + Y.join(T).sem() ** 2)
                    ).sort_values()
                )
            )
            * 2
        ).round(4),
        index=res_summary.index,
        columns=["p_val"],
    )
).sort_values(by="p_val", ascending=True).round(2).drop(index=["Y"])

### Some Summary Statistics About the Matching

How many units are there in each Matched Group?

In [None]:
import seaborn as sns

sns.histplot(
    MG.sum(axis=1).reshape(
        -1,
    )
)
# plt.xlim(0, 2000)

# sns.histplot(D.reshape(-1,))
# plt.xlim(0, 5)

In [None]:
res = pd.DataFrame(opt_mean_policy_).join(pd.DataFrame(Y.join(T)), rsuffix="_obs")

### More Analysis About the Heterogeneity

In [None]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=800)
inclusion = (X["propofol_50"] > 0.1) * (X["propofol_50"] < 5)
sns.regplot(
    x=X["propofol_50"].loc[inclusion].round(0),
    y=(res["E in last 1h (>25%) [prop]"] - res["E in last 1h (>25%) [prop]_obs"]).loc[
        inclusion
    ],
    order=3,
    # scatter_kws={"alpha": 0.15},
    ax=ax,
    scatter=False,
    # x_bins=14,
    ci=95,
    line_kws={"linewidth": 4},
)
plt.axhline(0, c="black")


sns.regplot(
    x=X["propofol_50"].loc[inclusion].round(0),
    y=(
        res["E in last 1h (>50%) [prop]"]
        + res["E in last 1h (>75%) [prop]"]
        - (
            res["E in last 1h (>50%) [prop]_obs"]
            + res["E in last 1h (>75%) [prop]_obs"]
        )
    ).loc[inclusion],
    order=3,
    # scatter_kws={"alpha": 0.15},
    ax=ax,
    scatter=False,
    x_bins=14,
    line_kws={"linewidth": 4, "linestyle": "--"},
    ci=95,
)

plt.title(
    "Difference in the Propofol Dose between\n the optimal and administered regimes"
)
plt.ylabel("Difference in the \nDoses [in mg/kg/hr]")
plt.xlabel("ED50 (Propofol)")
g_legend = ax.legend(["A", None, None, "B"])

plt.axhline(0, c="black")
g_legend.legendHandles.pop(1)
g_legend.legendHandles.pop(1)
ax.legend(
    g_legend.legendHandles,
    [" 25% < E < 50%", "E > 75%"],
    title=r"EA burden in last 1h",
)
plt.tight_layout()
# plt.xlim(0, 5)
plt.ylim(-2, 1.5)
fig.savefig("../Figures/policy_diff_optimal_administered_propofol_ed50.pdf", dpi=800)

In [None]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=800)
# inclusion = (X["propofol_50"] > 0.1) * (X["propofol_50"] < 5)
sns.regplot(
    x=X["baseline_avg"].loc[inclusion].round(0),
    y=(res["E in last 1h (>25%) [prop]"] - res["E in last 1h (>25%) [prop]_obs"]).loc[
        inclusion
    ],
    order=3,
    # scatter_kws={"alpha": 0.15},
    ax=ax,
    scatter=False,
    # x_bins=14,
    ci=95,
)
plt.axhline(0, c="black")


sns.regplot(
    x=X["baseline_avg"].loc[inclusion].round(0),
    y=(
        res["E in last 1h (>50%) [prop]"]
        + res["E in last 1h (>75%) [prop]"]
        - (
            res["E in last 1h (>50%) [prop]_obs"]
            + res["E in last 1h (>75%) [prop]_obs"]
        )
    ).loc[inclusion],
    order=3,
    # scatter_kws={"alpha": 0.15},
    ax=ax,
    scatter=False,
    x_bins=14,
    ci=95,
)

plt.title(
    "Difference in the Propofol Dose between\n the optimal and administered regimes"
)
plt.ylabel("Difference in the \nDoses [in mg/kg/hr]")
plt.xlabel("ED50 (Propofol)")
g_legend = ax.legend(["A", None, None, "B"])

plt.axhline(0, c="black")
g_legend.legendHandles.pop(1)
g_legend.legendHandles.pop(1)
ax.legend(
    g_legend.legendHandles,
    [" 25% < E < 50%", "E > 75%"],
    title=r"EA burden in last 1h",
)
plt.tight_layout()
# plt.xlim(0, 5)
plt.ylim(-2, 1.5)
fig.savefig(
    "../Figures/policy_diff_optimal_administered_propofol_baseline.pdf", dpi=800
)

In [None]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=800)

sns.regplot(
    x=X["C_imputed_APACHE II 1st 24h"].round(0),
    y=res["E in last 1h (>25%) [prop]"] - res["E in last 1h (>25%) [prop]_obs"],
    order=4,
    # scatter_kws={"alpha": 0.15},
    ax=ax,
    scatter=False,
    x_bins=14,
    ci=95,
    line_kws={"linewidth": 4},
)
plt.axhline(0, c="black")


sns.regplot(
    x=X["C_imputed_APACHE II 1st 24h"].round(0),
    y=res["E in last 1h (>50%) [prop]"]
    + res["E in last 1h (>75%) [prop]"]
    - (res["E in last 1h (>50%) [prop]_obs"] + res["E in last 1h (>75%) [prop]_obs"]),
    order=4,
    # scatter_kws={"alpha": 0.15},
    ax=ax,
    scatter=False,
    x_bins=14,
    ci=95,
    line_kws={"linewidth": 4, "linestyle": "--"},
)

plt.title(
    "Difference in the Propofol Dose between\n the optimal and administered regimes"
)
plt.ylabel("Difference in the \nDoses [in mg/kg/hr]")
plt.xlabel("APACHE II Score")
g_legend = ax.legend(["A", None, None, "B"])

plt.axhline(0, c="black")
g_legend.legendHandles.pop(1)
g_legend.legendHandles.pop(1)
ax.legend(
    g_legend.legendHandles,
    [" 25% < E < 50%", "E > 75%"],
    title=r"EA burden in last 1h",
)
plt.tight_layout()
plt.xlim(5, 35)
plt.ylim(-1.5, 1)
fig.savefig("../Figures/policy_diff_optimal_administered_apache.pdf", dpi=800)

In [None]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=800)

sns.regplot(
    x=X["C_imputed_iGCS-Total"].round(0),
    y=res["E in last 1h (>25%) [prop]"] - res["E in last 1h (>25%) [prop]_obs"],
    order=4,
    # scatter_kws={"alpha": 0.15},
    ax=ax,
    scatter=False,
    x_bins=14,
    ci=95,
    line_kws={"linewidth": 4},
)
plt.axhline(0, c="black")


sns.regplot(
    x=X["C_imputed_iGCS-Total"].round(0),
    y=res["E in last 1h (>50%) [prop]"]
    + res["E in last 1h (>75%) [prop]"]
    - (res["E in last 1h (>50%) [prop]_obs"] + res["E in last 1h (>75%) [prop]_obs"]),
    order=3,
    # scatter_kws={"alpha": 0.15},
    ax=ax,
    scatter=False,
    x_bins=14,
    ci=95,
    line_kws={"linewidth": 4, "linestyle": "--"},
)

plt.title(
    "Difference in the Propofol Dose between\n the optimal and administered regimes"
)
plt.ylabel("Difference in the \nDoses [in mg/kg/hr]")

plt.xlabel("Glasgow Coma Scale (GCS)")
g_legend = ax.legend(["A", None, None, "B"])

plt.axhline(0, c="black")
g_legend.legendHandles.pop(1)
g_legend.legendHandles.pop(1)
ax.legend(
    g_legend.legendHandles,
    [" 25% < E < 50%", "E > 75%"],
    title=r"EA burden in last 1h",
)
plt.tight_layout()
fig.savefig("../Figures/policy_diff_optimal_administered_gcs.pdf", dpi=800)

In [None]:
res2 = opt_mean_policy_.drop(columns=["Y"]).copy(deep=True)
res2["Regime"] = "Optimal"
res3 = T.copy()
res3["Regime"] = "Administered"
res2 = pd.concat([res2, res3])

In [None]:
fig, ax = plt.subplots(figsize=(10, 4.5), dpi=800)
sns.kdeplot(np.array(df_y_diff["y_opt"]), fill=True, lw=3)
sns.kdeplot(np.array(df_y_diff["y_adm"]), color="C3", fill=True, lw=3, ls="--")
plt.xlim(0, 1)
plt.xticks([0, 0.25, 0.5, 0.75, 1])
plt.yticks([0, 1, 2])
plt.xlabel(r"Estimated Outcome ($\mathbf{P}[Y_i(\pi)=1 \mid \mathbf{V}_i]$)")
plt.legend(["Optimal", "Administered"], title=r"Regime ($\pi$)")
plt.tight_layout()
fig.savefig("../Figures/outcome_optimal_administered.pdf", dpi=800)

In [None]:
df_y_diff["diff"] = df_y_diff["y_opt"] - df_y_diff["y_adm"]

In [None]:
df_y_diff.mean()

In [None]:
X_interest_col = [
    "C_imputed_Hx CVA",
    "C_imputed_Hx HTN",
    "C_imputed_Hx brain surgery",
    "C_imputed_Hx CKD",
    "C_imputed_Hx CAD/MI",
    "C_imputed_Hx CHF",
    "C_imputed_Hx DM",
    "C_imputed_Hx liver failure",
    "C_imputed_Hx smoking",
    "C_imputed_Hx alcohol",
    "C_imputed_Hx substance abuse",
    "C_imputed_Hx cancer",
    "C_imputed_Hx PVD",
    "C_imputed_Hx dementia",
    "C_imputed_Hx COPD/Asthma",
    "C_imputed_Hx leukemia/lymphoma",
    "C_imputed_Hx AIDs",
    "C_imputed_acute SDH",
    "C_imputed_Sepsis/Shock",
    "C_imputed_NeuroDx:IschStroke",
    "C_imputed_NeuroDx:HemStroke",
    "C_imputed_NeuroDx:SAH",
    "C_imputed_NeuroDx:Brain tumor",
    "C_imputed_NeuroDx:CNS infection",
    "C_imputed_NeuroDx:HIE/ABI",
]
X_features = [
    "Cerebrovascular Accident",
    "Hypertension",
    "Brain Surgery",
    "Kidney Disease",
    "Coronary Artery Disease",
    "Congestive Heart Failure",
    "Diabetes Mellitus",
    "Liver Failure",
    "Smoking",
    "Alcohol",
    "Substance Abuse",
    "Cancer",
    "Peripheral Vascular Disease",
    "Dementia",
    "Asthma",
    "Leukemia",
    "AIDs",
    "Subdural Hematoma",
    "Sepsis/Shock",
    "Ischemic Stroke",
    "Hemorrhagic Stroke",
    "Subarachnoid Hemorrhage",
    "Brain Tumor",
    "CNS Infection",
    "HIE/ABI",
]

In [None]:
explain_y_diff = tree.DecisionTreeRegressor(max_depth=2, min_samples_leaf=15).fit(
    X=X[X_interest_col].loc[df_y_diff.index].astype(bool),
    y=(-100 * df_y_diff["diff"]).round(2),
)
sns.set(style="whitegrid", font_scale=2)
fig, ax = plt.subplots(figsize=(10, 5), dpi=800)
tree.plot_tree(
    explain_y_diff,
    feature_names=X_features,
    ax=ax,
    filled=True,
    fontsize=19,
    impurity=False,
    precision=1,
    proportion=True,
    rounded=True,
    max_depth=2,
)
plt.tight_layout()
fig.savefig("../Figures/outcome_dtree.pdf", dpi=800)

In [None]:
E75lev_diff = (
    res["E in last 6h (>25%) AND E in last 12h (>25%) [lev]"]
    - res["E in last 6h (>25%) AND E in last 12h (>25%) [lev]_obs"]
)

In [None]:
# E75lev_diff = res["E in last 1h (>75%) [lev]"] - res["E in last 1h (>75%) [lev]_obs"]
explain_E75lev_diff = tree.DecisionTreeRegressor(max_depth=2, min_samples_leaf=15).fit(
    X=X[X_interest_col],
    y=E75lev_diff,
)

sns.set(style="whitegrid", font_scale=2)
fig, ax = plt.subplots(figsize=(10, 5), dpi=800)
tree.plot_tree(
    explain_E75lev_diff,
    feature_names=X_features,
    ax=ax,
    filled=True,
    fontsize=19,
    impurity=False,
    precision=1,
    proportion=True,
    rounded=True,
)
plt.tight_layout()
fig.savefig("../Figures/e6h_12h_lev_diff_explain.pdf", dpi=800)

In [None]:
E75lev_diff = (
    res["E in last 1h (>25% ) AND E in last 6h (>25%) [lev]"]
    + res["Baseline Dose"]
    + res["E in last 1h (>25%) [lev]"]
    + res["E in last 6h (>25%) [lev]"]
    - (
        res["E in last 1h (>25% ) AND E in last 6h (>25%) [lev]_obs"]
        + res["Baseline Dose_obs"]
        + res["E in last 1h (>25%) [lev]_obs"]
        + res["E in last 6h (>25%) [lev]_obs"]
    )
)
explain_E75lev_diff = tree.DecisionTreeRegressor(max_depth=2, min_samples_leaf=15).fit(
    X=X[X_interest_col],
    y=E75lev_diff,
)

sns.set(style="whitegrid", font_scale=2)
fig, ax = plt.subplots(figsize=(10, 5), dpi=800)
tree.plot_tree(
    explain_E75lev_diff,
    feature_names=X_features,
    ax=ax,
    filled=True,
    fontsize=19,
    impurity=False,
    precision=1,
    proportion=True,
    rounded=True,
)
plt.tight_layout()
fig.savefig("../Figures/e75_lev_diff_explain.pdf", dpi=800)

In [None]:
res_X = res2.join(
    X[X_interest_col].rename(
        columns={X_interest_col[i]: X_features[i] for i in range(len(X_interest_col))}
    )
)
res_X

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=800)
sns.pointplot(
    y=res_X["E in last 6h (>25%) AND E in last 12h (>25%) [lev]"]
    + res_X["Baseline Dose"]
    + res_X["E in last 6h (>25%) [lev]"],
    hue="Regime",
    x="Dementia",
    data=res_X,
    dodge=0.05,
    ci=95,
    capsize=0.05,
    join=True,
    palette="Set1",
    markers="s",
    scale=2,
)
ax.legend_ = None
plt.ylabel("Levetiracetam \n Drug Dose\n [in mg/kg]")
# plt.ylim(3.65, 17)
# plt.title("Patients with E > 25% in last 6h and 12h")
plt.legend(bbox_to_anchor=(1, -0.25), borderaxespad=0)
plt.tight_layout()
fig.savefig("../Figures/lev_dementia.pdf", dpi=800)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=800)
sns.pointplot(
    y=res_X["E in last 1h (>25% ) AND E in last 6h (>25%) [lev]"]
    + res_X["Baseline Dose"]
    + res_X["E in last 1h (>25%) [lev]"]
    + res_X["E in last 6h (>25%) [lev]"],
    hue="Regime",
    x="Subarachnoid Hemorrhage",
    data=res_X,
    dodge=0.05,
    ci=95,
    capsize=0.05,
    palette="Set1",
    markers="s",
    scale=2,
)
plt.ylabel("Levetiracetam \n Drug Dose\n [in mg/kg]")
# plt.title("Patients with E > 25% in last 1h and 6h")
# plt.ylim(3.65, 17)
plt.legend(bbox_to_anchor=(1, -0.25), borderaxespad=0)
plt.tight_layout()
fig.savefig("../Figures/lev_1h_6h.pdf", dpi=800)

In [None]:
y_apache = df_y_diff.join(X[["C_imputed_APACHE II 1st 24h"]]).join(
    df[["Y"]].groupby(by="SID").mean() == 6
)

In [None]:
y_apache["APACHE_ord"] = (
    (y_apache["C_imputed_APACHE II 1st 24h"] > 4).astype(int)
    + (y_apache["C_imputed_APACHE II 1st 24h"] > 9).astype(int)
    + (y_apache["C_imputed_APACHE II 1st 24h"] > 14).astype(int)
    + (y_apache["C_imputed_APACHE II 1st 24h"] > 19).astype(int)
    + (y_apache["C_imputed_APACHE II 1st 24h"] > 24).astype(int)
    + (y_apache["C_imputed_APACHE II 1st 24h"] > 29).astype(int)
    + (y_apache["C_imputed_APACHE II 1st 24h"] > 34).astype(int)
)

In [None]:
y_apache.groupby(by="APACHE_ord").mean()[["y_adm"]]