## Merge LSTM LRP and SHAP of All Test/Val Data

In [10]:
import pickle
import os
import sys
from scipy import stats
import numpy as np

from imp_utils import *

In [11]:
IS_SYNTHETIC = True #Whether data is synthetic
SEQ_LEN = 30
MODEL_NAME = "lstm"
DATA_TYPE = "event"  # event/sequence
BEST_EPOCH = 2
DATA_SPLIT = 'val' #val/test

RESULTS_DIR = f'./output/synthetic/{DATA_TYPE}/{SEQ_LEN}/{MODEL_NAME}/importances/'

SHAP_RESULTS_PATH = os.path.join(RESULTS_DIR, f"{DATA_SPLIT}_all_shap_{BEST_EPOCH}.pkl")
LRP_RESULTS_PATH = os.path.join(RESULTS_DIR, f"{DATA_SPLIT}_all_lrp_{BEST_EPOCH}.pkl")
OUTPUT_PATH = os.path.join(RESULTS_DIR, f"{DATA_SPLIT}_all_lrp_shap_{BEST_EPOCH}.pkl")

In [29]:
with open(SHAP_RESULTS_PATH, 'rb') as fp:
    shap_results = pickle.load(fp)
    
with open(LRP_RESULTS_PATH, 'rb') as fp:
    lrp_results = pickle.load(fp)

In [31]:
print(f"LRP: {len(lrp_results[BEST_EPOCH])}")
print(f"SHAP: {len(shap_results[BEST_EPOCH])}")
assert len(lrp_results[BEST_EPOCH]) == len(shap_results[BEST_EPOCH])

LRP: 7000
SHAP: 7000


## Merge LRP and SHAP scores

In [32]:
for pid in lrp_results[BEST_EPOCH].keys():
    orig_len = lrp_results[BEST_EPOCH][pid]["imp"].shape[0]
    lrp_results[BEST_EPOCH][pid]["imp"] = lrp_results[BEST_EPOCH][pid]["imp"].merge(
        shap_results[BEST_EPOCH][pid]["imp"], on=["seq_idx", "token"]
    )
    assert orig_len == lrp_results[BEST_EPOCH][pid]["imp"].shape[0]
    lrp_results[BEST_EPOCH][pid]["imp"] = lrp_results[BEST_EPOCH][pid]["imp"][
        ["idx", "seq_idx", "token", "att_weights", "lrp_scores", "shap_scores"]
    ]
results = lrp_results

## Calculate similarity scores

In [35]:
# calculate similarity indexes
epoch_lrp_shap_t_corr = []
epoch_lrp_sim = []
epoch_shap_sim = []

In [36]:
for pid in results[BEST_EPOCH].keys():
    imp_df = results[BEST_EPOCH][pid]["imp"]
    imp_df["u_token"] = [
        str(seq) + "_" + str(token)
        for seq, token in zip(imp_df["seq_idx"], imp_df["token"])
    ]
    results[BEST_EPOCH][pid]["lrp_shap_t_corr"] = get_wtau(
        imp_df["lrp_scores"], imp_df["shap_scores"]
    )

    # gt similarity
    lrp_sim = imp_utils.get_intersection_similarity(
        imp_df.lrp_scores, imp_df.token, freedom=0, is_synthetic=IS_SYNTHETIC
    )
    shap_sim = imp_utils.get_intersection_similarity(
        imp_df.shap_scores, imp_df.token, freedom=0, is_synthetic=IS_SYNTHETIC
    )
    att_sim = imp_utils.get_intersection_similarity(
        imp_df.att_weights, imp_df.token, freedom=0, is_synthetic=IS_SYNTHETIC
    )
    results[BEST_EPOCH][pid]["lrp_sim"] = lrp_sim
    results[BEST_EPOCH][pid]["shap_sim"] = shap_sim
    results[BEST_EPOCH][pid]["att_sim"] = att_sim

In [37]:
results[BEST_EPOCH][pid]['imp'].head()

Unnamed: 0,idx,seq_idx,token,att_weights,lrp_scores,shap_scores,u_token
0,3,0,ankle_sprain_N,0.002363,-0.570503,-0.019583,0_ankle_sprain_N
1,4,1,headache_N,0.003156,-0.833891,-0.01281,1_headache_N
2,2,2,backache_N,0.006852,-1.070633,-0.016969,2_backache_N
3,18,3,peanut_allergy_N,0.022434,-1.016982,-0.020836,3_peanut_allergy_N
4,20,4,metabolic_disorder_H,0.05969,7.179713,0.118016,4_metabolic_disorder_H


## Save all combined results

In [39]:
with open(OUTPUT_PATH, 'wb') as fpath:
    pickle.dump(results, fpath)
print(f'Importance Scores Successfully Merged and Saved to {OUTPUT_PATH}!')