## Merge results from SHAP and LRP in test/val

In [26]:
import pickle

In [27]:
SEQ_LEN = 30
MODEL_NAME = "lstm-att-lrp"
DATA_TYPE='seq_based'
BEST_EPOCH = 8
DATA_SPLIT = 'test'

In [28]:
shap_results_path = f"./output/final_final/{DATA_TYPE}/{SEQ_LEN}/{MODEL_NAME}/shap/{DATA_SPLIT}_results_shap_{BEST_EPOCH}.pkl"
lrp_results_path = f"./output/final_final/{DATA_TYPE}/{SEQ_LEN}/{MODEL_NAME}/shap/{DATA_SPLIT}_results_lrp_{BEST_EPOCH}.pkl"
output_path = f"./output/final_final/{DATA_TYPE}/{SEQ_LEN}/{MODEL_NAME}/shap/{DATA_SPLIT}_results_all_{BEST_EPOCH}.pkl"

In [29]:
with open(shap_results_path, 'rb') as fp:
    shap_results = pickle.load(fp)
    
with open(lrp_results_path, 'rb') as fp:
    lrp_results = pickle.load(fp)

In [30]:
#lrp_results[BEST_EPOCH]

In [31]:
print(f"LRP: {len(lrp_results[BEST_EPOCH])}")
print(f"SHAP: {len(shap_results[BEST_EPOCH])}")
assert len(lrp_results[BEST_EPOCH]) == len(shap_results[BEST_EPOCH])

LRP: 7000
SHAP: 7000


## Merge LRP and SHAP scores

In [32]:
for pid in lrp_results[BEST_EPOCH].keys():
    orig_len = lrp_results[BEST_EPOCH][pid]["imp"].shape[0]
    lrp_results[BEST_EPOCH][pid]["imp"] = lrp_results[BEST_EPOCH][pid]["imp"].merge(
        shap_results[BEST_EPOCH][pid]["imp"], on=["seq_idx", "token"]
    )
    assert orig_len == lrp_results[BEST_EPOCH][pid]["imp"].shape[0]
    lrp_results[BEST_EPOCH][pid]["imp"] = lrp_results[BEST_EPOCH][pid]["imp"][
        ["idx", "seq_idx", "token", "att_weights", "lrp_scores", "shap_scores"]
    ]

In [33]:
results = lrp_results

## Calculate similarity scores

In [34]:
import sys
import rbo
from scipy import stats
import numpy as np

def get_wtau(x, y):
    return stats.weightedtau(x, y, rank=None)[0]


def get_rbo(x, y, uid, p=0.7):
    x_idx = np.argsort(x)[::-1]
    y_idx = np.argsort(y)[::-1]

    return rbo.RankingSimilarity(
        [uid[idx] for idx in x_idx], [uid[idx] for idx in y_idx]
    ).rbo(p=p)


# calculate ground truth scores
def is_value(x):
    if "_N" in x:
        return False
    return True

In [35]:
# calculate similarity indexes
epoch_lrp_shap_t_corr = []
epoch_lrp_shap_rbo = []
epoch_lrp_sim = []
epoch_shap_sim = []

rbo_p = 0.95

In [36]:
for pid in results[BEST_EPOCH].keys():
    imp_df = results[BEST_EPOCH][pid]["imp"]
    imp_df["u_token"] = [
        str(seq) + "_" + str(token)
        for seq, token in zip(imp_df["seq_idx"], imp_df["token"])
    ]
    results[BEST_EPOCH][pid]["lrp_shap_t_corr"] = get_wtau(
        imp_df["lrp_scores"], imp_df["shap_scores"]
    )

    results[BEST_EPOCH][pid]["lrp_shap_rbo"] = get_rbo(
        imp_df["lrp_scores"],
        imp_df["shap_scores"],
        imp_df["u_token"].tolist(),
        p=rbo_p,
    )

    # gt similarity
    gt_idx = [x for x, tok in enumerate(imp_df.u_token) if is_value(tok)]
    n_gt = len(gt_idx)
    if n_gt > 0:
        lrp_idx = np.argsort(np.abs(imp_df.lrp_scores.values))[::-1][:n_gt]
        shap_idx = np.argsort(np.abs(imp_df.shap_scores.values))[::-1][:n_gt]
        att_idx = np.argsort(np.abs(imp_df.att_weights.values))[::-1][:n_gt]
        lrp_sim = len(set(lrp_idx).intersection(gt_idx)) / n_gt
        shap_sim = len(set(shap_idx).intersection(gt_idx)) / n_gt
        att_sim = len(set(att_idx).intersection(gt_idx)) / n_gt
    else:
        print(f"-1 is the output for {pid}")
        lrp_sim = -1
        shap_sim = -1
    results[BEST_EPOCH][pid]["lrp_sim"] = lrp_sim
    results[BEST_EPOCH][pid]["shap_sim"] = shap_sim
    results[BEST_EPOCH][pid]["att_sim"] = att_sim

In [37]:
results[BEST_EPOCH][pid]['imp'].head()

Unnamed: 0,idx,seq_idx,token,att_weights,lrp_scores,shap_scores,u_token
0,3,0,ankle_sprain_N,0.002363,-0.570503,-0.019583,0_ankle_sprain_N
1,4,1,headache_N,0.003156,-0.833891,-0.01281,1_headache_N
2,2,2,backache_N,0.006852,-1.070633,-0.016969,2_backache_N
3,18,3,peanut_allergy_N,0.022434,-1.016982,-0.020836,3_peanut_allergy_N
4,20,4,metabolic_disorder_H,0.05969,7.179713,0.118016,4_metabolic_disorder_H


## Save all combined results

In [38]:
output_path

'./output/final_final/seq_based/30/lstm-att-lrp/shap/test_results_all_8.pkl'

In [39]:
with open(output_path, 'wb') as fpath:
    pickle.dump(results, fpath)

In [40]:
import os
output_dir = os.path.dirname(output_path)
! ls {output_dir}

test_results_all_8.pkl	 test_shap_6.pkl	 val_shap_2.pkl
test_results_lrp_8.pkl	 test_shap_7.pkl	 val_shap_3.pkl
test_results_shap_8.pkl  test_shap_8.pkl	 val_shap_4.pkl
test_shap_0.pkl		 test_shap_9.pkl	 val_shap_5.pkl
test_shap_1.pkl		 val_results_all_8.pkl	 val_shap_6.pkl
test_shap_2.pkl		 val_results_lrp_8.pkl	 val_shap_7.pkl
test_shap_3.pkl		 val_results_shap_8.pkl  val_shap_8.pkl
test_shap_4.pkl		 val_shap_0.pkl		 val_shap_9.pkl
test_shap_5.pkl		 val_shap_1.pkl
