In [None]:
import json
from pprint import pprint
import torch
# device = torch.device("cuda:1")


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
model_name_or_path = "TheBloke/jackalope-7B-GPTQ"
# To use a different branch, change revision
# For example: revision="gptq-4bit-32g-actorder_True"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map='auto',
                                             trust_remote_code=False,
                                             revision="main")

from auto_gptq import exllama_set_max_input_length
model = exllama_set_max_input_length(model, 8192)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

print("*** Pipeline:")
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.4,
    repetition_penalty=1.1
)
torch.cuda.empty_cache()


In [None]:
# read in data
# get the text report as input
import pandas as pd
df_text = pd.read_csv('./doc_level_gold_labels.csv')
eg = df_text
input_text = eg.text.to_list()
input_labels = eg.labels.to_list()
eg

## data statistics


In [None]:
eg.info(verbose=True)

In [None]:
eg['text_len'] = eg.text.apply(lambda x: len(x))
print("mean: ", eg.text_len.mean())
print("min: ", eg.text_len.min())
print("max: ", eg.text_len.max())


In [None]:
eg['labels_len'] = eg.labels.apply(lambda x: len(x.split(",")))
print("mean: ", eg.labels_len.mean())
print("min: ", eg.labels_len.min())
print("max: ", eg.labels_len.max())

In [None]:
eg.labels.to_list()[0]

In [None]:
eg.final_labels.to_list()[0]

In [None]:
import re, ast
# transform gold labels to be dictionaries
label_lst = []

label_ids = []

for row in eg.iterrows():
    id = row[1]['doc_id']
    each = row[1]["final_labels"]
    txt = each.strip('][')
    res = re.finditer('{({*[^{}]*}*)}', txt)
    # count += 1
    res_dict = {}
    for idx in res:
        # print(id)
        idx_res = ast.literal_eval(idx.group())
        # print(idx.group())
        k = list(idx_res.keys())[0]
        
        # print(k)
        
        res_dict[k] = idx_res[k]
    # print(res_dict)
    # break
    label_lst.append(res_dict)
    label_ids.append(id)

        # else:
        #     label_lst.append(np.nan)
        #     label_ids.append(id)
print(len(label_lst))
print(len(label_ids))
eg['clean_final_labels'] = label_lst
eg.info()

In [None]:
[len(x) for x in eg.clean_final_labels.to_list()[0].values()]

In [None]:
import numpy as np
eg['avg_values_per_label'] = eg.clean_final_labels.apply(lambda x: np.mean([[len(each)for each in x.values()]]))
print("mean: ", eg.avg_values_per_label.mean())
print("min: ", eg.avg_values_per_label.min())
print("max: ", eg.avg_values_per_label.max())

In [None]:
eg.iloc[309].to_list()

In [None]:
eg[eg["labels_len"] >= 6].sample(n=5)

In [None]:
# new prompt with two steps
example = """
{"Symptom1": ["redness", "redness"], "Symptom2": ["fever"], "Symptom3": ["sore", "arm", "soreness"], "Symptom4": ["heartfailure"]}
{"Erythema": ["redness", "redness"], "Pain in extremity": ["sore", "arm", "soreness"], "Pruritus": ["none"]}
"""
input_template_2nd = """
Ignore previous conversations.

Clinical Notes: {text}

First, extract all medical symptoms mentioned in the clinical text above. 
Second, match the extracted symptoms with the terms in the given suggest list below. If there's no match, provide 'none' for the term.
Please follow the order of this suggest list: {suggest} and generate output by following the requirements below:

Requirements:
1. Adverse event means any symptoms or irregular test results. Therefore, procedure description, negative test results, or only the mention of the test itself are not adverse events.
2. If any non symptom, vague mention, or non vaccine related terms appeared in the suggested terms, just provide "none" for them in output. 
3. The output should have the exatracted symptoms as acquired in the first step, and the matched terms with corresponded symptoms as in the second step. Output should be in json format like the example below shows.

Example: 
{example}

Here is the JSON output:
"""

In [None]:
# new two step prompt structure
example = """
Symptom List: {"Erythema": ["redness", "redness"], "Pain in extremity": ["sore", "arm", "soreness"], "Pruritus": ["none"]}
Suggest List: {"Erythema": ["redness", "redness"], "Pruritus": ["none"]}
"""
input_template = """
Ignore previous conversations.

Clinical Notes: {text}

First, extract an adverse event list from the clinical text above.
Adverse event means any symptoms or irregular test results. Therefore, procedure description, negative test results, or only the mention of the test itself are not adverse events.
Then, extract adverse events that indicating each of the suggested terms below from the adverse event list in previous step. 
Include the terms in the output even if the terms are not explicitly mentioned in the provided report, just provide ‘none’ as the result. 

Please follow the order of this list: {suggest} and the output should only include a list of symptoms and a list of matched suggest list in json format like the example below shows.

Example: 
{example}

Here is the output:
"""

In [None]:
# build up the call
# prompt_dic = {'prefix': prompt_prefix, 'cloze': prompt_cloze,'heu': prompt_heu,'cot': prompt_cot}
prompt_dic = {'new_prompt': 1}
for p in prompt_dic:
    answer_lst = []
    for row in eg.iterrows():
        txt = row[1]['text']
        suggest = row[1]['labels']
        # input = input_template.format(prompt = prompt_dic[p], suggest = suggest, text = txt)
        input = input_template.format(text=txt, suggest = suggest, example = example)

        answer = pipe(input)
        answer_lst.append(answer[0]['generated_text'][len(input):].strip())
    p_col_name = p + 'llm_result'
    eg[p_col_name] = answer_lst
    result_df = eg[[p_col_name, 'final_labels', 'doc_id', 'text']]
    f_name = p + '_result_jackalope_temp_4_2step_prompt.json'
    result_df.to_json(f_name)
    torch.cuda.empty_cache()



In [None]:
result_df

In [None]:
eg = result_df.loc[:,['new_promptllm_result', 'final_labels']]
eg

In [None]:
for rows in eg.iterrows():
    print(rows[0])
    print(rows[1][0])
    