In [1]:
import ollama

# basic function to prompt the model
def get_response_from_model(prompt, model="qwen2.5:3b"):
    response = ollama.generate(model=model, prompt=prompt)
    return response["response"]

In [2]:
def generate_template (symptoms, demographics):

    prompt_template = f"""
    You are a clinical AI assistant.

    You will be provided:
    - A description of the patient's symptoms
    - Key demographic details (age, sex, medical history, etc.)

    You have access to the NHS A–Z webpages of conditions and treatments.

    Your task:
    - Analyze the symptoms + demographics
    - Generate 5–10 targeted and varied search queries (phrases or keyword combinations)
    - Each query should:
    1. Reflect different angles: e.g., symptom-based, demographic-specific, etc.
    2. Use terminology aligning with NHS documentation (leveraging query transformation best practices)
    3. Use word variations and synonyms to capture search intent beyond exact matches
    4. Include both narrow (specific symptom + demographic) and broad (symptom cluster) variants
    5. Use only demographic details provided if useful for retrieval on NHS A–Z webpages
    6. Aim to maximize precision and recall in retrieval.

    Return exactly valid JSON only, no extra text.
    Input:
    {{"symptoms": "...", "demographics": ...}}

    Output schema:
    {{
    "queries": [
        {{
        "query": "string",
        "focus": "symptom" | "demographic" | "risk_factor"
        }}
    ]
    }}

    Example:
    Input:
    {{"symptoms": "headache and nausea", "demographics": {{"age": 45, "sex": "female"}}}}
    Output:
    {{
    "queries": [
        {{"query": "headache nausea adult female", "focus": "symptom"}},
        {{"query": "female 45 headache", "focus": "demographic"}},
        {{"query": "nausea dehydration risk factor", "focus": "risk_factor"}}
    ]
    }}
    These are the patient's symptoms and demographics:
    {symptoms}
    {demographics}
    """
    return prompt_template

In [3]:
import json

filename ="../data/synthetic_queries/5147cd8_gpt-4o_1000_synthetic_queries.jsonl"
with open(filename, 'r') as f:
    lines = f.readlines()
    data = [json.loads(line) for line in lines]


In [4]:
import requests

def query_retriever(query):
    host = "localhost"
    port = 8000
    req = requests.get(f"http://{host}:{port}/query", params={"query": query})
    results = req.json()
    results = results['response']
    pages = [x['metadata']['source'] for x in results]
    return pages

In [None]:
# def generate_rerank_template(symptoms_description, demographics, condition):
#     """
#     Generates a template for reranking the retrieved conditions based on clinical relevance.
    
#     Args:
#         symptoms_description (str): The description of the patient's symptoms.
#         demographics (str): Key demographic details of the patient.
#         condition (str): The description of the condition to be evaluated.
    
#     Returns:
#         str: A formatted template for reranking.
#     """

#     rerank_template = f"""
#     You are an expert clinical AI evaluator.

#     Your task is to assess how clinically appropriate the retrieved condition is, given the patient’s symptom description and their general demographics.

#     We will provide you with:
#     - A **symptom description**: this may include patient-reported symptoms and potentially some contextual details.
#     - Key **demographic details**: this includes sex, age, medical history, and other relevant factors that may influence the clinical picture.
#     - A **condition description**: this typically includes the clinical definition, common symptoms, and related information for a specific condition.

#     # Symptoms:
#     {symptoms_description}

#     # Demographics:
#     {demographics}

#     # Condition:
#     {condition}

#     # Scoring Criteria
#     Read the symptom description carefully. Then, evaluate whether the retrieved condition plausibly explains or matches the presented symptoms. Your goal is to judge **clinical relevance**, i.e., how well this condition could account for what the patient is experiencing.

#     Before providing a score, briefly explain your reasoning. Refer to specific symptoms or aspects of the condition text when appropriate.

#     Output a JSON object containing:
#     - a short **reason** for your score
#     - a **score** from 1 to 10, where:

#     - **1–2**: The condition is clearly **irrelevant** to the symptoms. It describes a clinical picture that is inconsistent with the patient presentation.
#     - **3–4**: The condition has **minimal relevance**. There may be some weak thematic or superficial connections, but it does not plausibly explain the symptoms.
#     - **5–6**: The condition has **partial relevance**. It overlaps with some symptoms, but misses key features or includes inconsistent elements.
#     - **7–8**: The condition is **clinically appropriate** and aligns with most symptoms. Some aspects may not be fully explained, but it is a reasonable differential diagnosis.
#     - **9–10**: The condition is an **excellent match**. It provides a comprehensive and specific explanation for the symptoms, with no significant inconsistencies.

#     Return your output as a JSON object, with this format:
#     {{
#     "reason": "...",
#     "score": ...
#     }}
#     Return exactly valid JSON only, no extra text.
#     """
#     return rerank_template


In [16]:
def generate_rerank_template(symptoms_description, document_text, k=5):
    rerank_template = f"""
    You are part of a retrieval system for a medical domain.
    Given a description of symptoms provided by a patient, an initial retriever has shortlisted several possible conditions.

    The condition titles and after the column their descriptions:
    {document_text}

    The patient's symptom description:
    {symptoms_description}

    Your task is to select the {k} most likely conditions based on the symptoms.
    Please return only the titles of the selected conditions, comma-separated.
    Do not include any additional text or explanations.
    """
    return rerank_template


In [6]:
filename ="../data/nhs-conditions/v4/conditions.jsonl"
with open(filename, 'r') as f:
    lines = f.readlines()
    conditions = [json.loads(line) for line in lines]
conditions_dict = {condition['condition_title']: condition['condition_content'] for condition in conditions}

In [None]:
tot = 0
correct = 0
original_correct = 0
top_10_correct = 0
rerank_correct = 0
avg_length = []


for line in data:
        demographics = line['general_demographics']
        symptoms = line['symptoms_description']
        true_conditions = line['conditions_title']
        possible_conditions = query_retriever(symptoms)
        if true_conditions in set(possible_conditions):
            original_correct += 1
        template = generate_template(symptoms, demographics)
        sub_queries = get_response_from_model(template)
        sub_queries = json.loads(sub_queries)
        for sub_query in sub_queries['queries']:
            other_conditions = query_retriever(sub_query['query'])
            possible_conditions.extend(other_conditions)


        if true_conditions in set(possible_conditions):
            correct += 1
            avg_length.append(len(set(possible_conditions)))

        # Sort possible conditions by frequency and keep the top 5
        possible_conditions.sort(key=lambda x: possible_conditions.count(x), reverse=True)

        if true_conditions in possible_conditions[:10]:
            top_10_correct += 1

        reranking_results = {}
        possible_conditions = list(set(possible_conditions))  # Remove duplicates
        rerank_template = generate_rerank_template(symptoms, "\n\n".join([cond+": "+conditions_dict[cond] for cond in possible_conditions]))
        rerank_response = get_response_from_model(rerank_template)
        rerank_response = rerank_response.strip().split(',')
        rerank_response = [cond.strip() for cond in rerank_response if cond.strip() in conditions_dict]
        if true_conditions in rerank_response:
            rerank_correct += 1
        tot += 1
        if tot % 10 == 0:
            print(f"Processed {tot} queries")
            print (f"Original correct: {original_correct/tot:.2f}")
            print (f'Extended correct:{correct/tot:.2f}')
            print (f'Top 10 correct: {top_10_correct/tot:.2f}')
            print (f'Rerank correct: {rerank_correct/tot:.2f}')
            print(f"Average length of results: {sum(avg_length)/len(avg_length):.2f}")

print(f"Total: {tot}, Correct: {correct}, Accuracy: {correct/tot:.2f}")

Processed 10 queries
Original correct: 0.60
Extended correct:0.80
Top 10 correct: 0.60
Rerank correct: 0.40
Average length of results: 23.38
Total: 10, Correct: 8, Accuracy: 0.80
