In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pdb
import terra
import meerkat as mk
from meerkat.contrib.eeg import build_stanford_eeg_dp
from domino.utils import split_dp, balance_dp, merge_in_split

from sklearn.metrics import precision_score, recall_score, roc_auc_score
import eeghdf


## Load EEG datapanel

In [3]:
dp_art = build_stanford_eeg_dp.out(run_id=618) # for 60sec: 618, with text constraint: 409
dp=dp_art.load()
dp.lz[0]

{'file_id': 'DA05510C_1-3+',
 'filepath': '/media/4tb_hdd/eeg_data/lpch/lpch/DA05510C_1-3+.eeghdf',
 'fm_split': 'train',
 'id': 'DA05510C_1-3+_-1.0',
 'sz_start_index': -1.0,
 'target': False,
 'index': '0',
 'input': LambdaCell(fn=functools.partial(<function stanford_eeg_loader at 0x7f2fefc85e50>, clip_len=60)),
 'age': 0.006124270674784373,
 'duration': 1325.9999999999998}

In [4]:
len(dp)

86091

In [5]:
balanced_dp_art = balance_dp.out(623) # 928, 60 sec: 623, with text constraint: 622
balanced_dp = balanced_dp_art.load()

In [6]:
len(balanced_dp)

58100

## split data

In [7]:
dp_splits_art = split_dp(balanced_dp_art, split_on="file_id")
dp_splits = dp_splits_art.load()
print(len(dp_splits))
dp_splits.head()

44128


Unnamed: 0,file_id (PandasSeriesColumn),split (PandasSeriesColumn),index (PandasSeriesColumn)
0,DA1120VJ_1-1+,train,0
1,DA0551CN_1-1+,valid,1
2,DA0552WC_1-2+,test,2
3,CA8312E5_1-7+,test,3
4,DA00106S_1-2+,train,4


In [30]:
thresh = 1
age_labels = balanced_dp["age"] < thresh
sz_labels = balanced_dp["target"]
#np.corrcoef([age_labels,sz_labels])[1,0]
print(age_labels.mean())

0.1786836562035982


## Slice dp based on metadata

In [50]:
from domino.slices.eeg import EegSliceBuilder

#dp_age = EegSliceBuilder().build_correlation_setting(balanced_dp, correlate="age", corr=0.0, n=8000, correlate_threshold=thresh)
dp_age = EegSliceBuilder().build_rare_setting(balanced_dp, attribute="age", attribute_thresh=1, slice_frac=0.001, n=8000)

In [51]:
len(dp_age)

8000

In [52]:
age_labels = dp_age["age"] < thresh
sz_labels = dp_age["target"]
#np.corrcoef([age_labels,sz_labels])[0,1]
print(age_labels.mean())
print(sz_labels.mean())

0.001
0.46975


## Get multiple slices

In [None]:
# from domino.slices.eeg import collect_correlation_settings

# # correlate_list = ["age"]
# # corr_list = [0, 0.3, 0.5, 0.9]
# # correlate_thresholds = [10]
# # dp_slices_art = collect_correlation_slices(correlate_list, corr_list, correlate_thresholds)
# dp_slices_art = collect_correlation_settings.out(516)
# dp_slices = dp_slices_art.load()
# dp_slices.head()

## Score slices

In [None]:
from domino.train import score_settings, score_model, train_model
from domino.metrics import compute_model_metrics


In [None]:
biased_model_dp = score_model.out(691).load()
print(len(biased_model_dp))
print(compute_model_metrics(biased_model_dp, num_iter=1000, flat=True))

In [None]:
# import matplotlib.pyplot as plt

# metric = "auroc"

# scores_dp = score_slices.out(652).load() # on valid: 627, on test: 595

# plt.plot(scores_dp["corr"].data, scores_dp[f"in_slice_0_{metric}"].data, color="green")
# #plt.scatter(scores_dp["corr"].data, scores_dp[f"in_slice_0_{metric}_mean"].data, color="green")
# #plt.fill_between(scores_dp["corr"].data, scores_dp[f"in_slice_0_{metric}_lower"].data, scores_dp[f"in_slice_0_{metric}_upper"].data, alpha=0.3, color="green")

# plt.plot(scores_dp["corr"].data, scores_dp[f"out_slice_{metric}"].data, color="red")
# #plt.scatter(scores_dp["corr"].data, scores_dp[f"out_slice_{metric}_mean"].data, color="red")
# #plt.fill_between(scores_dp["corr"].data, scores_dp[f"out_slice_{metric}_lower"].data, scores_dp[f"out_slice_{metric}_upper"].data, alpha=0.3, color="red")

# plt.plot(scores_dp["corr"].data, scores_dp[f"overall_{metric}"].data, color="blue")
# #lt.scatter(scores_dp["corr"].data, scores_dp[f"overall_{metric}_mean"].data, color="blue")
# #plt.fill_between(scores_dp["corr"].data, scores_dp[f"overall_{metric}_lower"].data, scores_dp[f"overall_{metric}_upper"].data, alpha=0.3, color="blue")



# plt.legend(["C = Y", "C != Y","overall"])
# plt.ylabel(f"mean {metric}")
# plt.xlabel("correlation strength")
# plt.title("EEG seizure prediction, age slicing")
# plt.show()

In [None]:
# dp_ = train_model.inp(578)["dp"].load()
# mask = np.logical_and((dp_["slices"].data[:,0]==1),dp_.lz["split"]=="valid")

# dp_["target","binarized_age"].lz[mask][-10:]
# dp_["target"].lz[mask].sum()

## Emed EEGs and Text using multimodal model

In [None]:
from domino.emb.eeg import embed_eeg

In [None]:
score_dp_dev = score_model.out(677).load()
score_dp_train = score_model.out(689).load()
score_dp = mk.concat([score_dp_dev, score_dp_train])

print(len(score_dp_train))
print(len(score_dp_dev))
print(len(score_dp))

In [None]:
multimodal_corpus_dp = build_stanford_eeg_dp.out(run_id=696, load=True)  # for run with narrative: run_id = 696
multimodal_split_dp = split_dp.out(697, load=True) # for run with narrative: run_id = 697
multimodal_corpus_dp = merge_in_split(multimodal_corpus_dp, multimodal_split_dp)

multimodal_corpus_dp_emb = embed_eeg.out(743, load=True) #(dp=multimodal_corpus_dp,model=terra.get(715, "best_chkpt")["model"], layers={"fc1": "model.fc1"}, device=0, batch_size=10) 
eeg_corpus_dp_emb = embed_eeg.out(711,load=True)
# for eeg only embed_eeg_text run_id = 711, model run_id=709
# for multimodal run with narrative and only 10 epochs embed_eeg_text run_id 743, model run_id 715
# for multimodal run with narrative embed_eeg_text run_id = 713, model run_id = 704, 
# for multimodal run without narrative run_id=659

print(len(multimodal_corpus_dp_emb))
print(len(eeg_corpus_dp_emb))

## Score train samples in multimodal corpus

In [None]:
valid_mask = multimodal_corpus_dp_emb["split"]=="valid"
train_mask = multimodal_corpus_dp_emb["split"]=="train"
dp_emb_train = multimodal_corpus_dp_emb.lz[train_mask]

# get the average embedding of predicted seizure on the dev set
sz_preds = biased_model_dp["output"].argmax(1)
pred_emb = biased_model_dp["fc1"].lz[sz_preds].mean(0)

#get the average embedding of seizures on the dev set
sz_emb = biased_model_dp["fc1"].lz[biased_model_dp["target"]].mean(0)

# score the training samples 
train_embs = dp_emb_train["fc1"]
train_scores = np.dot(train_embs, (pred_emb-sz_emb))
dp_emb_train["scores"] = train_scores

(dp_emb_train.lz[(-train_scores).argsort()[:20]]["age"]< 1).mean()

## Merge multimodal and EEG embeddings in biased_model_dp

In [None]:
#biased_model_emb_dp = embed_eeg_text(dp=biased_model_dp,model=terra.get(704, "best_chkpt")["model"], layers={"multimodal_fc1": "model.fc1"}, device=0, batch_size=1).load()
#biased_model_emb_dp = embed_eeg_text(dp=biased_model_emb_dp,model=terra.get(709, "best_chkpt")["model"], layers={"eeg_fc1": "model.fc1"}, device=0, batch_size=1).load()
biased_model_emb_dp = embed_eeg_text.out(718,load=True)
print(len(biased_model_emb_dp))
biased_model_emb_dp.head()

## Fit and Score SDMs

In [None]:
from domino.sdm import MixtureModelSDM, SpotlightSDM
from domino.metrics import compute_sdm_metrics, compute_bootstrap_ci

sdm = MixtureModelSDM(
    n_slices=10, 
    n_clusters=10, 
    weight_y_log_likelihood=10, 
    init_params="error",
    emb="multimodal_fc1",
    pca_components=128 
)

# sdm = SpotlightSDM(
#     learning_rate=1e-3,
#     n_slices=10, 
#     emb="eeg_fc1",
#     min_weight=10,
# )

original_slices = biased_model_emb_dp["slices"]
biased_model_emb_dp["slices"] = np.array([((biased_model_emb_dp["binarized_age"]==0) * (biased_model_emb_dp["target"]==1)),((biased_model_emb_dp["binarized_age"]==1) * (biased_model_emb_dp["target"]==0))]).T
biased_model_emb_dp["pred"] = biased_model_emb_dp["output"][:,1].sigmoid().numpy()

num_runs = 10
top_auroc = []
for n in range(num_runs):
    # fit SDM
    
    sdm.fit(biased_model_emb_dp)
    sdm_dp = sdm.transform(biased_model_emb_dp)

    # score slices
    slice_idx = 0
    metrics_df = compute_sdm_metrics(sdm_dp)
    metrics_df = metrics_df[metrics_df["slice_idx"] == slice_idx]
    top_auroc.append(metrics_df["auroc"].max())
    #metrics_df[metrics_df["slice_idx"] == slice_idx].sort_values(by="auroc", ascending=False)

top_auroc = np.array(top_auroc)
print(compute_bootstrap_ci(top_auroc))

# Explain

In [None]:
# create vocabulary from all reports
from domino.emb.eeg import generate_words_dp

all_narratives = multimodal_corpus_dp["narrative"].data
words_dp = generate_words_dp(all_narratives, min_threshold=1)
print(len(words_dp))
words_dp.head()

In [None]:
random_ndxs = np.random.choice(len(words_dp),10)
words_dp[random_ndxs]

In [None]:
from domino.emb.eeg import embed_words 

# embed words in vocab
emb_words_dp = embed_words(words_dp,model=terra.get(704, "best_chkpt")["model"], device=0, batch_size=1).load()
emb_words_dp.head()

In [None]:
#biased_model_emb_dp["slices"] = original_slices

sdm.fit(biased_model_emb_dp)
sdm_dp = sdm.transform(biased_model_emb_dp)

# score slices
slice_idx = 1
metrics_df = compute_sdm_metrics(sdm_dp)
metrics_df = metrics_df[metrics_df["slice_idx"] == slice_idx]

print("max auroc: ", metrics_df["auroc"].max())
pred_slice_idx = metrics_df["auroc"].argmax()
expl_dp = sdm.explain(words_dp=emb_words_dp, data_dp=sdm_dp)
expl_dp.lz[(-expl_dp["pred_slices"][:, pred_slice_idx]).argsort()[:10]]

In [None]:
expl_dp.lz[(-expl_dp["pred_slices"][:, pred_slice_idx]).argsort()[10:20]]

In [None]:
expl_dp.lz[(-expl_dp["pred_slices"][:, pred_slice_idx]).argsort()[20:30]]