In [None]:
import pandas as pd
# import modin.pandas as pd
import numpy as np
import pickle as pkl
import datetime
import json
import time
import matplotlib.pyplot as plt
from statsmodels.stats.weightstats import ztest, ttest_ind
import scienceplots

plt.style.use(['science', 'nature', 'no-latex'])

pd.options.mode.chained_assignment = None
from tqdm import tqdm
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression, LinearRegression, Ridge
from sklearn.metrics import roc_auc_score, average_precision_score
%matplotlib inline
%load_ext autoreload
%autoreload 2
from plotnine import *

import sys
import os
sys.path.append('../../src')
sys.path.append('../../slicefinder')
from slice_finder import SliceFinder
from clustering_analysis import ClusteringEstimator
from stability_analysis import LatentSubgroupShiftEstimator
import utils
import sirus
from sklearn.model_selection import GridSearchCV


# Load Data

In [None]:
raw_df = pd.read_hdf('/data/adarsh/fda_project_data/raw_complete_dataset.h5')

In [None]:
scaled_test_df = pd.read_hdf('/data/adarsh/fda_project_data/hcgh_combined_test_df_7_11_2022.h5')

In [None]:
commorbidity_df = pd.read_csv("/data/adarsh/fda_project_data/study_cohort_commorbidity_features.csv")

In [None]:
commorbidity_df.head()

In [None]:
raw_df.columns

In [None]:
# merge relevant dataframes
wc_test_df = (scaled_test_df
              .merge(raw_df[['enc_id', 'obs_time', 'age']], on=['enc_id', 'obs_time'], how='left'))

In [None]:
test_y = scaled_test_df.label.values

In [None]:
wc_test_df = wc_test_df.merge(commorbidity_df, on ='enc_id', how='left')

In [None]:
# look at shift in demographics, commorbidities, and circumstances of admission
features = ['age', 'sex', 'ed_admit', 'season_1', 'season_2', 'season_3',
            'trauma_level', 'hospital_size_large'] + list(commorbidity_df.columns)[1:]

continuous_features = ['age']

In [None]:
# make age numeric
wc_test_df['age'] = wc_test_df.age.replace({'>= 90':90}).astype(int)

In [None]:
X_test = wc_test_df[features]
orig_X = np.copy(X_test)
y_test = wc_test_df['label']

all_features = features.copy()

In [None]:
full_df = pd.concat([wc_test_df.drop(features, axis=1), X_test], axis=1)

# AFISP Step 1: Stability Analysis

In [None]:
y_test = full_df.label.values
test_preds = full_df['aam_prediction'].values 
hinge_auc_loss = utils.torch_roc_auc_surrogate(y_test, test_preds, 'hinge')

In [None]:
test_loss = hinge_auc_loss

In [None]:
subgroup_feature_data = X_test.values

In [None]:
%%time
stability_analysis = LatentSubgroupShiftEstimator(cv=5, 
                                                  verbose=True, 
                                                  eps=1e-5, 
                                                  subset_fractions=np.arange(0.05, 1, 0.05)
                                                 )
sa_risks = stability_analysis.fit(subgroup_feature_data, hinge_auc_loss, feature_names=features)

In [None]:
sa_masks = stability_analysis.subset_masks()

In [None]:
_ = stability_analysis.check_subset_sizes()

In [None]:
# Check that loss increases as worst-performing subset size gets smaller
plt.plot(stability_analysis.subset_fractions, sa_risks, '.-', label='AAM')
plt.ylabel('loss')
plt.legend(loc='best')

In [None]:
subpop_aucs = []
news_aucs = []
for m in sa_masks:
    subpop_aucs.append(roc_auc_score(y_test[m], test_preds[m]))
    news_aucs.append(roc_auc_score(y_test[m], full_df.baseline_prediction.values[m]))


In [None]:
# can also compute bootstrap confidence intervals
bootstrap_cis = np.zeros((len(sa_masks), 2))
for i, alpha in tqdm(enumerate(stability_analysis.subset_fractions)):
    mask = sa_masks[i]
    mean, upper, lower = utils.bootstrap_ci(y_test[mask], test_preds[mask])
    bootstrap_cis[i, 0] = lower
    bootstrap_cis[i, 1] = upper

plt.plot(stability_analysis.subset_fractions, subpop_aucs)
plt.fill_between(stability_analysis.subset_fractions,
                 bootstrap_cis[:, 0], 
                 bootstrap_cis[:, 1], 
                 alpha=0.25, 
                 label='Bootstrap')
plt.xlabel('Subset Fraction')
plt.ylabel('Brier Score')
plt.legend(loc='best')
plt.show()

In [None]:
stab_curve_df = pd.DataFrame(
    {
        'Subset Fraction': stability_analysis.subset_fractions,
        'AUROC': subpop_aucs, 
        'Lower': bootstrap_cis[:, 0],
        'Upper': bootstrap_cis[:, 1],
        'Name': 'AAM'
    })
baseline_curve_df = pd.DataFrame(
    {
        'Subset Fraction': stability_analysis.subset_fractions,
        'AUROC': news_aucs[-1], 
        'Name': 'Baseline'
    })

In [None]:
plt.plot(stability_analysis.subset_fractions, subpop_aucs, '.-', label='AAM')
# plt.fill_between(stability_analysis.subset_fractions, roc_ci[:, 0], roc_ci[:, 1], alpha=0.25)

# plt.plot(alphas, news_aucs, label='NEWS (Baseline)')
plt.plot(stability_analysis.subset_fractions, 
         np.ones_like(news_aucs) * news_aucs[-1], 
         '--', 
         label='Baseline Full Performance')
plt.fill_between(stability_analysis.subset_fractions,
                 bootstrap_cis[:, 0], 
                 bootstrap_cis[:, 1], 
                 alpha=0.25)
plt.ylabel('AUROC')
plt.xlabel('Subset Fraction')
plt.legend(loc='lower right')
plt.xlim(0, 1.05)
# plt.grid()
plt.savefig('figs/stability_curve.pdf', dpi=360)

In [None]:
cds = []
p_vals = []

for i, a in enumerate(stability_analysis.subset_fractions):
    idxs = sa_masks[i]
    odxs = ~sa_masks[i]
    cds.append(sirus.cohens_d(test_loss[idxs], test_loss[odxs]))
    pval = ttest_ind(test_loss[idxs], 
                         x2=test_loss[odxs], 
                         value=0.,
                         alternative='larger',
                         usevar='unequal')[1]
    p_vals.append(pval)
        
    
plt.plot(stability_analysis.subset_fractions, cds)
plt.xlabel('Subset Fraction')
plt.ylabel('Cohen\'s d (Effect Size)')
plt.show()

In [None]:
max_ind, max_cd = sirus.find_max_effect_size(sa_masks, test_loss)
print(max_ind, stability_analysis.subset_fractions[max_ind], max_cd)

In [None]:
max_ind = np.where(np.array(subpop_aucs) < news_aucs[-1])[0][-1]

# AFISP Step 2: Subgroup Phenotype Identification

In [None]:
phenotype_df = X_test.copy()

In [None]:
phenotype_df['trauma_level'] = phenotype_df['trauma_level'].astype(int).astype("category")
phenotype_df['admit_source'] = phenotype_df['admit_source'].astype(int).astype("category")

In [None]:
phenotype_df = pd.get_dummies(phenotype_df)

In [None]:
phenotype_df['subset_label'] = sa_masks[max_ind]*1

In [None]:
phenotype_df.to_csv("sirus_files/aam_for_sirus.csv", index=False)

In [None]:
depth = 3
rule_max = 50
sirus_rules_fname = f"tmp/afisp_sirus_rules_{rule_max}_rules_{depth}_depth.txt"

subprocess.call((f"/home/adarsh.subbaswamy/anaconda3/envs/afisp/bin/Rscript" 
f" run_sirus.r" 
f" --input {df_fname} "
f" --output {sirus_rules_fname}"
f" --depth {depth}"
f" --rule.max {rule_max}"
f" --cv"),
shell=True)

In [None]:
sirus_rules = sirus.get_sirus_rules(sirus_rules_fname)
rule_p_values = sirus.precompute_p_values(sirus_rules, phenotype_df, test_loss)
significant_rules = sirus.holm_bonferroni_correction(rule_p_values)

extracted_rules = sirus.effect_size_filtering(significant_rules, phenotype_df, test_loss, 
                                                  effect_threshold=0.4)

In [None]:
extracted_rules

In [None]:
r_aucs = []
r_ls = []
r_us = []
ns = []

for rule in tqdm(extracted_rules):
    rows = phenotype_df.eval(str(rule))
    m, l, u = utils.bootstrap_ci(y_test[rows], test_preds[rows])
    ns.append(np.sum(rows))

    r_aucs.append(m)
    r_ls.append(l)
    r_us.append(u)

In [None]:
aam_sirus_df = pd.DataFrame(
    {
        'Phenotype': extracted_rules, 
        'AUROC': r_aucs, 
        'N': ns, 
        'Lower': r_ls, 
        'Upper': r_us
    }).sort_values(by='AUROC')

In [None]:
aam_sirus_df

## Look at prevalence of AFISP subgroups in each worst-case subset

In [None]:
# plot prevalence of each subgroup in worst-case subsets
prevalences = [[] for _ in range(len(aam_sirus_df))]
subgroup_num = [[i + 1] * len(sa_masks) for i in range(len(aam_sirus_df))]

for m in tqdm(sa_masks):
    for i in range(len(aam_sirus_df)):
        prev = phenotype_df[m].eval(aam_sirus_df.Phenotype.values[i]).mean()
        prevalences[i].append(prev)

In [None]:
cmap = plt.get_cmap('rainbow', len(prevalences))
for i in range(len(prevalences)):
    plt.plot(stability_analysis.subset_fractions, prevalences[i], label=f'Subgroup {i+1}', color=cmap(i))

plt.legend(loc='best')
plt.xlabel('Subset Fraction')
plt.ylabel('Prevalence in Subset')
plt.savefig('figs/prevalences.pdf', dpi=360)

# Run SliceFinder

In [None]:
%%time
sfX = X_test.copy()
sf = SliceFinder(None, (sfX, pd.DataFrame({'y': y_test})))
d1_slices = sf.slicing()
d2_slices = sf.crossing(d1_slices, 2)
candidate_slices = d1_slices + d2_slices
print("Slices acquired")
# candidate_rules = [sirus.slice_to_equality_rule(s) for s in candidate_slices]

# rule_p_values = sirus.precompute_p_values(candidate_rules, sfX, test_loss)
# significant_rules = sirus.holm_bonferroni_correction(rule_p_values)

# sf_extracted_rules = sirus.effect_size_filtering(significant_rules, sfX, test_loss, 
#                                               effect_threshold=0.4)

In [None]:
%%time
candidate_rules = [sirus.slice_to_equality_rule(s) for s in candidate_slices]

rule_p_values = sirus.precompute_p_values(candidate_rules, sfX, test_loss)
significant_rules = sirus.holm_bonferroni_correction(rule_p_values)

sf_extracted_rules = sirus.effect_size_filtering(significant_rules, sfX, test_loss, 
                                              effect_threshold=0.4)

In [None]:
len(sf_extracted_rules)

In [None]:
r_rs = []
r_aucs = []
r_ls = []
r_us = []
ns = []

for rule in tqdm(sf_extracted_rules):
    rows = sfX.eval(str(rule))
    
    if all(x==1 for x in y_test[rows]) or all(x==0 for x in y_test[rows]):
        continue

    m, l, u = utils.bootstrap_ci(y_test[rows], test_preds[rows])
    ns.append(np.sum(rows))

    r_aucs.append(m)
    r_ls.append(l)
    r_us.append(u)
    r_rs.append(rule)
    
sf_sirus_df = pd.DataFrame(
    {
        'Phenotype': r_rs, 
        'AUROC': r_aucs, 
        'N': ns, 
        'Lower': r_ls, 
        'Upper': r_us
    }).sort_values(by='AUROC')

In [None]:
sf_sirus_df.query('N >= 400')# .iloc[0:10]

In [None]:
# clustering

In [None]:
cl = ClusteringEstimator(verbose=True)
cl.fit(subgroup_feature_data, hinge_auc_loss)

In [None]:
phenotype_df = X_test.copy()
phenotype_df['trauma_level'] = phenotype_df['trauma_level'].astype(int).astype("category")
phenotype_df['admit_source'] = phenotype_df['admit_source'].astype(int).astype("category")
phenotype_df = pd.get_dummies(phenotype_df)
phenotype_df['subset_label'] = cl.masks_*1
phenotype_df.to_csv("sirus_files/clustering_for_sirus.csv", index=False)

In [None]:
clustering_rules_fname = "sirus_files/clustering_rules_d3.txt"
clustering_rules = sirus.get_sirus_rules(clustering_rules_fname)
clustering_rule_p_values = sirus.precompute_p_values(clustering_rules, phenotype_df, test_loss)
clustering_significant_rules = sirus.holm_bonferroni_correction(clustering_rule_p_values)

clustering_extracted_rules = sirus.effect_size_filtering(clustering_significant_rules, phenotype_df, test_loss, 
                                                  effect_threshold=0.4)

In [None]:
clustering_significant_rules

In [None]:
r_rs = []
r_aucs = []
r_ls = []
r_us = []
ns = []

for rule in tqdm([x[0] for x in clustering_significant_rules]):
    rows = sfX.eval(str(rule))
    
    if all(x==1 for x in y_test[rows]) or all(x==0 for x in y_test[rows]):
        continue

    m, l, u = utils.bootstrap_ci(y_test[rows], test_preds[rows])
    ns.append(np.sum(rows))

    r_aucs.append(m)
    r_ls.append(l)
    r_us.append(u)
    r_rs.append(rule)
    
clustering_sirus_df = pd.DataFrame(
    {
        'Phenotype': r_rs, 
        'AUROC': r_aucs, 
        'N': ns, 
        'Lower': r_ls, 
        'Upper': r_us
    }).sort_values(by='AUROC')

In [None]:
sf_filtered_df = sf_sirus_df.query("N >= 400")

# Compare groups found by SF and AFISP

In [None]:
random_slices = np.random.choice(candidate_rules, replace=False, size=1000)

In [None]:
r_rs = []
r_aucs = []
r_ls = []
r_us = []
ns = []

for rule in tqdm(random_slices):
    rows = sfX.eval(str(rule))
    
    if all(x==1 for x in y_test[rows]) or all(x==0 for x in y_test[rows]):
        continue

    m, l, u = utils.bootstrap_ci(y_test[rows], test_preds[rows])
    ns.append(np.sum(rows))

    r_aucs.append(m)
    r_ls.append(l)
    r_us.append(u)
    r_rs.append(rule)
    
random_slice_df = pd.DataFrame(
    {
        'Phenotype': r_rs, 
        'AUROC': r_aucs, 
        'N': ns, 
        'Lower': r_ls, 
        'Upper': r_us
    }).sort_values(by='AUROC')

In [None]:
full_slice_matrix = np.zeros((len(X_test), len(sf_filtered_df) + len(aam_sirus_df) + len(random_slice_df)))

for j, r in tqdm(enumerate(sf_filtered_df.Phenotype.values)):
    indicators = sfX.eval(r) * 1
    full_slice_matrix[:, j] = indicators
    
for j, r in tqdm(enumerate(aam_sirus_df.Phenotype.values)):
    indicators = phenotype_df.eval(r) * 1
    full_slice_matrix[:, len(sf_filtered_df) + j] = indicators
    
for j, r in tqdm(enumerate(random_slice_df.Phenotype.values)):
    indicators = sfX.eval(r) * 1
    full_slice_matrix[:, len(sf_filtered_df) + len(aam_sirus_df) + j] = indicators

In [None]:
%%time
from sklearn.model_selection import GridSearchCV
from sklearn.cross_decomposition import PLSRegression

# perform PLS; use CV to pick num components
plsr = PLSRegression()
cv = GridSearchCV(plsr, cv=5, param_grid = {
    'n_components':list(range(2, 20))
}, verbose=5, n_jobs=32, scoring="neg_mean_squared_error")
cv.fit(full_slice_matrix, utils.cross_entropy(y_test, test_preds))

In [None]:
cv.best_estimator_.n_components

In [None]:
plsr = PLSRegression(n_components=cv.best_estimator_.n_components)
plsr.fit(full_slice_matrix, test_loss)
plsr_X, plsr_Y = plsr.transform(full_slice_matrix, test_loss)

In [None]:
pca_embeddings = plsr.x_loadings_ # pca.components_.T
plt.scatter(plsr.x_loadings_[:len(sf_filtered_df), 0], 
            plsr.x_loadings_[:len(sf_filtered_df), 1], 
            label='SliceFinder Slices', color='purple')
plt.scatter(plsr.x_loadings_[len(sf_filtered_df) + len(aam_sirus_df):, 0], 
            plsr.x_loadings_[len(sf_filtered_df) + len(aam_sirus_df):, 1], 
            label='Random Slices', color='orange')

plt.xlabel("Partial Least Squares Dimension 1")
plt.ylabel("Partial Least Squares Dimension 2")
plt.axvline(x=0, color='k')
plt.axhline(y=0, color='k')
plt.box(False)

plt.legend()
plt.savefig('figs/pls_random_vs_sf_nature_no_afisp.pdf', dpi=240)

In [None]:
pca_embeddings = plsr.x_loadings_ # pca.components_.T
plt.scatter(plsr.x_loadings_[:len(sf_filtered_df), 0], 
            plsr.x_loadings_[:len(sf_filtered_df), 1], 
            c=sf_filtered_df['AUROC'], cmap='viridis')

for i in range(len(aam_sirus_df)):
    
    x = pca_embeddings[i+len(sf_filtered_df), 0] - 0.007
    y = pca_embeddings[i+len(sf_filtered_df), 1]
    
    if i == 4 or i == 0:
        y -= 0.0025
    if i == 6:
        y += 0.0025
    if i == 2 or i == 8:
        y += 0.001
    if i == 9:
        y -= 0.005

    
    plt.annotate("", xy=pca_embeddings[i + len(sf_filtered_df)], xytext=(0, 0),
            arrowprops=dict(arrowstyle="->",color='k', linewidth=1))
    plt.text(x, 
             y,
             f"AFISP {i+1}", size=6
            )
#     plt.arrow(x=0, y=0, 
#               dx=pca_embeddings[i + len(sf_filtered_df), 0], 
#               dy=pca_embeddings[i + len(sf_filtered_df), 1], color='k')
    
# plt.scatter(pca_embeddings[-len(aam_sirus_df):, 0], 
#             pca_embeddings[-len(aam_sirus_df):, 1], 
#             label="AFISP rules")

plt.xlabel("Partial Least Squares Dimension 1")
plt.ylabel("Partial Least Squares Dimension 2")
plt.xlim(-0.06, 0.0025)
plt.colorbar(label='AUROC')
# plt.legend()
plt.savefig('figs/pls_auroc.pdf', dpi=240)