In [None]:
from openai import OpenAI

def format_list(l):
    to_return = ""
    for i, s in enumerate(l):
        to_return += f"#{i} - \"{s}\"\n"
    return to_return

system_prompt = """You will receive a clinical question interpreted from a patient's question and a list of paragraphs extracted from a clinical note.

The questions are asked through the patient portal by patients.

Create a chain of thought to determine if the paragraph is relevant to answering the question.
Put your reasoning between <think> and </think> tags.

Is the paragraph number %sentence-number% (indexed from 0) relevant to answering the question?
The paragraph does not need to give a full answer, but should be relevant in formulating the answer.
Give a Yes or No answer."""

client = OpenAI(
    base_url="http://localhost:8888/v1",
    api_key="anything"
)

def predict(question, sentences, sentence_number):
    messages = [
        {
            "role": "user",
            "content": system_prompt.replace("%sentence-number%", str(sentence_number))
        },
        {
            "role": "user",
            "content": f"Question: {question}"
        },
        {
            "role": "user",
            "content": f"List of paragraphs: {format_list(sentences)}"
        }
    ]

    response = client.chat.completions.create(
            messages=messages,
            stream=False,
            model="loaded-model",
            temperature=0.1,
            max_tokens=4096,
        )
    
    return response.choices[0].message.content

In [2]:
import json

with open("../v1_test.json", "r", encoding="utf-8") as f:
    data = json.load(f)

In [None]:
import re

def extract_yes_no_answer(text):
    text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE)

    matches = re.findall(r'\\text\{([^}]*)\}', text)

    if matches:
        last_found = matches[-1]
        return last_found
    
    matches = re.findall(r'\\boxed\{([^}]*)\}', text)

    if matches:
        last_found = matches[-1]
        return last_found

    text = text.replace("\n", " ")
    text = re.sub(r"\s+", " ", text)
    text = text.strip().lower()
    text = text.replace("*", "")
    text = text.replace(":", "")

    yes_patterns = [
        r'^\s*yes\b',
        r'\byes,\s',
        r'\bthe answer is yes\b',
        r'\bthe final answer is yes\b',
        r'\bit is (very )?likely.*\byes\b',
        r'\byes it is\b',
        r'\banswer yes\b',
    ]
    
    no_patterns = [
        r'^\s*no\b',
        r'\bno,\s',
        r'\bthe answer is no\b',
        r'\bthe final answer is no\b',
        r'\bit is (very )?unlikely.*\bno\b',
        r'\bno it is not\b',
        r'\banswer no\b',
    ]

    for pat in yes_patterns:
        if re.search(pat, text):
            return "Yes"

    for pat in no_patterns:
        if re.search(pat, text):
            return "No"

    return "Unknown"


In [4]:
answer_to_number = {
    "No": 0,
    "Yes": 1,
    "Unknown": -1,
}

predictions_to_save = {}

for i, case in data.items():
    question = case["clinical_question"]
    note_excerpt = case["note_excerpt"]
    sentences = case["note_excerpt_sentences"]
    
    case_to_save = {}
    result_list = []
    for sentence_number, sentence in enumerate(sentences):
        result = predict(question, sentences, sentence_number)
        answer = extract_yes_no_answer(result)
        result_list.append(answer_to_number[answer])
        case_to_save[sentence_number] = {
            "sentence": sentence,
            "raw_prediction": result,
            "prediction": answer_to_number[answer]
        }

    predictions_to_save[i] = case_to_save

    print(i, result_list)

21 [1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0]
22 [1, 1, 1, 0, 0, 0, 0, 0, 0]
23 [0, 0, 0, 0, 1, 1, 0, 0]
24 [1, 1, 0, 1, 0]
25 [1, 1, 1, 0, 0, 0, 0]
26 [1, 0, 1, 1, 0, 0, 0]
27 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0]
28 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0]
29 [0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]
30 [0, 1, 0, 0, 0, 1, 1, 1, 1]
31 [0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
32 [0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0]
33 [0, 0, 1, 1, 1, 1, 1, 0, 1, 0]
34 [1, 0, 1, 1, 1, 0, 0, 0, 0, 0]
35 [0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0]
36 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0]
37 [1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1]
38 [0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1]
39 [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1]
40 [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1]
41 [1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1]
42 [

In [None]:
with open("../predictions-gemma-promptA.json", "w", encoding="utf-8") as f:
    json.dump(predictions_to_save, f, indent=4, ensure_ascii=False)