In [None]:
import numpy as np
import terra

from meerkat.contrib.eeg import build_stanford_eeg_dp

from domino.emb.eeg import generate_words_dp, embed_words, embed_eeg
from domino.evaluate import run_sdms, score_sdm_explanations, score_sdms, run_sdm
from domino.sdm import MixtureModelSDM, SpotlightSDM
from domino.slices import collect_settings
from domino.train import score_settings, synthetic_score_settings, train_settings
from domino.utils import split_dp, balance_dp

import meerkat as mk
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import precision_score, confusion_matrix, accuracy_score, roc_auc_score


In [None]:
setting_dp = collect_settings.out(load=True)
setting_dp

In [None]:
run_sdms_dp = run_sdms.out(load=True)
run_sdms_dp.head()

In [None]:
score_dp = mk.DataPanel.from_pandas(score_sdms.out(load=True))
score_dp

In [None]:
results_dp = mk.merge(
    score_dp,
    run_sdms_dp["sdm_class", "config/sdm", "alpha", "run_sdm_run_id", "build_setting_kwargs", "slice_category"], # include any other columns here you'd like to analyze
    on="run_sdm_run_id"
)
emb_col = results_dp["config/sdm"].map(lambda x: x["sdm_config"]["emb"][0])
results_dp["emb_type"] = emb_col

corr_thresh_col = results_dp["build_setting_kwargs"].map(lambda x: x["correlate_threshold"] if "correlate_threshold" in x else x["attribute_thresh"])
results_dp["corr_thresh"] = corr_thresh_col


results_df = results_dp.to_pandas()
results_df

In [None]:
metric = "precision_at_10"
grouped_df = results_df.iloc[results_df.reset_index().groupby(["sdm_class", "slice_name", "slice_idx","emb_type", "alpha", "corr_thresh", "slice_category"])[metric].idxmax().astype(int)]

grouped_df["success"] = grouped_df["precision_at_10"] > 0.6

print(len(grouped_df))
#ax = sns.lineplot(data=grouped_df,x="alpha",y=metric, hue="emb_type",style="slice_name")
#ax = sns.violinplot(data=grouped_df,x=metric,y="emb_type")
ax = sns.barplot(data=grouped_df,y=metric,x="emb_type",hue="sdm_class")
#ax = sns.displot(data=grouped_df,x=metric,hue="sdm_class")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.show()


In [None]:
# metric = "auroc"
# grouped_df = results_df.iloc[results_df.reset_index().groupby(["sdm_class", "slice_name", "slice_idx","emb_type","alpha","corr_thresh"])[metric].idxmax().astype(int)]


# #ax = sns.lineplot(data=grouped_df,x="alpha",y=metric, hue="emb_type",style="slice_name")
# ax = sns.barplot(data=grouped_df,x=metric,y="slice_category",hue="sdm_class")
# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# plt.show()

In [None]:
specific_run = grouped_df[grouped_df["emb_type"]=="multimodal"][grouped_df["slice_idx"]==0][grouped_df["alpha"]==0][grouped_df["sdm_class"]=="domino.sdm.confusion.ConfusionSDM"]
specific_run

In [None]:
sdm_dp, _ = run_sdm.out(specific_run["run_sdm_run_id"].values[0],load=True)
sdm_dp.head()

In [None]:
from domino.slices.abstract import build_setting

specific_run = grouped_df[grouped_df["emb_type"]=="multimodal"][grouped_df["slice_idx"]==0][grouped_df["alpha"]==0][grouped_df["sdm_class"]=="domino.sdm.confusion.ConfusionSDM"]
sdm_dp, _ = run_sdm.out(specific_run["run_sdm_run_id"].values[0],load=True)

preds = np.array(sdm_dp.lz["probs"].argmax(1))
targets = np.array(sdm_dp.lz["target"])
in_slice = sdm_dp.lz["slices"][:,0]

print(targets.mean())
tn, fp, fn, tp = confusion_matrix(targets,preds,normalize="true").ravel()
print(f"FP: {fp}, FN: {fn}")

pred_slices = sdm_dp.lz["pred_slices"]
#for ndx in range(10):
ndx=1
pp = pred_slices[:,ndx]
pp = (targets==0)*np.array(sdm_dp.lz["probs"][:,1]>0.5)
#print(roc_auc_score(in_slice,pp))
print(precision_score(in_slice,pp))

#print((preds != targets).mean())
#print((preds[in_slice]!= targets[in_slice]).mean())
fns = np.array((targets==0)*(preds==1))
print(in_slice[fns].mean())

#synth_dp.head()

In [None]:
pred_slices.argmax(1).min()

In [None]:
# grouped_df["success"] = np.logical_and(grouped_df["auroc"] > 0.7, grouped_df["precision_at_10"] > 0.4)

# #ax = sns.pointplot(data=grouped_df,x="alpha",y="success", hue="emb_type", dodge=True, join=False)
# ax = sns.barplot(data=grouped_df,x="success",y="emb_type")

# plt.show()

# Explanations

In [None]:
words_dp = embed_words.out(load=True)

# get multimodal sdm run ids
slice_id = 0
alpha = 0.8

specific_run = grouped_df[grouped_df["emb_type"]=="multimodal"][grouped_df["slice_idx"]==slice_id][grouped_df["alpha"]==alpha]
specific_run

In [None]:
pred_slice_idx = specific_run["pred_slice_idx"].values[0]
_, expl_dp = run_sdm.out(specific_run["run_sdm_run_id"].values[0],load=True)
expl_dp[(-expl_dp["pred_slices"].data[:,pred_slice_idx]).argsort()][:10]

In [None]:
expl_dp[(-expl_dp["pred_slices"].data[:,pred_slice_idx]).argsort()][10:20]