In [None]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langchain_huggingface import HuggingFacePipeline
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
import os
from sentence_transformers import SentenceTransformer
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_chroma import Chroma
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser


KeyboardInterrupt: 

### Step 1: Imports etc

In [None]:
json_path = os.path.join(os.getcwd(), "ori_pqal.json")
tmp_data = pd.read_json(json_path).T

# some labels have been defined as "maybe", only keep the yes/no answers
tmp_data = tmp_data[tmp_data.final_decision.isin(["yes", "no"])]


documents = pd.DataFrame({"abstract": tmp_data.apply(lambda row: (" ").join(row.CONTEXTS+[row.LONG_ANSWER]), axis=1),
             "year": tmp_data.YEAR})
questions = pd.DataFrame({"question": tmp_data.QUESTION,
             "year": tmp_data.YEAR,
             "gold_label": tmp_data.final_decision,
             "gold_context": tmp_data.LONG_ANSWER,
             "gold_document_id": documents.index})

### Step 2: Configuring LangChainLM 

In [None]:
# Step 2: Configure LangChainLM
# Choose a model from Hugging Face, prio training speed
lm = HuggingFacePipeline.from_model_id(
    model_id="Qwen/Qwen2.5-1.5B-Instruct",
    task="text-generation",
    model_kwargs={
        "torch_dtype": "auto",
        "device_map": "auto"
        },
    pipeline_kwargs={
        "max_new_tokens": 200,
        "temperature": 0.7,
        "top_p": 0.95
    }
)
lm.pipeline.tokenizer.padding_side = "left"
if lm.pipeline.tokenizer.pad_token is None:
    lm.pipeline.tokenizer.pad_token = lm.pipeline.tokenizer.eos_token

response = lm.invoke("Hello, how are you?")
print(response)

Device set to use cuda:0


Hello, how are you? How do I make this sentence sound better?
"Hello, how are you?"
It's a polite way to ask someone about their well-being.
You could also use: "Hi there! How are you doing?" or "Hey! How are things going for you today?"

Both of those options sound much more natural and friendly than just saying "Hello, how are you?" without any additional context or personalization. They show that you're paying attention to the person you're speaking to and care about how they're doing in general. Let me know if you have any other questions! 

If you want to keep it short, but still be nice, you could say:
"Hey! How's it going?" It's a casual greeting that shows interest in how the other person is feeling.

Let me know which one you'd like to try out! ðŸ˜Š

Also, remember that your tone of voice can go a long way in making a greeting feel warm and welcoming. So, make sure


### Step 3: Setup the document database

3.1 Downloading the embeddding model

In [None]:
# Pre download the embedding model, LangChain download bug
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')


embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    encode_kwargs={"normalize_embeddings": True}
    )

test = "What is the capital of France?"
test_embedding = embeddings.embed_query(test)
print(test_embedding)

[0.08204812556505203, 0.03605549782514572, -0.0038928634021431208, -0.0048810257576406, 0.025651128962635994, -0.05714346095919609, 0.01219159085303545, 0.004678937140852213, 0.034949883818626404, -0.02242193929851055, -0.008005239069461823, -0.10935358703136444, 0.022724760696291924, -0.02932084910571575, -0.043522048741579056, -0.1202411875128746, -0.000848623167257756, -0.018150145187973976, 0.05612955987453461, 0.0030852784402668476, 0.0023363700602203608, -0.016839278861880302, 0.06362468004226685, -0.023660236969590187, 0.03149350360035896, -0.03479793295264244, -0.020548813045024872, -0.0027909616474062204, -0.011037994176149368, -0.03612673282623291, 0.054141074419021606, -0.03661712631583214, -0.025008657947182655, -0.03817039728164673, -0.049603626132011414, -0.015148111619055271, 0.021315019577741623, -0.01274043694138527, 0.07670093327760696, 0.044355761259794235, -0.01083488017320633, -0.029759984463453293, -0.016970470547676086, -0.024691835045814514, 0.008087101392447948

3.2 Chunking

In [None]:


text_splitter = RecursiveCharacterTextSplitter(
    # Set a really small chunk size, just to show.
    chunk_size=500,
    chunk_overlap=20,
)

metadatas = [{"id": idx} for idx in documents.index]
texts = text_splitter.create_documents(documents.abstract.tolist(), metadatas=metadatas)
#print(texts[0])
# print(texts[1])
# print(texts[2])
# print(texts[3])
# print(texts[4])

Step 3.3: Define a vector store

In [None]:
vector_store = Chroma.from_documents(
    documents=texts,
    embedding=embeddings,
    persist_directory="./chroma_db"
)


In [None]:


# Sanity check, Chroma uses L2-score by default so scores closer to 0 means that its a good match
results = vector_store.similarity_search_with_score(
    "Bajsapa?", k=3
)
for res, score in results:
    print(f"* [SIM={score:3f}] {res.page_content} [{res.metadata}]")

* [SIM=1.466296] Geriatr Gerontol Int 2016; 16: 570-576. [{'id': 25981682}]
* [SIM=1.503360] from the date of pathological diagnosis to the date of primary treatment. Mortality data was obtained from the National Registry of Births and Deaths. Last date of follow-up was November 2010. Median TPT was 18 days. Majority 508 (69.1%) of the patients received treatment within 30 days after diagnosis. The majority was surgically treated. Ethnicity (p=0.002) and stage at presentation (p=0.007) were significantly associated with delayed TPT. Malay ethnicity had delayed TPT compared to the [{'id': 23234860}]
* [SIM=1.527578] was more common in JAS than AAS. [{'id': 14655021}]


### Step 4: Define the full RAG pipeline (Option B)

In [2]:
from typing import Any
from langchain_core.documents import Document
from langchain.agents.middleware import AgentMiddleware, AgentState


class State(AgentState):
    context: list[Document]


class RetrieveDocumentsMiddleware(AgentMiddleware[State]):
    state_schema = State

    def __init__(self, vector_store):
        self.vector_store = vector_store

    def before_model(self, state: AgentState) -> dict[str, Any] | None:
        last_message = state["messages"][-1] # get the user input query
        retrieved_docs = self.vector_store.similarity_search(last_message.text)  # search for documents

        docs_content = "\n\n".join(doc.page_content for doc in retrieved_docs)  

        augmented_message_content = (
            "You are a scientific QA assistant.\n"
            "You will be given a question and a context passage.\n"
            "Use only the information in the context to answer.\n\n"
            "Answer with either 'Yes' or 'No' as the first word of your answer, "
            "followed by a short explanation.\n\n"
            f"Question:\n{last_message.text}\n\n"
            f"Context:\n{docs_content}"
        )
        return {
            "messages": [last_message.model_copy(update={"content": augmented_message_content})],
            "context": retrieved_docs,
        }

KeyboardInterrupt: 

In [None]:
from langchain.agents import create_agent

rag_middleware = RetrieveDocumentsMiddleware(vector_store)

agent = create_agent(
    model=lm,
    tools=[],
    middleware=[rag_middleware],
)

your_query = questions["question"].iloc[2]

for step in agent.stream(
    {"messages": [{"role": "user", "content": your_query}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()



Syncope during bathing in infants, a pediatric form of water-induced urticaria?

You are a scientific QA assistant.
You will be given a question and a context passage.
Use only the information in the context to answer.

Answer with either 'Yes' or 'No' as the first word of your answer, followed by a short explanation.

Question:
Syncope during bathing in infants, a pediatric form of water-induced urticaria?

Context:
after a few weeks without baths. After a 2-7 year follow-up, three out of seven infants continue to suffer from troubles associated with sun or water. "Aquagenic maladies" could be a pediatric form of the aquagenic urticaria.

seizure or gastroesophageal reflux but this was doubtful. The hypothesis of an equivalent of aquagenic urticaria was then considered; as for patients with this disease, each infant's family contained members suffering from dermographism, maladies or eruption after exposure to water or sun. All six infants had dermographism. We found an increase in b

In [None]:
def extract_label_from_answer(answer: str):
    if not isinstance(answer, str):
        return None
    first = answer.strip().split()[0].lower()
    if first.startswith("yes"):
        return "yes"
    if first.startswith("no"):
        return "no"
    return None


In [None]:
rag_preds = []
rag_golds = []
rag_valid_mask = []
doc_hit_flags = []

for i, row in questions.iterrows():
    q = row["question"]
    gold = row["gold_label"].lower()
    gold_doc_id = row["gold_document_id"]

    res = agent.invoke({"messages": [{"role": "user", "content": q}]})
    msg = res["messages"][-1]
    answer = getattr(msg, "content", getattr(msg, "text", str(msg)))

    pred = extract_label_from_answer(answer)
    if pred is not None:
        rag_preds.append(pred)
        rag_golds.append("yes" if gold == "yes" else "no")
        rag_valid_mask.append(True)
    else:
        rag_valid_mask.append(False)

    retrieved_docs = res.get("context", [])
    hit = any(doc.metadata.get("id") == gold_doc_id for doc in retrieved_docs)
    doc_hit_flags.append(hit)

n_total = len(questions)
n_valid = sum(rag_valid_mask)

tp = sum(1 for p, g in zip(rag_preds, rag_golds) if p == "yes" and g == "yes")
fp = sum(1 for p, g in zip(rag_preds, rag_golds) if p == "yes" and g == "no")
fn = sum(1 for p, g in zip(rag_preds, rag_golds) if p == "no" and g == "yes")

precision_yes = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall_yes = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1_yes = (
    2 * precision_yes * recall_yes / (precision_yes + recall_yes)
    if (precision_yes + recall_yes) > 0
    else 0.0
)

acc_rag = (
    sum(p == g for p, g in zip(rag_preds, rag_golds)) / n_valid
    if n_valid > 0
    else float("nan")
)

retrieval_acc = sum(doc_hit_flags) / n_total

print("RAG evaluation:")
print("  Total questions:", n_total)
print("  Valid answers:", n_valid)
print("  Accuracy (on valid):", acc_rag)
print("  F1 (Yes as positive):", f1_yes)
print("  Retrieval hit rate (gold doc in retrieved):", retrieval_acc)


You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


KeyboardInterrupt: 

In [None]:
import math

import re

def extract_label_from_answer(answer: str):
    if not isinstance(answer, str):
        return None
    cleaned = answer.strip().lower()
    cleaned = cleaned.replace("<|im_start|>assistant", "").replace("<|im_end|>", "").strip()
    m = re.search(r"\b(yes|no)\b", cleaned)
    if m:
        return m.group(1)
    return None


# ===================== CONFIG: SUBSET SIZE =====================
subset_size = 20  # change this if you want fewer/more
subset_size = min(subset_size, len(questions))

questions_eval = questions.sample(n=subset_size, random_state=42).reset_index(drop=True)
print(f"Evaluating on subset of size {len(questions_eval)} out of {len(questions)}")

# ===================== RAG EVALUATION (BATCHED) =====================

rag_prompts = []
rag_golds = []
doc_hit_flags = []

for _, row in questions_eval.iterrows():
    q = row["question"]
    gold = row["gold_label"].lower()
    gold_doc_id = row["gold_document_id"]

    retrieved_docs = vector_store.similarity_search(q, k=1)
    docs_content = "\n\n".join(doc.page_content for doc in retrieved_docs)

    prompt = (
        "You are a scientific QA assistant.\n"
        "You will be given a medical yes/no question and a context passage.\n"
        "Use only the information in the context to answer.\n\n"
        "Your answer MUST start with exactly one word: 'Yes' or 'No'. "
        "After that word, you may add a short explanation.\n\n"
        f"Question:\n{q}\n\n"
        f"Context:\n{docs_content}"
    )


    rag_prompts.append(prompt)
    rag_golds.append("yes" if gold == "yes" else "no")
    hit = any(doc.metadata.get("id") == gold_doc_id for doc in retrieved_docs)
    doc_hit_flags.append(hit)

rag_answers = lm.batch(rag_prompts)

rag_preds = [extract_label_from_answer(ans) for ans in rag_answers]

valid_idx = [i for i, p in enumerate(rag_preds) if p is not None]
n_total = len(rag_preds)
n_valid = len(valid_idx)

rag_preds_valid = [rag_preds[i] for i in valid_idx]
rag_golds_valid = [rag_golds[i] for i in valid_idx]

tp = sum(1 for p, g in zip(rag_preds_valid, rag_golds_valid) if p == "yes" and g == "yes")
fp = sum(1 for p, g in zip(rag_preds_valid, rag_golds_valid) if p == "yes" and g == "no")
fn = sum(1 for p, g in zip(rag_preds_valid, rag_golds_valid) if p == "no" and g == "yes")

precision_yes = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall_yes = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1_yes = (
    2 * precision_yes * recall_yes / (precision_yes + recall_yes)
    if (precision_yes + recall_yes) > 0
    else 0.0
)

acc_rag = (
    sum(p == g for p, g in zip(rag_preds_valid, rag_golds_valid)) / n_valid
    if n_valid > 0
    else float("nan")
)

retrieval_acc = sum(doc_hit_flags) / n_total

print("\nRAG evaluation (batched, subset):")
print("  Total questions (subset):", n_total)
print("  Valid answers:", n_valid)
print("  Accuracy (on valid):", acc_rag)
print("  F1 (Yes as positive):", f1_yes)
print("  Retrieval hit rate (gold doc in retrieved):", retrieval_acc)

# ===================== BASELINE EVALUATION (NO CONTEXT, BATCHED) =====================

baseline_prompts = []
baseline_golds = []

baseline_template = (
    "You are a medical QA classifier.\n"
    "You will be given a medical yes/no question.\n"
    "Answer with 'Yes' or 'No' as the first word, then a short explanation.\n\n"
    "Question:\n{question}\n"
)

for _, row in questions_eval.iterrows():
    q = row["question"]
    gold = row["gold_label"].lower()
    prompt = baseline_template.format(question=q)
    baseline_prompts.append(prompt)
    baseline_golds.append("yes" if gold == "yes" else "no")

baseline_answers = lm.batch(baseline_prompts)

baseline_preds = [extract_label_from_answer(a) for a in baseline_answers]
valid_idx_b = [i for i, p in enumerate(baseline_preds) if p is not None]

n_total_b = len(baseline_preds)
n_valid_b = len(valid_idx_b)

baseline_preds_valid = [baseline_preds[i] for i in valid_idx_b]
baseline_golds_valid = [baseline_golds[i] for i in valid_idx_b]

tp_b = sum(1 for p, g in zip(baseline_preds_valid, baseline_golds_valid) if p == "yes" and g == "yes")
fp_b = sum(1 for p, g in zip(baseline_preds_valid, baseline_golds_valid) if p == "yes" and g == "no")
fn_b = sum(1 for p, g in zip(baseline_preds_valid, baseline_golds_valid) if p == "no" and g == "yes")

precision_yes_b = tp_b / (tp_b + fp_b) if (tp_b + fp_b) > 0 else 0.0
recall_yes_b = tp_b / (tp_b + fn_b) if (tp_b + fn_b) > 0 else 0.0
f1_yes_b = (
    2 * precision_yes_b * recall_yes_b / (precision_yes_b + recall_yes_b)
    if (precision_yes_b + recall_yes_b) > 0
    else 0.0
)

acc_baseline = (
    sum(p == g for p, g in zip(baseline_preds_valid, baseline_golds_valid)) / n_valid_b
    if n_valid_b > 0
    else float("nan")
)

print("\nBaseline (no context, batched, subset):")
print("  Total questions (subset):", n_total_b)
print("  Valid answers:", n_valid_b)
print("  Accuracy (on valid):", acc_baseline)
print("  F1 (Yes as positive):", f1_yes_b)


Evaluating on subset of size 20 out of 890

RAG evaluation (batched, subset):
  Total questions (subset): 20
  Valid answers: 20
  Accuracy (on valid): 0.7
  F1 (Yes as positive): 0.8235294117647058
  Retrieval hit rate (gold doc in retrieved): 1.0

Baseline (no context, batched, subset):
  Total questions (subset): 20
  Valid answers: 20
  Accuracy (on valid): 0.7
  F1 (Yes as positive): 0.8235294117647058
