## Merge LSTM LRP and SHAP of All Test/Val Data

This notebook is for merging the results of LRP and SHAP scores for all test and validation data. You need to first run [LSTM LRP](lstm_lrp_all_data.py) and [LSTM SHAP](lstm_shap_all_data.py) to compute the LRP and SHAP scores respectively.

In [20]:
import pickle
import os
import sys
from scipy import stats
import numpy as np

import imp_utils

In [14]:
IS_SYNTHETIC = True #Whether data is synthetic

BEST_EPOCH = 2
DATA_SPLIT = 'val' #val/test

RESULTS_DIR = f'./output/synthetic/{DATA_TYPE}/{SEQ_LEN}/{MODEL_NAME}/importances/'

SHAP_RESULTS_PATH = os.path.join(RESULTS_DIR, f"{DATA_SPLIT}_all_shap_{BEST_EPOCH}.pkl")
LRP_RESULTS_PATH = os.path.join(RESULTS_DIR, f"{DATA_SPLIT}_all_lrp_{BEST_EPOCH}.pkl")
OUTPUT_PATH = os.path.join(RESULTS_DIR, f"{DATA_SPLIT}_all_lrp_shap_{BEST_EPOCH}.pkl")

In [15]:
with open(SHAP_RESULTS_PATH, 'rb') as fp:
    shap_results = pickle.load(fp)
    
with open(LRP_RESULTS_PATH, 'rb') as fp:
    lrp_results = pickle.load(fp)

In [16]:
print(f"LRP: {len(lrp_results[BEST_EPOCH])}")
print(f"SHAP: {len(shap_results[BEST_EPOCH])}")
assert len(lrp_results[BEST_EPOCH]) == len(shap_results[BEST_EPOCH])

LRP: 7000
SHAP: 7000


## Merge LRP and SHAP scores

In [17]:
for pid in lrp_results[BEST_EPOCH].keys():
    orig_len = lrp_results[BEST_EPOCH][pid]["imp"].shape[0]
    lrp_results[BEST_EPOCH][pid]["imp"] = lrp_results[BEST_EPOCH][pid]["imp"].merge(
        shap_results[BEST_EPOCH][pid]["imp"], on=["seq_idx", "token"]
    )
    assert orig_len == lrp_results[BEST_EPOCH][pid]["imp"].shape[0]
    lrp_results[BEST_EPOCH][pid]["imp"] = lrp_results[BEST_EPOCH][pid]["imp"][
        ["idx", "seq_idx", "token", "att_weights", "lrp_scores", "shap_scores"]
    ]
results = lrp_results

## Calculate similarity scores

In [18]:
# calculate similarity indexes
epoch_lrp_shap_t_corr = []
epoch_lrp_sim = []
epoch_shap_sim = []

In [21]:
for pid in results[BEST_EPOCH].keys():
    imp_df = results[BEST_EPOCH][pid]["imp"]
    imp_df["u_token"] = [
        str(seq) + "_" + str(token)
        for seq, token in zip(imp_df["seq_idx"], imp_df["token"])
    ]
    results[BEST_EPOCH][pid]["lrp_shap_t_corr"] = get_wtau(
        imp_df["lrp_scores"], imp_df["shap_scores"]
    )

    # gt similarity
    lrp_sim = imp_utils.get_intersection_similarity(
        imp_df.lrp_scores, imp_df.token, freedom=0, is_synthetic=IS_SYNTHETIC
    )
    shap_sim = imp_utils.get_intersection_similarity(
        imp_df.shap_scores, imp_df.token, freedom=0, is_synthetic=IS_SYNTHETIC
    )
    att_sim = imp_utils.get_intersection_similarity(
        imp_df.att_weights, imp_df.token, freedom=0, is_synthetic=IS_SYNTHETIC
    )
    results[BEST_EPOCH][pid]["lrp_sim"] = lrp_sim
    results[BEST_EPOCH][pid]["shap_sim"] = shap_sim
    results[BEST_EPOCH][pid]["att_sim"] = att_sim

In [22]:
results[BEST_EPOCH][pid]['imp'].head()

Unnamed: 0,idx,seq_idx,token,att_weights,lrp_scores,shap_scores,u_token
0,11,0,dental_exam_N,0.039415,-0.020867,0.002821,0_dental_exam_N
1,11,1,dental_exam_N,0.043597,-0.022586,0.000379,1_dental_exam_N
2,17,2,cut_finger_N,0.042979,-0.018259,0.002387,2_cut_finger_N
3,18,3,ingrown_nail_N,0.046472,-0.024119,0.000659,3_ingrown_nail_N
4,6,4,quad_injury_N,0.045419,-0.018696,-0.001051,4_quad_injury_N


## Save all combined results

In [23]:
with open(OUTPUT_PATH, 'wb') as fpath:
    pickle.dump(results, fpath)
print(f'Importance Scores Successfully Merged and Saved to {OUTPUT_PATH}!')

Importance Scores Successfully Merged and Saved to ./output/synthetic/event/30/lstm/importances/val_all_lrp_shap_2.pkl!
