# Pipeline for generating a question to further prompt the user.


In [2]:
import pandas as pd
import os
import json
from tqdm import trange
import time
import re

from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from groq import Groq


## load env variables
GROQ_API_KEY       = os.environ["GROQ_API_KEY"]
CHAT_MODEL         = "llama3-70b-8192"
client = Groq()

PROMPT = \
'''You are an expert in linguistic variation and medical communication. 
You are provided with a medical question, a persons answer to that question, and the model answer to that question. 
Your task is to generate a follow up question that aids the user in the answering of the question. 
You are to focus in on what is missing from the users answer in comparison to the model answer, and hint that to the user in the follow up question.
You are to reply in a cheerful, friendly tone.

Question:
{question}

Users Answer:
{user_ans}

Model answer:
{model_ans}

The returning format should be like this:
{{"guiding_question": "Question to help the user better answer the question"}}

Before returning the answer, ensure and double check that the answer is in accordance to the format above.
'''

In [8]:
def load_data(index):
    with open("../data/further_prompting/combined_data.json", "r", encoding="utf-8") as fin:
        data = json.load(fin)
    return data[index]

{'QA_pair_number': 0,
 'Question': 'What is the percentage of total mortality in Singapore attributed to ischemic heart disease, based on national health statistics from the Ministry of Health in 2019?',
 'Answer': 'Ischemic heart disease contributes to 18.8% of total mortality in Singapore, based on national health statistics from the Ministry of Health in 2019.',
 '3/10': ["I think it's around 20% or something, I'm not really sure.",
  "Ischemic heart disease is a big deal in Singapore, but I don't have the exact number.",
  "Heart disease is a major killer in Singapore, but I'm not sure what the exact percentage is.",
  "I'm pretty sure it's higher than 10%, but I don't have the exact stat.",
  "Singapore has a lot of heart problems, but I don't know the exact mortality rate."],
 '6/10': ["According to the Ministry of Health's 2019 statistics, ischemic heart disease accounts for nearly 19% of total mortality in Singapore.",
  'In 2019, ischemic heart disease was responsible for appr

In [20]:
def extract_answer(answer, pattern=r'"guiding_question"\s*:\s*"([^"]+)"'):
    """
    Extracts the answer from the text associated with the key "guiding_question".
    
    """
    try:        
        # Search for the pattern in the input text
        match = re.search(pattern, answer)
        
        # If a match is found, return the captured group (the question)
        if match:
            return match.group(1)
        else:
            # If no match is found, return an error message
            return "No guiding question found in the input text."
    
    except Exception as e:
        # If any unexpected error occurs, return a generic error message
        return f"An error occurred: {str(e)}"
        



def generate_guiding_question(qn, user_ans, model_ans) -> str:
    '''Generates the question used to re-prompt the user, 
    helping him to answer the question better with some hints'''
    
    prompt = PromptTemplate(
            template=PROMPT,
            input_variables=["question", "user_ans", "model_ans"]
        ) 
    
    final_prompt = prompt.format(question=qn, user_ans=user_ans, model_ans=model_ans)

    completion = client.chat.completions.create(
            model=CHAT_MODEL,
            messages=[
                {
                    "role": "user",
                    "content": final_prompt
                }
            ],
            temperature=0,
            max_tokens=1024,
            top_p=1,
            stream=True,
            stop=None
        )

    llm_answer = ''''''
    for chunk in completion:
        llm_answer += chunk.choices[0].delta.content or ""
        
    # Improve error handling here
    guiding_qn = extract_answer(llm_answer)
    print("===================================")
    print(f"Question:\n{qn}\n")
    print("===================================")
    print(f"User's Answer:\n{user_ans}\n")
    print("===================================")
    print(f"Model Answer:\n{model_ans}\n")
    print("===================================")
    print(f"Guiding Question:\n{guiding_qn}\n")
    print("===================================")
    return guiding_qn


def main():
    data = load_data(4)
    model_ans = data['Answer']
    qn = data['Question']
    ans_index = 2
    user_ans = data['3/10'][ans_index]
    ans = generate_guiding_question(qn, user_ans, model_ans)


Question:
What is the minimum requirement for performing basic life-saving skills of cardio-pulmonary resuscitation (CPR) and use of automated external defibrillators (AEDs)?

User's Answer:
You need to have a lot of experience in emergency situations.

Model Answer:
The minimum requirement for performing basic life-saving skills of cardio-pulmonary resuscitation (CPR) and use of automated external defibrillators (AEDs) is the rescuer's two hands.

Guiding Question:
What specific physical resources do you think are necessary to perform CPR and use an AED, aside from experience?



In [None]:
# TODO for Ethan (Generate answer similarity pipeline)
def get_answer_similarity(retry_ans, model_ans) -> float: 
    '''
    Takes in the users retried answer and the model answer. 
    Uses finetuned cross encoder to generate similarity score
    '''
    score = None
    return score