In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pdb
import terra
import meerkat as mk
from meerkat.contrib.eeg import build_stanford_eeg_dp
from domino.utils import split_dp, balance_dp, merge_in_split

from sklearn.metrics import precision_score, recall_score, roc_auc_score


## Load EEG datapanel

In [3]:
dp_art = build_stanford_eeg_dp.out(run_id=618) # with text constraint: 409
dp=dp_art.load()
dp.lz[0]

{'file_id': 'DA05510C_1-3+',
 'filepath': '/media/4tb_hdd/eeg_data/lpch/lpch/DA05510C_1-3+.eeghdf',
 'fm_split': 'train',
 'id': 'DA05510C_1-3+_-1.0',
 'sz_start_index': -1.0,
 'target': False,
 'index': '0',
 'input': LambdaCell(fn=functools.partial(<function stanford_eeg_loader at 0x7f0628ce78b0>, clip_len=60)),
 'age': 0.006124270674784373,
 'duration': 1325.9999999999998}

In [4]:
len(dp)

86091

In [5]:
balanced_dp_art = balance_dp.out(623) # with text constraint: 622
balanced_dp = balanced_dp_art.load()

In [6]:
len(balanced_dp)

58100

## split data

In [7]:
dp_splits_art = split_dp.out(625) #split_dp(balanced_dp_art, split_on="file_id")
dp_splits = dp_splits_art.load()
print(len(dp_splits))
dp_splits.head()

44128


Unnamed: 0,file_id (PandasSeriesColumn),split (PandasSeriesColumn),index (PandasSeriesColumn)
0,DA1120VJ_1-1+,train,0
1,DA0551CN_1-1+,valid,1
2,DA0552WC_1-2+,test,2
3,CA8312E5_1-7+,test,3
4,DA00106S_1-2+,train,4


In [8]:
thresh = 1
age_labels = balanced_dp["age"] > thresh
sz_labels = balanced_dp["target"]
np.corrcoef([age_labels,sz_labels])[1,0]

-0.0973282739180349

## Slice dp based on metadata

In [9]:
from domino.slices.eeg import EegSliceBuilder

dp_age = EegSliceBuilder().build_correlation_setting(balanced_dp, correlate="age", corr=0.9, n=8000, correlate_threshold=thresh)

In [10]:
len(dp_age)

7998

In [11]:
age_labels = dp_age["age"] > thresh
sz_labels = dp_age["target"]
np.corrcoef([age_labels,sz_labels])[0,1]

0.8999749937484369

## Get multiple slices

In [12]:
# from domino.slices.eeg import collect_correlation_settings

# # correlate_list = ["age"]
# # corr_list = [0, 0.3, 0.5, 0.9]
# # correlate_thresholds = [10]
# # dp_slices_art = collect_correlation_slices(correlate_list, corr_list, correlate_thresholds)
# dp_slices_art = collect_correlation_settings.out(516)
# dp_slices = dp_slices_art.load()
# dp_slices.head()

## Score slices

In [13]:
from domino.train import score_settings, score_model, train_model
from domino.metrics import compute_model_metrics




In [14]:
biased_model_dp = score_model.out(691).load()
print(len(biased_model_dp))
print(compute_model_metrics(biased_model_dp, num_iter=1000, flat=True))

774
{'overall_auroc': 0.7802465826856071, 'overall_recall': 0.9536585365853658, 'overall_precision': 0.6378466557911908, 'overall_f1_score': 0.7644183773216031, 'in_slice_0_auroc': 0.017647058823529415, 'in_slice_0_recall': 0.35, 'in_slice_0_precision': 0.2916666666666667, 'in_slice_0_f1_score': 0.31818181818181823, 'out_slice_auroc': 0.822966082908446, 'out_slice_recall': 0.9846153846153847, 'out_slice_precision': 0.6519524617996605, 'out_slice_f1_score': 0.7844739530132789}


In [15]:
# import matplotlib.pyplot as plt

# metric = "auroc"

# scores_dp = score_slices.out(652).load() # on valid: 627, on test: 595

# plt.plot(scores_dp["corr"].data, scores_dp[f"in_slice_0_{metric}"].data, color="green")
# #plt.scatter(scores_dp["corr"].data, scores_dp[f"in_slice_0_{metric}_mean"].data, color="green")
# #plt.fill_between(scores_dp["corr"].data, scores_dp[f"in_slice_0_{metric}_lower"].data, scores_dp[f"in_slice_0_{metric}_upper"].data, alpha=0.3, color="green")

# plt.plot(scores_dp["corr"].data, scores_dp[f"out_slice_{metric}"].data, color="red")
# #plt.scatter(scores_dp["corr"].data, scores_dp[f"out_slice_{metric}_mean"].data, color="red")
# #plt.fill_between(scores_dp["corr"].data, scores_dp[f"out_slice_{metric}_lower"].data, scores_dp[f"out_slice_{metric}_upper"].data, alpha=0.3, color="red")

# plt.plot(scores_dp["corr"].data, scores_dp[f"overall_{metric}"].data, color="blue")
# #lt.scatter(scores_dp["corr"].data, scores_dp[f"overall_{metric}_mean"].data, color="blue")
# #plt.fill_between(scores_dp["corr"].data, scores_dp[f"overall_{metric}_lower"].data, scores_dp[f"overall_{metric}_upper"].data, alpha=0.3, color="blue")



# plt.legend(["C = Y", "C != Y","overall"])
# plt.ylabel(f"mean {metric}")
# plt.xlabel("correlation strength")
# plt.title("EEG seizure prediction, age slicing")
# plt.show()

In [16]:
# dp_ = train_model.inp(578)["dp"].load()
# mask = np.logical_and((dp_["slices"].data[:,0]==1),dp_.lz["split"]=="valid")

# dp_["target","binarized_age"].lz[mask][-10:]
# dp_["target"].lz[mask].sum()

## Emed EEGs and Text using multimodal model

In [17]:
from domino.emb.eeg import embed_eeg_text

In [18]:
score_dp_dev = score_model.out(677).load()
score_dp_train = score_model.out(689).load()
score_dp = mk.concat([score_dp_dev, score_dp_train])

print(len(score_dp_train))
print(len(score_dp_dev))
print(len(score_dp))

5693
771
6464




In [19]:
multimodal_corpus_dp = build_stanford_eeg_dp.out(run_id=696, load=True)  # for run with narrative: run_id = 696
multimodal_split_dp = split_dp.out(697, load=True) # for run with narrative: run_id = 697
multimodal_corpus_dp = merge_in_split(multimodal_corpus_dp, multimodal_split_dp)

multimodal_corpus_dp_emb = embed_eeg_text.out(713,load=True) #(dp=multimodal_corpus_dp,model=terra.get(704, "best_chkpt")["model"], layers={"fc1": "model.fc1"}, device=0, batch_size=1) 
eeg_corpus_dp_emb = embed_eeg_text.out(711,load=True)
# for eeg only embed_eeg_text run_id = 711, model run_id=709
# for multimodal run with narrative embed_eeg_text run_id = 713, model run_id = 704, 
# for multimodal run without narrative run_id=659

print(len(multimodal_corpus_dp_emb))
print(len(eeg_corpus_dp_emb))

6254
6254


## Score train samples in multimodal corpus

In [20]:
valid_mask = multimodal_corpus_dp_emb["split"]=="valid"
train_mask = multimodal_corpus_dp_emb["split"]=="train"
dp_emb_train = multimodal_corpus_dp_emb.lz[train_mask]

# get the average embedding of predicted seizure on the dev set
sz_preds = biased_model_dp["output"].argmax(1)
pred_emb = biased_model_dp["fc1"].lz[sz_preds].mean(0)

#get the average embedding of seizures on the dev set
sz_emb = biased_model_dp["fc1"].lz[biased_model_dp["target"]].mean(0)

# score the training samples 
train_embs = dp_emb_train["fc1"]
train_scores = np.dot(train_embs, (pred_emb-sz_emb))
dp_emb_train["scores"] = train_scores

(dp_emb_train.lz[(-train_scores).argsort()[:20]]["age"]< 1).mean()

0.7

## Merge multimodal and EEG embeddings in biased_model_dp

In [21]:
# biased_model_emb_dp = embed_eeg_text(dp=biased_model_dp,model=terra.get(704, "best_chkpt")["model"], layers={"multimodal_fc1": "model.fc1"}, device=0, batch_size=1)
# biased_model_emb_dp = embed_eeg_text(dp=embed_eeg_text,model=terra.get(709, "best_chkpt")["model"], layers={"eeg_fc1": "model.fc1"}, device=0, batch_size=1)
biased_model_emb_dp = embed_eeg_text.out(718,load=True)
print(len(biased_model_emb_dp))
biased_model_emb_dp.head()

774


Unnamed: 0,file_id (PandasSeriesColumn),sz_start_index (NumpyArrayColumn),fm_split (PandasSeriesColumn),filepath (PandasSeriesColumn),target (NumpyArrayColumn),input (LambdaColumn),duration (NumpyArrayColumn),id (PandasSeriesColumn),age (NumpyArrayColumn),index (PandasSeriesColumn),binarized_age (NumpyArrayColumn),slices (NumpyArrayColumn),split (PandasSeriesColumn),output (ClassificationOutputColumn),fc1 (TensorColumn),multimodal_fc1 (NumpyArrayColumn),eeg_fc1 (NumpyArrayColumn)
0,ZA0054I5_1-1+,3.0,train,/media/4tb_hdd/eeg_data/stanford/stanford_mini/ZA0054I5_1-1+.eeghdf,True,"LambdaCell(fn=functools.partial(, clip_len=60))",1971.0,ZA0054I5_1-1+_3.0,66.137876,0,1,"np.ndarray(shape=(1,))",valid,torch.Tensor(shape=torch.Size([2])),torch.Tensor(shape=torch.Size([128])),"np.ndarray(shape=(128,))","np.ndarray(shape=(128,))"
1,ZA0054I5_1-1+,15.0,train,/media/4tb_hdd/eeg_data/stanford/stanford_mini/ZA0054I5_1-1+.eeghdf,True,"LambdaCell(fn=functools.partial(, clip_len=60))",1971.0,ZA0054I5_1-1+_15.0,66.137876,1,1,"np.ndarray(shape=(1,))",valid,torch.Tensor(shape=torch.Size([2])),torch.Tensor(shape=torch.Size([128])),"np.ndarray(shape=(128,))","np.ndarray(shape=(128,))"
2,ZA0054I5_1-1+,4.0,train,/media/4tb_hdd/eeg_data/stanford/stanford_mini/ZA0054I5_1-1+.eeghdf,True,"LambdaCell(fn=functools.partial(, clip_len=60))",1971.0,ZA0054I5_1-1+_4.0,66.137876,2,1,"np.ndarray(shape=(1,))",valid,torch.Tensor(shape=torch.Size([2])),torch.Tensor(shape=torch.Size([128])),"np.ndarray(shape=(128,))","np.ndarray(shape=(128,))"
3,ZA0054I5_1-1+,18.0,train,/media/4tb_hdd/eeg_data/stanford/stanford_mini/ZA0054I5_1-1+.eeghdf,True,"LambdaCell(fn=functools.partial(, clip_len=60))",1971.0,ZA0054I5_1-1+_18.0,66.137876,3,1,"np.ndarray(shape=(1,))",valid,torch.Tensor(shape=torch.Size([2])),torch.Tensor(shape=torch.Size([128])),"np.ndarray(shape=(128,))","np.ndarray(shape=(128,))"
4,ZA0054I5_1-1+,11.0,train,/media/4tb_hdd/eeg_data/stanford/stanford_mini/ZA0054I5_1-1+.eeghdf,True,"LambdaCell(fn=functools.partial(, clip_len=60))",1971.0,ZA0054I5_1-1+_11.0,66.137876,4,1,"np.ndarray(shape=(1,))",valid,torch.Tensor(shape=torch.Size([2])),torch.Tensor(shape=torch.Size([128])),"np.ndarray(shape=(128,))","np.ndarray(shape=(128,))"


## Fit and Score SDMs

In [22]:
from domino.sdm import MixtureModelSDM, SpotlightSDM
from domino.metrics import compute_sdm_metrics, compute_bootstrap_ci

sdm = MixtureModelSDM(
    n_slices=10, 
    n_clusters=10, 
    weight_y_log_likelihood=10, 
    init_params="error",
    emb="multimodal_fc1",
    pca_components=128 
)

# sdm = SpotlightSDM(
#     learning_rate=1e-3,
#     n_slices=10, 
#     emb="eeg_fc1",
#     min_weight=10,
# )

biased_model_emb_dp["slices"] = np.array([((biased_model_emb_dp["binarized_age"]==0) * (biased_model_emb_dp["target"]==1)),((biased_model_emb_dp["binarized_age"]==1) * (biased_model_emb_dp["target"]==0))]).T
biased_model_emb_dp["pred"] = biased_model_emb_dp["output"][:,1].sigmoid().numpy()

num_runs = 10
top_auroc = []
for n in range(num_runs):
    # fit SDM
    
    sdm.fit(biased_model_emb_dp)
    sdm_dp = sdm.transform(biased_model_emb_dp)

    # score slices
    slice_idx = 0
    metrics_df = compute_sdm_metrics(sdm_dp)
    metrics_df = metrics_df[metrics_df["slice_idx"] == slice_idx]
    top_auroc.append(metrics_df["auroc"].max())
    #metrics_df[metrics_df["slice_idx"] == slice_idx].sort_values(by="auroc", ascending=False)

top_auroc = np.array(top_auroc)
print(np.mean(top_auroc))
print(compute_bootstrap_ci(top_auroc))

 40%|████      | 40/100 [00:00<00:00, 649.48it/s]
 73%|███████▎  | 73/100 [00:00<00:00, 784.33it/s]
 29%|██▉       | 29/100 [00:00<00:00, 810.95it/s]
 81%|████████  | 81/100 [00:00<00:00, 759.04it/s]
 37%|███▋      | 37/100 [00:00<00:00, 642.42it/s]
 63%|██████▎   | 63/100 [00:00<00:00, 463.36it/s]
 40%|████      | 40/100 [00:00<00:00, 369.74it/s]
 52%|█████▏    | 52/100 [00:00<00:00, 524.56it/s]
 49%|████▉     | 49/100 [00:00<00:00, 757.69it/s]
 47%|████▋     | 47/100 [00:00<00:00, 731.16it/s]


0.839973474801061
{'mean': 0.839973474801061, 'lower': 0.8289350132625993, 'upper': 0.849990550397878}


# Explain

In [31]:
# create vocabulary from all reports
from domino.emb.eeg import generate_words_dp

all_narratives = multimodal_corpus_dp["narrative"].data
words_dp = generate_words_dp(all_narratives, min_threshold=50)
print(len(words_dp))
words_dp.head()

1056


Unnamed: 0,word (PandasSeriesColumn),frequency (PandasSeriesColumn),index (PandasSeriesColumn)
0,of,22102,0
1,mg,18075,1
2,for,14418,2
3,and,14246,3
4,eeg,13178,4


In [32]:
from domino.emb.eeg import embed_words 

# embed words in vocab
emb_words_dp = embed_words(words_dp,model=terra.get(704, "best_chkpt")["model"], device=0, batch_size=1).load()
emb_words_dp.head()

task: embed_words, run_id=731


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


CheXbert: skipping param linear_heads.0.weight
CheXbert: skipping param linear_heads.0.bias
CheXbert: skipping param linear_heads.1.weight
CheXbert: skipping param linear_heads.1.bias
CheXbert: skipping param linear_heads.2.weight
CheXbert: skipping param linear_heads.2.bias
CheXbert: skipping param linear_heads.3.weight
CheXbert: skipping param linear_heads.3.bias
CheXbert: skipping param linear_heads.4.weight
CheXbert: skipping param linear_heads.4.bias
CheXbert: skipping param linear_heads.5.weight
CheXbert: skipping param linear_heads.5.bias
CheXbert: skipping param linear_heads.6.weight
CheXbert: skipping param linear_heads.6.bias
CheXbert: skipping param linear_heads.7.weight
CheXbert: skipping param linear_heads.7.bias
CheXbert: skipping param linear_heads.8.weight
CheXbert: skipping param linear_heads.8.bias
CheXbert: skipping param linear_heads.9.weight
CheXbert: skipping param linear_heads.9.bias
CheXbert: skipping param linear_heads.10.weight
CheXbert: skipping param linear_



  0%|          | 0/1056 [00:00<?, ?it/s]



Unnamed: 0,word (PandasSeriesColumn),frequency (PandasSeriesColumn),index (PandasSeriesColumn),emb (NumpyArrayColumn)
0,of,22102,0,"np.ndarray(shape=(128,))"
1,mg,18075,1,"np.ndarray(shape=(128,))"
2,for,14418,2,"np.ndarray(shape=(128,))"
3,and,14246,3,"np.ndarray(shape=(128,))"
4,eeg,13178,4,"np.ndarray(shape=(128,))"


In [35]:
pred_slice_idx = 0
expl_dp = sdm.explain(words_dp=words_dp, data_dp=sdm_dp)
expl_dp.lz[(-expl_dp["pred_slices"][:, pred_slice_idx]).argsort()[:10]]

Unnamed: 0,word (PandasSeriesColumn),pred_slices (NumpyArrayColumn),frequency (PandasSeriesColumn)
0,razavi,"np.ndarray(shape=(10,))",85
1,possibly,"np.ndarray(shape=(10,))",50
2,norcuron,"np.ndarray(shape=(10,))",75
3,possible,"np.ndarray(shape=(10,))",2304
4,ca,"np.ndarray(shape=(10,))",63
5,nihon,"np.ndarray(shape=(10,))",465
6,picu,"np.ndarray(shape=(10,))",319
7,suspected,"np.ndarray(shape=(10,))",54
8,tube,"np.ndarray(shape=(10,))",60
9,graber,"np.ndarray(shape=(10,))",69
