## Run LRP on all test and validation results

In [1]:
#! pip install nb_black
#! pip install shap

In [2]:
%load_ext lab_black

%load_ext autoreload

%autoreload 2

In [15]:
import sys
import os
import json
import time
import torch
import pandas as pd
import numpy as np
import time
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from urllib.parse import urlparse
import tarfile
import pickle
import shutil
import matplotlib.pyplot as plt
import shap
import numpy as np

import deep_id_pytorch

from lstm_models import *
from lstm_lrp_models import *
from lstm_att_models import *
from lstm_self_att_models import *
from lstm_utils import *
from imp_utils import *

In [4]:
IS_SYNTHETIC = True  # If dataset is synthetic/real
# MODEL_NAME = 'lstm'
MODEL_NAME = "lstm"
USE_SELF_ATTENTION = True

NROWS = 1e9

TRAIN_MODEL = True
SEQ_LEN = 30
DATA_TYPE = "event"  # event/sequence

TRAIN_DATA_PATH = f"../data/sample_dataset/{DATA_TYPE}/{SEQ_LEN}/train.csv"
VALID_DATA_PATH = f"../data/sample_dataset/{DATA_TYPE}/{SEQ_LEN}/val.csv"
TEST_DATA_PATH = f"../data/sample_dataset/{DATA_TYPE}/{SEQ_LEN}/test.csv"
VOCAB_PATH = f"../data/sample_dataset/{DATA_TYPE}/{SEQ_LEN}/vocab.pkl"

MODEL_SAVE_PATH_PATTERN = (
    f"./output/{DATA_TYPE}/{SEQ_LEN}/{MODEL_NAME}/model_weights/model_{'{}'}.pkl"
)
IMP_SAVE_DIR_PATTERN = f"./output/{DATA_TYPE}/{SEQ_LEN}/{MODEL_NAME}/importances/{'{}'}_imp_{'{}'}.pkl"  # Feature importance values path for a given dataset split

OUTPUT_RESULTS_PATH = (
    f"./output/{DATA_TYPE}/{SEQ_LEN}/{MODEL_NAME}/train_results/results.csv"
)
PARAMS_PATH = (
    f"./output/{DATA_TYPE}/{SEQ_LEN}/{MODEL_NAME}/train_results/model_params.json"
)


# Test Patients for Visualization (from Test Set)
SELECTED_PATIENTS = ["VHCSIEPRI8", "MIP0F9ZOUM", "79A4PHXNE6", "CW780GL0AX"]

BEST_EPOCH = 2
EARLY_STOPPING_NUM = 1  # For early stopping

TARGET_COLNAME = "label"
UID_COLNAME = "patient_id"
TARGET_VALUE = "1"

# Results path for val & test data
output_dir = os.path.dirname(IMP_SAVE_DIR_PATTERN)
VAL_RESULTS_PATH = os.path.join(output_dir, f"val_all_lrp_{BEST_EPOCH:02}.pkl")
TEST_RESULTS_PATH = os.path.join(output_dir, f"test_all_lrp_{BEST_EPOCH:02}.pkl")

### Load Vocab and Dataset

In [6]:
# Load model params
MODEL_PARAMS = None
with open(PARAMS_PATH, "r") as fp:
    MODEL_PARAMS = json.load(fp)
MODEL_PARAMS

{'min_freq': 1,
 'batch_size': 16,
 'embedding_dim': 8,
 'hidden_dim': 16,
 'nlayers': 2,
 'bidirectional': True,
 'dropout': 0.3,
 'linear_bias': False,
 'init_type': 'zero',
 'learning_rate': 0.01,
 'scheduler_step': 3,
 'clip': False,
 'rev': False,
 'n_background': 300,
 'background_negative_only': True,
 'background_positive_only': False,
 'num_eval_val': 64,
 'num_eval_test': 64,
 'test_positive_only': False,
 'is_test_random': False}

In [7]:
if os.path.exists(VOCAB_PATH):
    with open(VOCAB_PATH, "rb") as fp:
        vocab = pickle.load(fp)
    print(f"vocab len: {len(vocab)}")  # vocab + padding + unknown
else:
    raise ValueError(
        "Vocab path does not exist! Please create vocab from training data and save it first."
    )

valid_dataset, vocab = build_lstm_dataset(
    VALID_DATA_PATH,
    min_freq=MODEL_PARAMS["min_freq"],
    uid_colname=UID_COLNAME,
    target_colname=TARGET_COLNAME,
    max_len=SEQ_LEN,
    target_value=TARGET_VALUE,
    vocab=vocab,
    nrows=NROWS,
    rev=MODEL_PARAMS["rev"],
)

test_dataset, _ = build_lstm_dataset(
    TEST_DATA_PATH,
    min_freq=MODEL_PARAMS["min_freq"],
    uid_colname=UID_COLNAME,
    target_colname=TARGET_COLNAME,
    max_len=SEQ_LEN,
    target_value=TARGET_VALUE,
    vocab=vocab,
    nrows=NROWS,
    rev=MODEL_PARAMS["rev"],
)

valid_dataloader = DataLoader(
    valid_dataset, batch_size=MODEL_PARAMS["batch_size"], shuffle=False, num_workers=2
)

test_dataloader = DataLoader(
    test_dataset, batch_size=MODEL_PARAMS["batch_size"], shuffle=False, num_workers=2
)

vocab len: 47
Building dataset from ../data/sample_dataset/event/30/val.csv..
Success!
Building dataset from ../data/sample_dataset/event/30/test.csv..
Success!


### Load Best Model

In [8]:
model_path = MODEL_SAVE_PATH_PATTERN.format(f"{BEST_EPOCH:02}")
model_path

'./output/event/30/lstm/model_weights/model_02.pkl'

In [9]:
# Check if cuda is available
print(f"Cuda available: {torch.cuda.is_available()}")
model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Cuda available: True


In [10]:
lstm_model_best = AttNoHtLSTM(
    MODEL_PARAMS["embedding_dim"],
    MODEL_PARAMS["hidden_dim"],
    vocab,
    model_device,
    bidi=MODEL_PARAMS["bidirectional"],
    nlayers=MODEL_PARAMS["nlayers"],
    dropout=MODEL_PARAMS["dropout"],
    init_type=MODEL_PARAMS["init_type"],
    linear_bias=MODEL_PARAMS["linear_bias"],
)

In [11]:
lstm_model_best.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [12]:
# def get_eval_data(dataloader, num):
#     """Get more than one iteration of data"""
#     col_pid, col_lab, col_txt = None, None, None
#     col_num = 0
#     for pid, lab, txt in dataloader:
#         if col_pid is None:
#             col_pid, col_lab, col_txt = pid, lab, txt
#         else:
#             col_pid = tuple(list(col_pid) + list(pid))
#             col_lab = torch.cat((col_lab, lab), dim=0)
#             col_txt = torch.cat((col_txt, txt), dim=0)
#         col_num = len(col_pid)
#         if col_num > num:
#             break

#     return col_pid[:num], col_lab[:num], col_txt[:num]

In [13]:
valid_results_best = {}
valid_results_best[BEST_EPOCH] = {}

test_results_best = {}
test_results_best[BEST_EPOCH] = {}
# calculate relevancy and SHAP
lstm_model_best.eval()
lrp_model = LSTM_LRP_MultiLayer(lstm_model_best.cpu())

In [16]:
# Get test/val data
val_patient_ids, val_labels, val_idxed_text = get_eval_data(
    valid_dataloader, 7000
)  # Change the number if your dataset size is different

test_patient_ids, test_labels, test_idxed_text = get_eval_data(test_dataloader, 7000)

In [17]:
start = time.time()

for sel_idx in range(len(val_labels)):
    one_text = [
        int(token.numpy())
        for token in val_idxed_text[sel_idx]
        if int(token.numpy()) != 0
    ]
    lrp_model.set_input(one_text)
    lrp_model.forward_lrp()

    Rx, Rx_rev, _ = lrp_model.lrp(one_text, 0, eps=1e-6, bias_factor=0)
    R_words = np.sum(Rx + Rx_rev, axis=1)

    df = pd.DataFrame()
    df["lrp_scores"] = R_words
    df["idx"] = one_text
    df["seq_idx"] = [x for x in range(len(one_text))]
    df["token"] = [lrp_model.vocab.itos(x) for x in one_text]
    df["att_weights"] = lrp_model.get_attn_values()

    if val_patient_ids[sel_idx] not in valid_results_best[BEST_EPOCH]:
        valid_results_best[BEST_EPOCH][val_patient_ids[sel_idx]] = {}
    valid_results_best[BEST_EPOCH][val_patient_ids[sel_idx]] = {}
    valid_results_best[BEST_EPOCH][val_patient_ids[sel_idx]]["label"] = val_labels[
        sel_idx
    ]
    valid_results_best[BEST_EPOCH][val_patient_ids[sel_idx]]["pred"] = lrp_model.s[0]
    valid_results_best[BEST_EPOCH][val_patient_ids[sel_idx]]["imp"] = df.copy()

    if sel_idx % 500 == 0:
        print(sel_idx)

end = time.time()
mins, secs = epoch_time(start, end)
print(f"{mins}min: {secs}sec")

0
500
1000
1500
2000
2500
3000
3500
4000
4500
5000
5500
6000
6500
2min: 49sec


In [18]:
valid_results_best[BEST_EPOCH][val_patient_ids[sel_idx]]

{'label': tensor([1]),
 'pred': -0.9606567216006672,
 'imp':     lrp_scores  idx  seq_idx              token  att_weights
 0    -0.020867   11        0      dental_exam_N     0.039415
 1    -0.022586   11        1      dental_exam_N     0.043597
 2    -0.018259   17        2       cut_finger_N     0.042979
 3    -0.024119   18        3     ingrown_nail_N     0.046472
 4    -0.018696    6        4      quad_injury_N     0.045419
 5    -0.004450    9        5         backache_N     0.047086
 6    -0.020040   18        6     ingrown_nail_N     0.046999
 7    -0.014780   19        7         headache_N     0.047027
 8    -0.012223    8        8        hay_fever_N     0.047272
 9    -0.019245   19        9         headache_N     0.047400
 10   -0.016507   12       10        cold_sore_N     0.047696
 11    0.547189   31       11      sleep_apnea_H     0.024369
 12   -0.031667    3       12     ankle_sprain_N     0.028616
 13   -0.015872   18       13     ingrown_nail_N     0.029725
 14   -0.0

In [19]:
with open(VAL_RESULTS_PATH, "wb") as fp:
    pickle.dump(valid_results_best, fp)

In [20]:
start = time.time()
for sel_idx in range(len(test_labels)):
    one_text = [
        int(token.numpy())
        for token in test_idxed_text[sel_idx]
        if int(token.numpy()) != 0
    ]
    lrp_model.set_input(one_text)
    lrp_model.forward_lrp()

    Rx, Rx_rev, _ = lrp_model.lrp(one_text, 0, eps=1e-6, bias_factor=0)
    R_words = np.sum(Rx + Rx_rev, axis=1)

    df = pd.DataFrame()
    df["lrp_scores"] = R_words
    df["idx"] = one_text
    df["seq_idx"] = [x for x in range(len(one_text))]
    df["token"] = [lstm_model_best.vocab.itos(x) for x in one_text]
    df["att_weights"] = lrp_model.get_attn_values()

    if test_patient_ids[sel_idx] not in test_results_best[BEST_EPOCH]:
        test_results_best[BEST_EPOCH][test_patient_ids[sel_idx]] = {}
    test_results_best[BEST_EPOCH][test_patient_ids[sel_idx]] = {}
    test_results_best[BEST_EPOCH][test_patient_ids[sel_idx]]["label"] = test_labels[
        sel_idx
    ]
    test_results_best[BEST_EPOCH][test_patient_ids[sel_idx]]["pred"] = lrp_model.s[0]
    test_results_best[BEST_EPOCH][test_patient_ids[sel_idx]]["imp"] = df.copy()

    if sel_idx / 50 == 0:
        print(sel_idx)
        end = time.time()
        mins, secs = epoch_time(start, end)
        print(f"{mins}min: {secs}sec")

end = time.time()
mins, secs = epoch_time(start, end)
print(f"Total {mins}min: {secs}sec")

0
0min: 0sec
Total 2min: 54sec


In [21]:
len(test_results_best[BEST_EPOCH])

7000

In [22]:
with open(TEST_RESULTS_PATH, "wb") as fp:
    pickle.dump(test_results_best, fp)