# Prompting LLMs for Patient-Oriented Language 

In [1]:
import os
import transformers
import json

import torch
import tqdm

from typing import Dict, List, Optional

from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextStreamer
from datasets import load_dataset


## Load dataset

In [2]:
data = load_dataset("csv", data_files="../../test_split.csv")

In [3]:
#data = load_dataset("Mathilde/test_data_pol")
subset = data['train']  # TODO changer en test dans le ds original
print(subset)
"""
statements_pol = data['train']['statement_pol']
statements_med = data['train']['statement_med']
premises =  data['train']['premise']
"""

Dataset({
    features: ['id', 'topic_id', 'statement_medical', 'statement_pol', 'premise', 'NCT_title', 'NCT_id', 'label'],
    num_rows: 1578
})


"\nstatements_pol = data['train']['statement_pol']\nstatements_med = data['train']['statement_med']\npremises =  data['train']['premise']\n"

## Load Model

TODO: change the model's path

In [4]:
model_path = "/gpfsdswork/dataset/HuggingFace_Models/mistralai/Mixtral-8x7B-Instruct-v0.1" # "/gpfsdswork/dataset/HuggingFace_Models/meta-llama/Llama-2-7b-chat-hf" #"/lustre/fsn1/projects/rech/hjp/ulj12fo/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)

In [5]:
streamer = TextStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_8bit=True, device_map="auto",)  # torch_dtype=torch.float16, low_cpu_mem_usage=True

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards:  42%|####2     | 8/19 [00:31<00:43,  3.97s/it]

## Format the prompt

In [None]:
def get_input_text(premise, hypothesis):
    # TODO the persona option + the wrapping of the pol + med "patients with this medical profile..."
    options_prefix = "Answer in 1 word only with: \n- " #"OPTIONS:\n- "
    separator = "\n- "
    options_ = options_prefix + f"{separator}".join(["Entailment OR","Contradiction"])  #  "Neutral"
    return f"{premise} \n Question: Imagine that you are a doctor reviewing patients profiles to enroll them for a clinical trial. Does the previous eligibility criteria imply that the following patient can participate to the trial?\n Patient profile:\n {hypothesis}\n {options_}"
    #return f"{premise} \n Question: Does this imply that {hypothesis}? {options_}"
    # Does the previous eligibility criteria imply that the following patient can participate to the trial?
    # return f"Classification: {premise} \n Question: Does this imply that {hypothesis}? Entailment or Contradiction?Answer:"

In [None]:
#subset[149]

In [None]:
samples = []
# TODO adapt in function of the med or POL (l.5)
for instance in subset:
    premise = instance['premise']
    sentence = f"Eligibility criteria of the trial are:\n {premise}"
    input_text = get_input_text(sentence, instance['statement_medical'])
    # temp = {"text":input_text, "label":sample['label']}
    temp = {"text":input_text, "label":0}
    print(input_text)
    samples.append(temp)

## Define the chat function

TODO: adjust the hyperparameters

In [None]:
def chat(
    query: str,
    history: Optional[List[Dict]] = None,
    temperature: float = 0.7,
    top_p: float = 1.0,
    top_k: float = 0,
    repetition_penalty: float = 1.1,
    max_new_tokens: int = 5, # 1024,
    **kwargs,
):
    if history is None:
        history = []

    history.append({"role": "user", "content": query})

    input_ids = tokenizer.apply_chat_template(history, add_generation_prompt=True, return_tensors="pt").to(model.device)
    input_length = input_ids.shape[1]

    generated_outputs = model.generate(
        input_ids=input_ids,
        generation_config=GenerationConfig(
            temperature=temperature,
            do_sample=temperature > 0.0,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            **kwargs,
        ),
        streamer=streamer,
        return_dict_in_generate=True,
    )

    generated_tokens = generated_outputs.sequences[0, input_length:]
    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

    history.append({"role": "assistant", "content": generated_text})

    return generated_text, history

In [None]:
response, history = chat(samples[54]['text'], history=None)

In [None]:
#samples[52]['text']

## Call the chat function on the whole test set

In [None]:
labels = []
pred = []
#with torch.inference_mode():
for sample in tqdm.tqdm(samples):
    labels.append(sample["label"])
    # input_ids = tokenizer(sample["text"], return_tensors="pt",).input_ids.to("cuda")
    # input_ids = tokenizer.apply_chat_template(sample["text"], return_tensors="pt",).to("cuda")
    # outputs = model.generate(input_ids, max_new_tokens=20)
    response, history = chat(sample['text'], history=None)
    pred.append(response)

In [None]:
pred

In [None]:
# pred = [p[5:][:-4].strip() for p in pred]

## Parse model's output

TODO: adapt the pattern in function of the model's output style

In [None]:
pattern = "entailment|contradiction|Entailment|Contradiction|entail"  # |Neutral|neutral

In [None]:
import re

# Function to match text with regular expression
def match_text_with_regexp(text, pattern):
    text = text.strip()
    # print(text)
    # Compile the regular expression pattern
    regexp = re.compile(pattern)
    
    # Search for a match in the text
    match = regexp.search(text)
    # print(match)
    
    if match:
        # If a match is found, return the matched text
        return match.group()
    else:
        # If no match is found
        return None

parsed_preds = []

for p in pred:
    # text = pred[0]
    pattern = "entailment|contradiction|Entailment|Contradiction|entail|yes|Yes|No|no"  # |Neutral|neutral
    
    result = match_text_with_regexp(p, pattern)
    
    if result:
        if result in ['entailment', 'entail', 'yes', 'Yes']:
            result = 'Entailment'
        elif result in ['contradiction', 'contradicts', 'no', 'No']:
            result = 'Contradiction'
        #elif result in ['neutral', 'Neutral']:
            #result = 'Neutral'
        # print(f"Match found: {result}")
        parsed_preds.append(result) 
    else:
        # print("No match found.")
        # if nothing exploitable predicted --> assert "Neutral"
        parsed_preds.append("Contradiction")
        


In [None]:
parsed_preds

In [None]:
set(parsed_preds)
from collections import Counter
Counter(parsed_preds)

## Save the predictions in a JSON file

In [None]:
prediction_dict = {}
for _id,pred_x in zip(data['train']['id'], parsed_preds):
    prediction_dict[str(_id)] = {"Prediction":pred_x}

In [None]:
# from sklearn.metrics import f1_score
# uuid_list = list(prediction_dict.keys())
# results_pred = []
# gold_labels = []
# for i in range(len(uuid_list)):
#     if prediction_dict[uuid_list[i]]["Prediction"] in ["Entailment", "Yes"]:
#         results_pred.append(1)
#     else:
#         results_pred.append(0)
#     if data[uuid_list[i]]["Label"] in ["Entailment", "No"]:
#         gold_labels.append(1)
#     else:
#         gold_labels.append(0)
# f1_score(gold_labels,results_pred)

In [None]:
json.dump(prediction_dict, open("results_mixtral_zs_med_persona.json",'w'),indent=4)

In [None]:
prediction_dict