## Merge results from SHAP and LRP in test

In [1]:
import pickle

In [2]:
SEQ_LEN = 30
MODEL_NAME = "lstm-att-lrp"
best_epoch = 5

In [14]:
shap_results_path = f"./output/time_diff_toy_dataset_v3/event_based/{SEQ_LEN}/{MODEL_NAME}/test_results_shap_5.pkl"
lrp_results_path = f"./output/time_diff_toy_dataset_v3/event_based/{SEQ_LEN}/{MODEL_NAME}/test_results_lrp_5.pkl"

output_path = f"./output/time_diff_toy_dataset_v3/event_based/{SEQ_LEN}/{MODEL_NAME}/test_results_all_5.pkl"

In [4]:
with open(shap_results_path, 'rb') as fp:
    shap_results = pickle.load(fp)
    
with open(lrp_results_path, 'rb') as fp:
    lrp_results = pickle.load(fp)

In [5]:
lrp_results[best_epoch]

{'6ZOB2DEVY7': {'label': tensor([0]),
  'pred': 1.8426749051979867,
  'imp':     lrp_scores  idx  seq_idx                   token  att_weights
  0    -0.004621   17        0           dental_exam_N     0.012767
  1    -0.004601   17        1           dental_exam_N     0.018179
  2     0.334022    8        2              eye_exam_N     0.024975
  3     0.248345   27        3           sleep_apnea_H     0.028488
  4    -0.006096   10        4             cold_sore_N     0.033411
  5     0.245898   30        5              troponin_H     0.036757
  6     0.527560   14        6            cut_finger_N     0.038834
  7    -0.005007   22        7          ankle_sprain_N     0.039472
  8     0.002656   21        8  ventricular_aneurysm_A     0.040978
  9     0.644247    8        9              eye_exam_N     0.041668
  10   -0.804630    7       10              headache_N     0.040377
  11   -0.033287    6       11                myopia_N     0.039762
  12    0.013179   13       12           

In [6]:
print(f"LRP: {len(lrp_results[best_epoch])}")
print(f"SHAP: {len(shap_results[best_epoch])}")
assert len(lrp_results[best_epoch]) == len(shap_results[best_epoch])

LRP: 7000
SHAP: 7000


## Merge LRP and SHAP scores

In [7]:
for pid in lrp_results[best_epoch].keys():
    orig_len = lrp_results[best_epoch][pid]["imp"].shape[0]
    lrp_results[best_epoch][pid]["imp"] = lrp_results[best_epoch][pid]["imp"].merge(
        shap_results[best_epoch][pid]["imp"], on=["seq_idx", "token"]
    )
    assert orig_len == lrp_results[best_epoch][pid]["imp"].shape[0]
    lrp_results[best_epoch][pid]["imp"] = lrp_results[best_epoch][pid]["imp"][
        ["idx", "seq_idx", "token", "att_weights", "lrp_scores", "shap_scores"]
    ]

In [8]:
test_results = lrp_results

## Calculate similarity scores

In [9]:
import sys
import rbo
from scipy import stats
import numpy as np

def get_wtau(x, y):
    return stats.weightedtau(x, y, rank=None)[0]


def get_rbo(x, y, uid, p=0.7):
    x_idx = np.argsort(x)[::-1]
    y_idx = np.argsort(y)[::-1]

    return rbo.RankingSimilarity(
        [uid[idx] for idx in x_idx], [uid[idx] for idx in y_idx]
    ).rbo(p=p)


# calculate ground truth scores
def is_value(x):
    if "_N" in x:
        return False
    return True

In [10]:
# calculate similarity indexes for test
epoch_test_lrp_shap_t_corr = []
epoch_test_lrp_shap_rbo = []
epoch_test_lrp_sim = []
epoch_test_shap_sim = []

rbo_p = 0.95

In [11]:
for pid in test_results[best_epoch].keys():
    imp_df = test_results[best_epoch][pid]["imp"]
    imp_df["u_token"] = [
        str(seq) + "_" + str(token)
        for seq, token in zip(imp_df["seq_idx"], imp_df["token"])
    ]
    test_results[best_epoch][pid]["lrp_shap_t_corr"] = get_wtau(
        imp_df["lrp_scores"], imp_df["shap_scores"]
    )

    test_results[best_epoch][pid]["lrp_shap_rbo"] = get_rbo(
        imp_df["lrp_scores"],
        imp_df["shap_scores"],
        imp_df["u_token"].tolist(),
        p=rbo_p,
    )

    # gt similarity
    gt_idx = [x for x, tok in enumerate(imp_df.u_token) if is_value(tok)]
    n_gt = len(gt_idx)
    if n_gt > 0:
        lrp_idx = np.argsort(np.abs(imp_df.lrp_scores.values))[::-1][:n_gt]
        shap_idx = np.argsort(np.abs(imp_df.shap_scores.values))[::-1][:n_gt]
        att_idx = np.argsort(np.abs(imp_df.att_weights.values))[::-1][:n_gt]
        lrp_sim = len(set(lrp_idx).intersection(gt_idx)) / n_gt
        shap_sim = len(set(shap_idx).intersection(gt_idx)) / n_gt
        att_sim = len(set(att_idx).intersection(gt_idx)) / n_gt
    else:
        print(f"-1 is the output for {pid}")
        lrp_sim = -1
        shap_sim = -1
    test_results[best_epoch][pid]["lrp_sim"] = lrp_sim
    test_results[best_epoch][pid]["shap_sim"] = shap_sim
    test_results[best_epoch][pid]["att_sim"] = att_sim

In [12]:
test_results[best_epoch][pid]['imp'].head()

Unnamed: 0,idx,seq_idx,token,att_weights,lrp_scores,shap_scores,u_token
0,6,0,myopia_N,0.009735,-0.010823,-0.012647,0_myopia_N
1,12,1,quad_injury_N,0.012768,-0.008046,-0.008752,1_quad_injury_N
2,8,2,eye_exam_N,0.017213,0.290072,0.40271,2_eye_exam_N
3,6,3,myopia_N,0.021846,-0.015325,-0.017809,3_myopia_N
4,6,4,myopia_N,0.025275,-0.016718,-0.019076,4_myopia_N


## Save all combined results

In [19]:
output_path

'./output/time_diff_toy_dataset_v3/event_based/30/lstm-att-lrp/test_results_all_5.pkl'

In [23]:
with open(output_path, 'wb') as fpath:
    pickle.dump(test_results, fpath)

In [24]:
! ls "./output/time_diff_toy_dataset_v3/event_based/30/lstm-att-lrp/"

model_weights		test_results_lrp_5.pkl	 val_results_lrp_5.pkl
shap			test_results_shap_5.pkl  val_results_shap_5.pkl
test_results_all_5.pkl	train_results
