In [1]:
import guidance
import openai
import pandas as pd
import numpy as np
import torch
from transformers import BertConfig, BertTokenizerFast, BertForSequenceClassification
from XAI_Utils import MOR_Predictor
from FewShotPipeline import FewShotPipeline

In [2]:
# Load the pre-trained model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "path/to/mor_umlsbert_baseline.pt"
tokenizer_path = "path/to/UmlsBERT/checkpoint/"
config = BertConfig.from_pretrained(tokenizer_path)
model = BertForSequenceClassification.from_pretrained(model_path, config=config)
model.config.output_hidden_states = True
model.eval()
tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path)

Predictor = MOR_Predictor(model, tokenizer, 512, device)

train_admission = pd.read_csv('path/to/CORe_Dataset/MP_IN_adm_train.csv')
test_admission = pd.read_csv('path/to/CORe_Dataset/MP_IN_adm_test.csv')

In [3]:
openai.api_key = "API_KEY"

In [4]:
program_explain = guidance('''
{{~#system~}}
You are a helpful assistant.
{{~/system~}}

{{#each prompt_list~}}
{{#user~}}
EXAMPLE:
{{this.text}}
LABEL: {{this.label}}
{{~/user~}}
{{/each}}

{{#user~}}
Given the TEST EXAMPLE:
{{test_prompt}}
The true LABEL for this example is: {{target_label}}
{{#block hidden=True}}
Remember not to refer to whole sections, but individual details.\
{{~/block}}
Can you explain the top 3 reason that LABEL is {{target_label}} for this TEST EXAMPLE?\
{{~/user~}}

{{#assistant~}}
{{gen 'explanation' temperature=0}}
{{~/assistant~}}
                           
{{#user~}}
Then, rank the top 3 pieces of information in the TEST EXAMPLE that could most likely contribute to the LABEL {{#if target_label == 'mortality'}}survived{{else}}mortality{{/if}}?\
{{#block hidden=True}}
Provide a rationale for you believe each piece of information is influential in determining the label.\
{{~/block}}
{{~/user~}}

{{#assistant~}}
{{gen 'opposite_explanation' temperature=0}}
{{~/assistant~}}

{{#user~}}
Now, based on the TEST EXAMPLE, rank the top 3 features that, if altered,\
would most likely change the LABEL to {{#if target_label == 'mortality'}}survived{{else}}mortality{{/if}}?\
{{#block hidden=True}}
Provide a rationale for each feature selection, detailing what change will be made and why it would have a significant impact on the label if changed.
{{~/block}}
{{~/user~}}

{{#assistant~}}
{{gen 'change_explanation' temperature=0}}
{{~/assistant~}}
''')


In [5]:
train_embeddings = [Predictor.cls_embed(text) for text in train_admission.text]
train_embeddings = np.array(train_embeddings)

In [6]:
guidance.llm = guidance.llms.OpenAI("gpt-4")
guidance.llms.Transformers.cache.clear()

In [7]:
# Convert numerical labels to descriptive terms
def convert_label(num_label):
    return 'survived' if num_label == 0 else 'mortality'

In [8]:
idx = 10
instance = test_admission.iloc[idx, 1]
target_label = test_admission.iloc[idx, -1]
merged_indexes, prompt_list = Predictor.knn_prompt(train_embeddings, train_admission, instance, target_label, 'cosine')
# Convert numerical label to descriptive label
descriptive_label = convert_label(target_label)
prediction = program_explain(prompt_list=prompt_list, test_prompt=instance, target_label = descriptive_label)

In [9]:
idx = 1686
instance = test_admission.iloc[idx, 1]
target_label = test_admission.iloc[idx, -1]
merged_indexes, prompt_list = Predictor.knn_prompt(train_embeddings, train_admission, instance, target_label, 'cosine')
# Convert numerical label to descriptive label
descriptive_label = convert_label(target_label)
prediction = program_explain(prompt_list=prompt_list, test_prompt=instance, target_label = descriptive_label)