# Example Walkthrough of the Pipeline

This notebook will demonstrate the explainability pipeline with a dummy dataset of 5 patients, each with 10-15 tokens. The dataset is randomly made up.

Before running the following code, please install the required environment and packages as described in the "Install" section of the README.md file.

In [1]:
import os
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
os.chdir(parent_dir)

## Loading Input Data

This is a sample dataset. You can replace it with your own dataset, as long as it follows the same format: 
- `person_id`: A unique identifier for each individual.
- `day_position_tokens`: An array representing the relative time (in days) of events, with 0 indicating demographic tokens.
- `sorted_event_tokens`: A list of event codes associated with the individual. Each event corresponds to the relative date indicated by its index in the day_position_tokens array.
  - The first five tokens are always assumed to be demographic tokens, in the order of age, ethnicity, gender, race, and region.
- `label`: Label for the patient for a specific prediction task.

In [2]:
import pandas as pd
dummy_data = [
    {
        'person_id': '1', 
        'day_position_tokens': [0, 0, 0, 0, 0, 1, 1, 1, 26, 26, 26, 57, 63, 76, 76], 
        'sorted_event_tokens': ['AGE:64', 'ETHNICITY:Not Hispanic', 'GENDER:Female', 'RACE:Caucasian', 'REGION:South', 
                                'NDC:637390588', 'NDC:101350697', 'NDC:101350697', 'NDC:136680136', 'NDC:101350697', 
                                'NDC:003784043', 'NDC:782060135', 'NDC:593100304', 'NDC:000930058', 'ICD10:z760'],
        'label': 1},
    {
        'person_id': '2', 
        'day_position_tokens': [0, 0, 0, 0, 0, 1, 21, 22, 22, 22],
        'sorted_event_tokens': ['AGE:60', 'ETHNICITY:Hispanic', 'GENDER:Female', 'RACE:Caucasian', 'REGION:South', 
                                'ICD10:z01818', 'ICD10:z6829', 'SNOMED:110466009', 'SNOMED:15805002', 'NDC:000932264'],
        'label': 1},
    {
        'person_id': '3', 
        'day_position_tokens': [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
        'sorted_event_tokens': ['AGE:26', 'ETHNICITY:Not Hispanic', 'GENDER:Male', 'RACE:Other/Unknown', 'REGION:West', 
                                'NDC:694520105', 'NDC:704610120', 'ICD10:k529', 'ICD10:z95810', 'HCPCS:g0463'],
        'label': 1},
    {
        'person_id': '4', 
        'day_position_tokens': [0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 5, 5],
        'sorted_event_tokens': ['AGE:45', 'ETHNICITY:Hispanic', 'GENDER:Male', 'RACE:Caucasian', 'REGION:Other/Unknown', 
                                'LOINC:62292-8:40-50', 'LOINC:6690-2:30-40', 'LOINC:6768-6:20-30', 'NDC:162520601', 'ICD10:j4530',
                                'ICD10:z95810', 'HCPCS:g0463'],
        'label': 0},
    {
        'person_id': '5', 
        'day_position_tokens': [0, 0, 0, 0, 0, 1, 1, 1, 9, 37, 37, 52, 52, 56, 56],
        'sorted_event_tokens': ['AGE:65', 'ETHNICITY:Not Hispanic', 'GENDER:Male', 'RACE:Caucasian', 'REGION:South', 
                                'ICD10:z1211', 'ICD10:e119', 'HCPCS:g0402', 'ICD10:d010', 'HCPCS:g0403', 
                                'ICD10:r21', 'ICD10:z13220', 'ICD10:i10', 'ICD10:n401', 'ICD10:k219'],
        'label': 1},
]

dummy_df = pd.DataFrame(dummy_data)

In [3]:
# Save the dataset to a parquet file, and then load it

parquet_file_path = "./data/dummy_data.parquet"
dummy_df.to_parquet(parquet_file_path, engine='pyarrow')

loaded_df = pd.read_parquet(parquet_file_path, engine='pyarrow')

## Running the Pipeline

The explainability pipeline requires a BERT-based model trained on structured EHR data and fine-tuned for the specific disease prediction task. 
The pre-training procedure most closely follows the method described in the [TransformEHR](https://www.nature.com/articles/s41467-023-43715-z) paper, 
and the fine-tuning is primarily based on the approach used in the [Med-BERT](https://doi.org/10.1038/s41746-021-00455-y) model.

In [4]:
from src.explainability.explainability import EHRSequenceClassificationExplainer, FixedVocabTokenizer, LabelledDataset, generate_explanation, run_explanation_generation

# Necessary inputs to run the explainability pipeline.
config_path = './config/explainability_asthma.yaml'
model_location = './bert_finetuning_asthma_model.tar.gz'
input_dataset_path = './data/'
output_path = './output/'

# Run the function in the notebook
run_explanation_generation(config_path, model_location, input_dataset_path, output_path)

# Alternatively, to run this in the terminal:
# python3 -m src.explainability.explainability './config/explainability_asthma.yaml' './bert_finetuning_asthma_model.tar.gz' './data/' './output/'

Number of token type IDs: 11


Processing items:   0%|          | 0/5 [00:00<?, ?it/s]

[AVG_REGION] is not found in the vocabulary; creating one
Found the following tokens for category REGION: ['REGION:South', 'REGION:Other/Unknown', 'REGION:Midwest', 'REGION:Northeast', 'REGION:West']
Existing token embedding shape: torch.Size([114691, 768])
Category embeddings shape: torch.Size([5, 768])
Average category embedding shape: torch.Size([768])
New token embeddings shape: torch.Size([114692, 768])
Added [AVG_REGION]: 114691
[AVG_GENDER] is not found in the vocabulary; creating one
Found the following tokens for category GENDER: ['GENDER:Male', 'GENDER:Female', 'GENDER:Unknown']
Existing token embedding shape: torch.Size([114692, 768])
Category embeddings shape: torch.Size([3, 768])
Average category embedding shape: torch.Size([768])
New token embeddings shape: torch.Size([114693, 768])
Added [AVG_GENDER]: 114692
[AVG_ETHNICITY] is not found in the vocabulary; creating one
Found the following tokens for category ETHNICITY: ['ETHNICITY:Unknown', 'ETHNICITY:Not Hispanic', 'ETHN

Processing items: 100%|██████████| 5/5 [00:45<00:00,  9.01s/it]


In [5]:
# Read the output

output_df = pd.read_parquet(output_path, engine='pyarrow')

In [6]:
output_df

Unnamed: 0,person_id,position_id,true_label,score_0,pred_prob_0,score_1,pred_prob_1,word,token_type_id
0,5,0,1,0.098704,0.869820,-0.098704,0.130180,[CLS],0.0
1,5,0,1,0.596119,0.869820,-0.596119,0.130180,AGE:65,0.0
2,5,0,1,0.003533,0.869820,-0.003533,0.130180,ETHNICITY:Not Hispanic,0.0
3,5,0,1,0.422016,0.869820,-0.422016,0.130180,GENDER:Male,0.0
4,5,0,1,0.016576,0.869820,-0.016576,0.130180,RACE:Caucasian,0.0
...,...,...,...,...,...,...,...,...,...
62,2,1,1,0.439130,0.845862,-0.439130,0.154138,ICD10:z01818,0.0
63,2,21,1,0.062352,0.845862,-0.062352,0.154138,ICD10:z6829,0.0
64,2,22,1,0.006741,0.845862,-0.006741,0.154138,SNOMED:110466009,0.0
65,2,22,1,0.198696,0.845862,-0.198696,0.154138,SNOMED:15805002,0.0


## Aggregation and Post-processing

Below is code snippets for a sample use case, where we want to find the top contributing tokens to class 0 

In [7]:
# Generate the predictetd labels
output_df['pred_label'] = (output_df['pred_prob_1'] > output_df['pred_prob_0']).astype(int)


filtered_df = output_df[output_df['pred_label'] == 0]

result_df = filtered_df.groupby('word').agg({'score_0': 'mean'}).reset_index()
result_df = result_df.sort_values(by='score_0', ascending=False)

result_df


Unnamed: 0,word,score_0
2,AGE:60,0.771955
4,AGE:65,0.596119
3,AGE:64,0.568753
38,NDC:704610120,0.546419
1,AGE:45,0.488324
34,NDC:162520601,0.47154
20,ICD10:z01818,0.43913
16,ICD10:k219,0.367215
36,NDC:637390588,0.301734
8,GENDER:Male,0.30015
