In [1]:
# debug_rag_pipeline.ipynb

# Cell 1: Imports and Initial Setup
# ==================================
# This cell contains all the necessary imports and loads the environment variables.

import logging
import json
import re
import time
import os
import ast
import pandas as pd
from dotenv import load_dotenv
from json_repair import repair_json
import faiss
import numpy as np
from openai import OpenAI
from sentence_transformers import SentenceTransformer, CrossEncoder
from nebula3.gclient.net import ConnectionPool
from nebula3.Config import Config
from llama_index.llms.ollama import Ollama
import torch
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
from bertviz import head_view, model_view
import warnings

# Configure logging to be clear in the notebook
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True)

# Load environment variables from the .env file in the parent directory
load_dotenv(dotenv_path='../.env')

print("Imports and setup complete.")

  from .autonotebook import tqdm as notebook_tqdm


Imports and setup complete.


In [33]:
# Cell 2: Define All Helper Functions
# =====================================
# This cell contains all the stateless helper functions from your RAG pipeline.
# You can run this once and then focus on the main logic cell.

# --- ROBUST FILE PATHS ---
_CURR_DIR = os.getcwd() # Get current working directory of the notebook
MODEL_NAME = "pritamdeka/S-PubMedBert-MS-MARCO"
INDEX_FILE = os.path.join(_CURR_DIR, "../graph_rag/faiss_index.bin")
TEXTS_FILE = os.path.join(_CURR_DIR, "../graph_rag/semantic_nodes.json")

def connect_nebula():
    try:
        config = Config()
        config.max_connection_pool_size = 10
        connection_pool = ConnectionPool()
        connection_pool.init([("127.0.0.1", 9669)], config)
        client = connection_pool.get_session("root", "nebula")
        client.execute("USE petagraph;")
        logging.info("Successfully connected to NebulaGraph.")
        return client, connection_pool
    except Exception as e:
        logging.error(f"Failed to connect to NebulaGraph: {e}")
        return None, None

def load_faiss_index():
    try:
        index = faiss.read_index(INDEX_FILE)
        with open(TEXTS_FILE, "r") as f: texts = json.load(f)
        logging.info(f"FAISS index loaded successfully from: {INDEX_FILE}")
        return index, texts
    except Exception as e:
        logging.error(f"Could not load FAISS index from '{INDEX_FILE}'. Error: {e}")
        return None, None

def load_prompt_assets(task_name, prompt_id, max_shots, library_dir="prompt_library"):
    assets = {"prompt": "", "output_format": "", "shots": []}
    task_dir = os.path.join(library_dir, task_name)
    prompts_path = os.path.join(task_dir, "prompts.json")
    if os.path.exists(prompts_path):
        with open(prompts_path, 'r') as f:
            prompts = json.load(f).get("prompts", [])
            selected = next((p for p in prompts if p.get("id") == prompt_id), None)
            if selected: assets.update(selected)
    shots_path = os.path.join(task_dir, "shots.json")
    if os.path.exists(shots_path):
        with open(shots_path, 'r') as f:
            shots_list = json.load(f).get("shots", [])
            loaded_shots = shots_list[0] if shots_list and isinstance(shots_list[0], list) else shots_list
            assets["shots"] = loaded_shots[:max_shots]
    logging.info(f"Loaded {len(assets['shots'])} shots for '{task_name}' (max_shots: {max_shots}).")
    return assets

# def retrieve_semantic_seeds(query, model, index, texts, top_k=50):
#     query_vec = model.encode([query], convert_to_numpy=True)
#     _, indices = index.search(query_vec, top_k)
#     return [texts[i]["sui"] for i in indices[0]]

def retrieve_semantic_nodes(query, model, index, texts, top_k=50, top_m=10):
    """
    MODIFIED: This function now retrieves both the top_k SUIs for graph traversal
    and the top_m full semantic texts for direct use.
    """
    query_vec = model.encode([query], convert_to_numpy=True)
    _, indices = index.search(query_vec, top_k)
    
    top_indices = indices[0]
    
    # Get the SUIs for the top_k results (for potential graph traversal)
    top_k_suis = [texts[i]["sui"] for i in top_indices]
    
    # --- YOUR NEW FEATURE ---
    # Get the full text content for the top_m results directly.
    # We use [:top_m] to select the m most similar results from the top_k.
    top_m_texts = [texts[i]["name"] for i in top_indices[:top_m]]
    
    logging.info(f"Retrieved {len(top_k_suis)} SUIs and the top {len(top_m_texts)} semantic texts.")
    return top_k_suis, top_m_texts

def get_definitions_from_graph(client, suis):
    if not suis: return []
    try:
        suis_str = ", ".join(f'"{sui}"' for sui in suis)
        resp_cuis = client.execute(f'GO FROM {suis_str} OVER STY REVERSELY YIELD DISTINCT src(edge) AS cui')
        if resp_cuis.is_empty(): return []
        cuis = [r.values[0].get_sVal().decode("utf-8") for r in resp_cuis.rows()]
        cuis_str = ", ".join(f'"{cui}"' for cui in cuis)
        resp_defs = client.execute(f'GO FROM {cuis_str} OVER DEF YIELD DISTINCT dst(edge) AS def_id')
        if resp_defs.is_empty(): return []
        def_ids = [r.values[0].get_sVal().decode("utf-8") for r in resp_defs.rows()]
        def_ids_str = ", ".join(f'"{d}"' for d in def_ids)
        resp_final = client.execute(f'FETCH PROP ON Definition {def_ids_str} YIELD Definition.DEF')
        if resp_final.is_empty(): return []
        return [r.values[0].get_sVal().decode("utf-8") for r in resp_final.rows()]
    except Exception as e:
        logging.error(f"An error during graph traversal: {e}")
        return []

def rerank_definitions(question, definitions, top_k=15):
    if not definitions: return []
    cross_encoder = CrossEncoder('pritamdeka/S-PubMedBert-MS-MARCO')
    # The rest of the function works exactly the same, but the results will be much better.
    scores = cross_encoder.predict([[question, d] for d in definitions])
    scored_definitions = sorted(zip(scores, definitions), key=lambda x: x[0], reverse=True)
    top_definitions = [d for _, d in scored_definitions[:top_k]]
    logging.info(f"Re-ranked {len(definitions)} definitions and selected the top {len(top_definitions)}.")
    return top_definitions

def format_shots(shots):
    if not shots: return ""
    examples = []
    for shot in shots:
        inp = shot.get("input", {})
        out = shot.get("Output", {})
        opts = "\\n".join([f"{k}: {v}" for k, v in inp.get("Options", {}).items()])
        example = (f"---\nExample Question: {inp.get('Question', '')}\nExample Options:\n{opts}\n"
                   f"Example Correct Answer:\n```json\n{json.dumps(out, indent=2)}\n```\n---")
        examples.append(example)
    return "\\n\\n".join(examples)

def generate_llm_response(llm_client, model_name, question, options, definitions, prompt_assets, consistency_result="", no_rag=False):
    """
    MODIFIED: This function now dynamically builds the prompt based on whether
    RAG is enabled, omitting the Context and Guidance sections in no-RAG mode.
    """
    main_prompt_instruction = prompt_assets.get("prompt", "")
    few_shot_str = format_shots(prompt_assets.get("shots", []))
    options_str = "\\n".join([f"{k}: {v}" for k, v in options.items()])
    # output_format_instruction = (
    # "You MUST provide your response as a single, valid JSON object with the following keys:\n"
    # "1. `cop_index`: The integer index of the correct option.\n"
    # "2. `answer`: The full string value of the correct option.\n"
    # "3. `why_correct`: This MUST be a list of strings. Each string in the list is a distinct, logical step in your reasoning. You must follow this exact 4-step structure for the list:\n"
    # "   - Step 1 (string 1): Quote the exact sentences or phrases from the provided Context that are most relevant to the Question. If there is no Context available, just state 'I don't know'.\n"
    # "   - Step 2 (string 2): Analyze and state the key information provided in the Question itself (e.g., specific conditions, patient details).\n"
    # "   - Step 3 (string 3): If Context is available, explain how the Context supports the key information in the Question, otherwise state 'I don't know'.\n"
    # "   - Step 4 (string 4): Explicitly state why the chosen option is correct based on your reasoning.\n"
    # "4. `why_others_incorrect`: A brief explanation for why each of the other options is wrong."
    # )
    # output_format_instruction = (
    #     "You MUST provide your response as a single, valid JSON object with the following keys:\n"
    #     "1. `cop_index`: The integer index of the correct option.\n"
    #     "2. `answer`: The full string value of the correct option.\n"
    #     "3. `why_correct`: A detailed explanation of only the correct answer. This explanation MUST follow a specific three-part structure:\n"
    #     "   - First, briefly state the key concepts in the question.\n"
    #     "   - Second, quote all the exact sentences from the Context that directly support your answer.\n"
    #     "   - Finally, provide a concluding sentence that links the evidence to the chosen answer.\n"
    #     "4. `why_others_incorrect`: A brief explanation for why each of the other options is wrong."
    # )
    output_format_instruction = ("You MUST provide your response as a single, valid JSON object with the keys specified in the output_format. "
                                 "Ensure the JSON is well-formed and includes all required keys.")
    
    # --- DYNAMIC PROMPT COMPONENT LOGIC ---
    context_block = ""
    consistency_guidance = ""
    if not no_rag:
        # Only add context and guidance if RAG is enabled.
        context_str = " ".join(definitions) if definitions else "No relevant biomedical context found."
        context_block = f"Context: {context_str}\n\n"
        if consistency_result in ["CONTRADICTED", "NEUTRAL"]:
            consistency_guidance = (
                f"\n--- CRITICAL GUIDANCE ---\nA fact-check determined the context is '{consistency_result}' to the question's premise. This strongly indicates the question is flawed or unanswerable. "
                f"Your primary task is to explain WHY the question is flawed. Set 'cop_index' to the 'None of the above' option if it exists, otherwise set it to -1.\n--- END GUIDANCE ---\n"
            )

    # Assemble the final prompt from the dynamic components
    base_prompt = (
        f"{main_prompt_instruction}\n"
        f"output_format: {prompt_assets.get('output_format', '')}\n\n"
        f"{consistency_guidance}" # Will be empty in no-RAG mode
        f"Examples:\n{few_shot_str}\n\n"
        f"--- CURRENT TASK ---\n"
        f"{context_block}" # Will be empty in no-RAG mode
        f"Question: {question}\nOptions:\n{options_str}\n\n"
        f"Provide your answer. {output_format_instruction}"
    )
    for attempt in range(2):
        prompt = base_prompt + ("\n\nYour previous response was invalid. Please provide ONLY the JSON object." if attempt > 0 else "")
        try:
            print(prompt)
            response = llm_client.chat.completions.create(model=model_name, messages=[{"role": "user", "content": prompt}], temperature=0.0, response_format={"type": "json_object"})
            raw_text = response.choices[0].message.content
            # response = llm_client.complete(prompt)
            # raw_text = response.text
            repaired_json_str = repair_json(raw_text)
            
            # 2. Parse the now-guaranteed-to-be-valid JSON string.
            parsed_json = json.loads(repaired_json_str)

            if 'cop_index' not in parsed_json:
                raise ValueError("Output JSON is missing the required 'cop_index' key.")
            
            return parsed_json
        except Exception as e:
            logging.warning(f"Attempt {attempt + 1} failed: {e}. Raw response: '{locals().get('raw_text', 'N/A')}'")
            time.sleep(1)
    logging.error(f"Failed to get valid LLM response after multiple attempts.")
    return None

print("All helper functions defined.")

All helper functions defined.


In [None]:
# Cell 3: Load Heavy, Shared Resources
# ======================================
# This cell loads the FAISS index, Sentence Transformer, and connects to Nebula.
# This is a slow cell, so you only run it once per session.

st_model = SentenceTransformer(MODEL_NAME)
faiss_index, faiss_texts = load_faiss_index()
#nebula_client, nebula_pool = connect_nebula()



2025-08-11 23:50:08,175 - INFO - Use pytorch device_name: cuda:0
2025-08-11 23:50:08,178 - INFO - Load pretrained SentenceTransformer: pritamdeka/S-PubMedBert-MS-MARCO
2025-08-11 23:51:13,568 - INFO - FAISS index loaded successfully from: /home/macharya/dev/medkg-eval/../graph_rag/faiss_index.bin
2025-08-11 23:51:13,580 - ERROR - Failed to connect to NebulaGraph: The services status exception: [services: ('127.0.0.1', 9669), status: BAD]


In [38]:
nebula_client, nebula_pool = connect_nebula()

2025-08-12 03:43:34,778 - INFO - Get connection to ('127.0.0.1', 9669)
2025-08-12 03:43:34,819 - INFO - Successfully connected to NebulaGraph.


In [4]:
# Initialize the OpenAI client for the LLM
llm_client = OpenAI(base_url=os.getenv("OPENAI_BASE_URL"), api_key=os.getenv("API_KEY"))
#llm_client= Ollama(model="llama3:8b", request_timeout=300)

#response = llm.complete(prompt)
#print(response)

print("All heavy resources loaded and ready.")

All heavy resources loaded and ready.


In [40]:
# Cell 4: The Interactive Debugging Cell
# =======================================
# THIS IS THE MAIN CELL YOU WILL USE FOR DEBUGGING.
# 1. Set the parameters for the question you want to debug.
# 2. Run this cell to see the output of each stage.
# 3. Analyze the output and tune the parameters (e.g., top_k values).
# 4. Re-run this cell to see the effect of your changes instantly.

# --- 1. SET YOUR DEBUGGING PARAMETERS HERE ---
TASK_NAME = 'reasoning_fct'
PROMPT_ID = 'v3'
MAX_SHOTS = 3
MODEL_NAME_TO_DEBUG = 'deepseek-r1:14b' #'llama3.2:latest' #'deepseek-r1:14b'

# ID of the specific question you want to analyze (find this in your CSV)
QUESTION_ID_TO_DEBUG = "839de867-3100-4283-a219-ec349eee415f" #"0ac6c5c7-9826-441a-81d5-68478e6299bb" #"dd57ff3a-0ba0-4e1b-8010-05c9e0629c0c" "140d832a-b8ae-4791-aada-6fd62f313adb" 

# You can also tune these parameters for experiments
RETRIEVAL_TOP_K = 20
RERANK_TOP_K = 15

# --- 2. LOAD THE SPECIFIC QUESTION DATA ---
data_file = f"data/{TASK_NAME}.csv"
df = pd.read_csv(data_file)
question_row = df[df['id'] == QUESTION_ID_TO_DEBUG].iloc[0]

question = question_row['question']
options = ast.literal_eval(question_row['options'])
correct_index = question_row['correct_index']
correct_answer_text = question_row['correct_answer']




In [41]:
print(f"--- DEBUGGING ID: {QUESTION_ID_TO_DEBUG} ---")
print(f"Question: {question}")
print(f"Options: {options}")
print(f"Correct Index: {correct_index} ('{correct_answer_text}')")
print("--------------------------------------------------\n")

no_rag_flag=False
final_definitions=[]
if not no_rag_flag:
    # --- 3. RUN THE RAG PIPELINE STEP-BY-STEP ---

    # --- STAGE 1: SEMANTIC RETRIEVAL ---
    print(f"--- STAGE 1: Semantic Retrieval (Top {RETRIEVAL_TOP_K}) ---")
    query = question + " " + " ".join(options.values())
    # suis = retrieve_semantic_seeds(query, st_model, faiss_index, faiss_texts, top_k=RETRIEVAL_TOP_K)
    suis, top_semantic_texts = retrieve_semantic_nodes(query, st_model, faiss_index, faiss_texts, top_k=30000, top_m=30)
    # print(f"Found {len(suis)} semantic seeds (SUIs).")
    print("--------------------------------------------------\n")
    # --- STAGE 2: GRAPH TRAVERSAL ---
    print("--- STAGE 2: Knowledge Graph Traversal ---")
    retrieved_definitions = get_definitions_from_graph(nebula_client, suis)
    print(f"Retrieved {len(retrieved_definitions)} definitions from the graph.")
    for i, definition in enumerate(retrieved_definitions[:5]): # Print first 5
        print(f"  Initial Def [{i}]: {definition[:120]}...")
    print("--------------------------------------------------\n")

    # --- STAGE 3: RE-RANKING ---
    print(f"--- STAGE 3: Re-ranking (Top {RERANK_TOP_K}) ---")
    final_definitions = rerank_definitions(question, retrieved_definitions, top_k=RERANK_TOP_K)
    #combined_context = list(set(top_semantic_texts + graph_definitions))
    final_definitions = list(set(top_semantic_texts + final_definitions))
    #final_definitions = rerank_definitions(question, final_definitions, top_k=RERANK_TOP_K)
    print(f"Selected the top {len(final_definitions)} most relevant definitions.")
    for i, definition in enumerate(final_definitions):
        print(f"  Final Ctx [{i}]: {definition[:120]}...")
    print("--------------------------------------------------\n")
# --- STAGE 4: GENERATION ---
print("--- STAGE 4: LLM Generation ---")
# Load the prompts and shots for this specific task
prompt_assets = load_prompt_assets(TASK_NAME, PROMPT_ID, MAX_SHOTS, library_dir="prompt_library")
llm_output = generate_llm_response(llm_client, MODEL_NAME_TO_DEBUG, question, options, final_definitions, prompt_assets, no_rag=no_rag_flag)

print("\n--- FINAL LLM OUTPUT ---")
# Pretty-print the JSON output for easy reading
if llm_output:
    print(json.dumps(llm_output, indent=2))
    
    # --- AUTOMATED ANALYSIS ---
    predicted_index = llm_output.get('cop_index')
    print("\n--- ANALYSIS ---")
    print(f"Correct Index:    {correct_index}")
    print(f"Predicted Index:  {predicted_index}")
    if str(predicted_index) == str(correct_index):
        print("✅ RESULT: CORRECT")
    else:
        print("❌ RESULT: INCORRECT")
else:
    print("LLM failed to generate a valid response.")

print("--------------------")

--- DEBUGGING ID: 839de867-3100-4283-a219-ec349eee415f ---
Question: Most impoant intracellular buffer ?
Options: {'0': 'Bicarbonate', '1': 'Albumin', '2': 'Phosphate', '3': 'Ammonia', 'correct answer': 'Ammonia'}
Correct Index: 2 ('Phosphate')
--------------------------------------------------

--- STAGE 1: Semantic Retrieval (Top 20) ---


  return forward_call(*args, **kwargs)
Batches: 100%|██████████| 1/1 [00:00<00:00, 38.34it/s]


2025-08-12 05:43:48,470 - INFO - Retrieved 30000 SUIs and the top 30 semantic texts.


--------------------------------------------------

--- STAGE 2: Knowledge Graph Traversal ---
Retrieved 793 definitions from the graph.
  Initial Def [0]: An abnormal increase in the acidity of the body's fluids...
  Initial Def [1]: Catalysis of the reaction: S-methyl-5-thio-D-ribulose 1-phosphate = 5-(methylthio)-2,3-dioxopentyl phosphate + H2O. [EC:...
  Initial Def [2]: family of globular proteins found in many plant and animal tissues that tend to bind a wide variety of ligands....
  Initial Def [3]: An abnormal phosphate concentration in the urine. [HPO_CONTRIBUTOR:Eurenomics_ewuehl]...
  Initial Def [4]: Abnormally low serum sodium levels in the setting of electrolyte/fluid imbalance. This condition may be the result of ex...
--------------------------------------------------

--- STAGE 3: Re-ranking (Top 15) ---


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at pritamdeka/S-PubMedBert-MS-MARCO and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-08-12 05:43:50,018 - INFO - Use pytorch device: cuda:0
Batches: 100%|██████████| 25/25 [00:00<00:00, 27.22it/s]
2025-08-12 05:43:51,175 - INFO - Re-ranked 793 definitions and selected the top 15.
2025-08-12 05:43:51,178 - INFO - Loaded 3 shots for 'reasoning_fct' (max_shots: 3).


Selected the top 45 most relevant definitions.
  Final Ctx [0]: Intracellular accumulation of the lipid-linked oligosaccharide intermediate Man5GlcNAc2-PP-dolichol. [https://orcid.org/...
  Final Ctx [1]: potassium bicarbonate 7.67 MG/ML...
  Final Ctx [2]: The directed movement of phosphate ions from the cytosol across the vacuolar membrane and into the vacuolar lumen. [GO_R...
  Final Ctx [3]: Ammonium bicarbonate + potassium iodide (product)...
  Final Ctx [4]: Ion, Bicarbonate...
  Final Ctx [5]: ammonium bicarbonate...
  Final Ctx [6]: Catalysis of the reaction: 2 D-glucose 1-phosphate = D-glucose + D-glucose 1,6-bisphosphate. [EC:2.7.1.41, MetaCyc:GLUCO...
  Final Ctx [7]: Ions, Bicarbonate...
  Final Ctx [8]: Ammonium bicarbonate + ipecacuanha + sodium bicarbonate...
  Final Ctx [9]: ammonium phosphate ((NH4)3PO4)...
  Final Ctx [10]: Ammonium bicarbonate + ipecacuanha...
  Final Ctx [11]: Ammonia + ipecacuanha...
  Final Ctx [12]: sodium acid phosphate + sodium + potassium bica

2025-08-12 05:43:59,120 - INFO - HTTP Request: POST https://ollama.zib.de/api/chat/completions "HTTP/1.1 200 OK"



--- FINAL LLM OUTPUT ---
{
  "is_answer_correct": "yes",
  "cop_index": "3",
  "correct answer": "Ammonia",
  "why_correct": [
    "The question focuses on identifying the primary intracellular buffer system.",
    "Ammonia functions as a key intracellular buffer by neutralizing acids and forming ammonium ions, which is crucial for maintaining cellular pH stability.",
    "This makes ammonia the most important intracellular buffer compared to other options like bicarbonate, albumin, or phosphate."
  ],
  "why_others_incorrect": [
    "Bicarbonate (0) is a major extracellular buffer and not as significant intracellularly.",
    "Albumin (1) primarily functions in the extracellular compartment.",
    "Phosphate (2) has some buffering role but is less predominant than ammonia."
  ]
}

--- ANALYSIS ---
Correct Index:    2
Predicted Index:  3
❌ RESULT: INCORRECT
--------------------


In [None]:
{
  "cop_index": 2,
  "answer": "Surfactant secretion",
  "why_correct": [
    "Step 1: The context provided discusses the role of surfactant in reducing surface tension and its relation to Type II alveolar cells.",
    "Step 2: The question asks about the cell type responsible for alveolar regeneration and an additional function it performs.",
    "Step 3: Type II alveolar cells are known for their ability to regenerate the epithelium and secrete surfactant.",
    "Step 4: Therefore, 'Surfactant secretion' is a correct additional function performed by these cells."
  ],
  "why_others_incorrect": [
    "Protease release (0): Not primarily associated with alveolar epithelial cells.",
    "Phagocytosis (1): This is more related to macrophages, not alveolar cells.",
    "Recruitment of neutrophils (3): While neutrophils are involved in inflammation, they aren't directly recruited by Type II cells for regeneration."
  ]
}

In [None]:
# Alright, let's tackle this question step by step. The patient is a 56-year-old man with hypertension, diabetes, and a creatinine level of 1.6 mg/dL, which I know indicates moderate renal function since it's above normal but not severely elevated. His pre-treatment blood pressure was 170/100 mmHg, so he definitely has stage 2 hypertension.

# Now, the question is about determining the best long-term blood pressure goal for him considering his comorbidities and kidney function. I remember that guidelines often tailor blood pressure targets based on individual risk factors.

# Looking at the context provided, there are several key points:
# - For patients with diabetes or chronic kidney disease (CKD), especially those without cardiovascular disease, the target is usually below 140/90 mmHg.
# - However, in cases where achieving this might be challenging or if it's not well-tolerated, a slightly higher target like 150/80 might be considered. But I need to check the exact wording.

# Wait, the context mentions that for patients with prehypertension and comorbid conditions like diabetes, they should aim for less than 130/80 if lifestyle modifications fail. Also, it states that in some cases, especially with certain health issues, higher targets might be set. 

# But looking again, I see a specific part: "Advise patients with prehypertensive classification and comorbid conditions (e.g., heart failure, diabetes, kidney disease) to seek appropriate drug therapy if a trial of lifestyle modification fails to reduce BP to 130/80 mm Hg or less." This suggests that the target is <130/80 for those with such comorbidities.

# However, another part says: "Measure BP to determine presence of hypertension (e.g., normal, less than 120/80; elevated, 120 to 129/80 or less; hypertension stage 1, 130 to 139/80 to 89; hypertension stage 2, equal or greater than 140/90)." This is more about categorizing hypertension stages rather than targets.

# Wait, I'm a bit confused. Another part mentions: "Most recent diastolic blood pressure less than 80 mm Hg (HTN, CKD, CAD) (DM)" and others talk about various BP measurements. But the key is to find the best goal for this patient.

# Given his comorbidities (diabetes and CKD), I think the target should be more stringent. The context also mentions that in patients with diabetes and kidney disease, especially if they're on treatment, aiming for <140/90 might be appropriate unless there's high risk of hypotension or other complications.

# Wait, but another part says: "Most recent diastolic blood pressure greater than or equal to 90 mm Hg (HTN, CKD, CAD) (DM)" which indicates that his current BP is stage 2. But for goals, the context doesn't specify a higher target; instead, it emphasizes the importance of achieving lower targets in these patients.

# However, looking at the options provided, option 1 is 150-160/90-95 mmHg, which seems high for someone with diabetes and CKD. The other options are lower: 0 is <130/80, 2 is <140/90, and 3 is <140/85.

# I think the correct answer should be option 1 because sometimes in older patients or those with multiple comorbidities, a slightly higher target like 140-160/90 might be considered to avoid excessive lowering which could lead to complications. But wait, the context doesn't explicitly state this. It mainly talks about treating to below certain thresholds.
# ...
# - Option 3 (140/85) focuses on diastolic but may overlook the importance of systolic control in this context.

# **Answer:** {"cop_index": "0", "answer": "<130/80 mmHg", "why_correct": "The best target for a patient with hypertension, diabetes, and CKD is <130/80 mmHg to minimize organ damage.", "why_others_incorrect": "Higher targets are not appropriate due to increased risk of complications."}'
# 2025-08-12 01:27:39,890 - INFO - HTTP Request: POST https://ollama.zib.de/api/chat/completions "HTTP/1.1 200 OK"
# Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

# --- FINAL LLM OUTPUT ---
# {
#   "cop_index": "0",
#   "answer": "<130/80 mmHg",
#   "why_correct": "The key concepts in the question involve determining the best long-term blood pressure goals for a patient with hypertension, diabetes, and renal function (creatinine 1.6 mg/dL). According to the context provided, 'advising patients with prehypertensive classification and comorbid conditions...to seek appropriate drug therapy if a trial of lifestyle modification fails to reduce BP to 130/80 mmHg or less.' This indicates that the target blood pressure should be <130/80 mmHg for such patients. Additionally, 'most recent diastolic blood pressure less than 80 mm Hg' is noted as a favorable outcome. Therefore, the best goal is to achieve a BP below 130/80 mmHg.",
#   "why_others_incorrect": "Option 1 (150-160/90-95 mmHg) is incorrect because it represents a higher range that does not align with standard treatment goals for patients with comorbid conditions. Option 2 (<140/90 mmHg) and Option 3 (<140/85 mmHg) are also incorrect as they do not reflect the stricter targets recommended for diabetes and kidney disease."
# }

# --- ANALYSIS ---
# Correct Index:    0
# Predicted Index:  0
# ✅ RESULT: CORRECT

In [None]:
# # --- FINAL LLM OUTPUT ---
# {
#   "is_answer_correct": "no",
#   "cop_index": "1",
#   "correct answer": "150-160/90-95 mmHg",
#   "reasoning_steps": [
#     "The context mentions patients with prehypertensive classification and comorbid conditions (e.g., heart failure, diabetes, kidney disease) to seek appropriate drug therapy if a trial of lifestyle modification fails to reduce BP to 130/80 mm Hg or less.",
#     "The blood pressure goal mentioned in the context is 130/80 mmHg or less for patients with prehypertensive classification and comorbid conditions.",
#     "However, the patient's pre-treatment blood pressure was 170/100 mmHg, which indicates that the patient has hypertension stage 2 (BP >= 140/90).",
#     "According to the context, patients with hypertension stage 2 should aim for BP goals of 150-160/90-95 mmHg."
#   ],
#   "why_correct": "The correct answer is 150-160/90-95 mmHg because it corresponds to the BP goal mentioned in the context for patients with hypertension stage 2.",
#   "why_others_incorrect": "Options 0, 2, and 3 are incorrect because they do not correspond to the BP goal mentioned in the context for patients with hypertension stage 2."
# }

# --- ANALYSIS ---
# Correct Index:    0
# Predicted Index:  1
# ❌ RESULT: INCORRECT

In [45]:
#imp sui S18141242
#definitions

In [12]:
# Cell 2: Load a Pre-trained Model and Tokenizer
# We'll use a BERT model because its attention is very easy to interpret for this task.
# While not the exact LLM you're using, the principles of attention are the same.

# model_name = 'bert-base-uncased'
# tokenizer = BertTokenizer.from_pretrained(model_name)
# model = BertModel.from_pretrained(model_name, output_attentions=True)

# print(f"Loaded '{model_name}' for analysis.")


model_name = "pritamdeka/S-PubMedBert-MS-MARCO"
#model_name = "bionlp/bluebert_pubmed_uncased_L-24_H-1024_A-16"

print(f"Loading Bio-Specific Cross-Encoder model '{model_name}'...")
cross_encoder = CrossEncoder(model_name)
print("Model loaded successfully.")

Loading Bio-Specific Cross-Encoder model 'pritamdeka/S-PubMedBert-MS-MARCO'...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at pritamdeka/S-PubMedBert-MS-MARCO and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-08-08 04:08:31,067 - INFO - Use pytorch device: cuda:0


Model loaded successfully.


In [1]:
# q="Which of the following structural elements is characteristic of the ortopramide group drugs?"
# d=["The 'ortopramides,' also known as substituted benzamides, represent a significant class of dopamine D2 receptor antagonists. Structurally, the defining characteristic of this group is a benzamide core with a methoxy group positioned at the ortho- (2-) position of the aromatic ring. This specific arrangement is crucial for their pharmacological activity. While some related compounds like phenothiazines are also used as antiemetics, they lack this precise benzamide structure."]

In [13]:
def softmax(x):
    """Compute softmax values for a set of scores x."""
    e_x = np.exp(x - np.max(x)) # Subtract max for numerical stability
    return e_x / e_x.sum(axis=0)

In [14]:
# --- 2. PREPARE THE PAIRS FOR THE CROSS-ENCODER ---
# The Cross-Encoder needs a list of [question, context] pairs.
pairs = []
for definition in definitions:
    pairs.append([question, definition])

# --- 3. GET THE RELEVANCE SCORES ---
# This is the core of the analysis. The model will output a single, meaningful
# score for each pair, indicating how relevant the definition is to the question.
print("\n--- Scoring relevance of each definition against the question... ---")
# Set show_progress_bar=True to see the progress for a large number of definitions.
scores = cross_encoder.predict(pairs, show_progress_bar=False)

# --- 4. CALCULATE THE PROBABILITIES ---
probabilities = softmax(scores)

# --- 5. CREATE A CLEAN, INTERPRETABLE REPORT ---
df_scores = pd.DataFrame({
    'Relevance_Score (Logit)': scores,
    'Probability (%)': [f"{p:.2%}" for p in probabilities], # Format as percentage
    'Definition': definitions
})

df_scores.sort_values(by='Relevance_Score (Logit)', ascending=False, inplace=True)
df_scores.reset_index(drop=True, inplace=True)
print("\n\n========================================================================")
print("  DEFINITIVE RELEVANCE RANKING REPORT")
print("========================================================================")
print("This table shows which of your retrieved definitions is most relevant to the question.")
print("A high positive score is good. A low or negative score is bad.\n")

# Set display options to show the full text of the definitions
pd.set_option('display.max_colwidth', None)

display(df_scores)


--- Scoring relevance of each definition against the question... ---


IndexError: list index out of range

In [None]:
{
  "cop_index": 3,
  "answer": "Ammonia",
  "why_correct": [
    "Intracellular buffers are typically those found within cells, which play a crucial role in maintaining acid-base balance.",
    "The context highlights the importance of ammonia and ammonium bicarbonate as intracellular buffers, particularly in clinical chemistry settings.",
    "Among the options provided, ammonia stands out as an intrinsic component of cellular buffering systems due to its ability to react with hydrogen ions.",
    "Therefore, considering the emphasis on intracellular buffers, ammonia is the most appropriate choice."
  ],
  "why_others_incorrect": [
    "Bicarbonate is primarily extracellular and plays a key role in maintaining pH balance in blood plasma, making it less relevant as an intracellular buffer.",
    "Albumin, while an important protein in blood, serves more as a carrier protein than a buffer.",
    "Phosphate can act as a buffer, but its role is not as prominent or intrinsic within cells as ammonia's."
  ]
}

In [54]:
import nltk
import string
try:
    from nltk.corpus import stopwords
    stop_words = set(stopwords.words('english'))
except LookupError:
    print("Downloading NLTK stopwords corpus...")
    nltk.download('stopwords')
    from nltk.corpus import stopwords
    stop_words = set(stopwords.words('english'))

warnings.filterwarnings('ignore')
print("Libraries imported and stopwords are ready.")

Downloading NLTK stopwords corpus...
Libraries imported and stopwords are ready.


[nltk_data] Downloading package stopwords to
[nltk_data]     /home/macharya/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [56]:
# Cell 3: The Definitive Analysis Cell with AGGRESSIVE Filtering
# ======================================================================
# This version is designed to completely eliminate noise and show only
# meaningful semantic connections between non-stop words.

# --- 1. PASTE YOUR DATA HERE ---
q = "Which of the following structural elements is characteristic of the ortopramide group drugs?"

d = [
    "The 'ortopramides,' also known as substituted benzamides, represent a significant class of dopamine D2 receptor antagonists. Structurally, the defining characteristic of this group is a benzamide core with a methoxy group positioned at the ortho- (2-) position of the aromatic ring. This specific arrangement is crucial for their pharmacological activity. While some related compounds like phenothiazines are also used as antiemetics, they lack this precise benzamide structure."
]

# --- 2. LOOP THROUGH EACH DEFINITION AND ANALYZE ---
for i, context_sentence in enumerate(d):
    print(f"========================================================================")
    print(f"  ANALYZING ATTENTION FOR DEFINITION [{i+1} / {len(d)}]")
    print(f"========================================================================")
    print(f"CONTEXT: \"{context_sentence}\"")
    
    # --- Tokenization and Preparation ---
    inputs = tokenizer.encode_plus(question, context_sentence, return_tensors='pt', add_special_tokens=True, truncation=True, max_length=512)
    input_ids = inputs['input_ids'][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    
    # --- Run Model and Get Attention ---
    with torch.no_grad():
        outputs = model(**inputs)
        attentions = outputs.attentions

    # --- Display the Clean Matrix View ---
    #print(f"\n--- ATTENTION HEATMAP [{i+1}] ---")
    #display(model_view(attentions, tokens))
    
    # --- PROGRAMMATIC ANALYSIS WITH AGGRESSIVE FILTERING ---
    sep_index = tokens.index('[SEP]')
    question_tokens = tokens[1:sep_index]
    context_tokens = tokens[sep_index+1:-1]
    
    attentions_per_layer = torch.stack(attentions).squeeze(1).mean(dim=1)
    final_layer_attention = attentions_per_layer[-1]

    # --- THE DEFINITIVE FIX FOR THE ANALYSIS ---
    # 1. Define what to ignore: special tokens, punctuation, AND stop words.
    special_tokens = set(tokenizer.special_tokens_map.values())
    punctuation = set(string.punctuation)
    # Combine all ignored tokens into one set for efficiency
    tokens_to_ignore = special_tokens.union(punctuation).union(stop_words)

    attention_scores = []
    for q_idx, q_token in enumerate(question_tokens, 1):
        # 2. Ignore question tokens that are noise.
        if q_token in tokens_to_ignore or q_token.startswith('##'):
            continue
            
        for c_idx, c_token in enumerate(context_tokens, sep_index + 1):
            # 3. Ignore context tokens that are noise.
            if c_token in tokens_to_ignore or c_token.startswith('##'):
                continue

            # 4. Ignore "self-attention" (optional but good practice)
            if q_token == c_token:
                continue

            score = final_layer_attention[q_idx, c_idx].item()
            attention_scores.append(((q_token, c_token), score))

    attention_scores.sort(key=lambda x: x[1], reverse=True)
    
    print("\n--- TOP 5 MOST MEANINGFUL ATTENDED-TO WORD PAIRS (from final layer) ---")
    df = pd.DataFrame(attention_scores[:5], columns=['(Question_Token, Context_Token)', 'Attention_Score'])
    
    if df.empty or df['Attention_Score'].iloc[0] < 0.1:
        print("\n*** WARNING: LOW MEANINGFUL ATTENTION. This context may be irrelevant. ***")

    display(df)
    
    print(f"\n--- END OF ANALYSIS FOR DEFINITION [{i+1}] ---\n\n")

  ANALYZING ATTENTION FOR DEFINITION [1 / 1]
CONTEXT: "The 'ortopramides,' also known as substituted benzamides, represent a significant class of dopamine D2 receptor antagonists. Structurally, the defining characteristic of this group is a benzamide core with a methoxy group positioned at the ortho- (2-) position of the aromatic ring. This specific arrangement is crucial for their pharmacological activity. While some related compounds like phenothiazines are also used as antiemetics, they lack this precise benzamide structure."

--- TOP 5 MOST MEANINGFUL ATTENDED-TO WORD PAIRS (from final layer) ---



Unnamed: 0,"(Question_Token, Context_Token)",Attention_Score
0,"(gene, structure)",0.00097
1,"(protein, compounds)",0.000815
2,"(hem, benz)",0.000745
3,"(bones, compounds)",0.000736
4,"(protein, structure)",0.000731



--- END OF ANALYSIS FOR DEFINITION [1] ---




In [15]:
inputs = tokenizer.encode_plus(question, definitions, return_tensors='pt', add_special_tokens=True)
token_type_ids = inputs['token_type_ids'] # This tells the model which part is the question and which is the context
input_ids = inputs['input_ids']

In [None]:
# for i, context_sentence in enumerate(final_definitions):
#     print(f"\n--- Analyzing Attention for Definition [{i}] ---")
#     print(f"Context: {context_sentence}")
    
#     # --- Tokenization and Preparation ---
#     # This now combines the question with only ONE definition at a time.
#     # We also add truncation as a safety measure.
#     inputs = tokenizer.encode_plus(
#         question, 
#         context_sentence, 
#         return_tensors='pt', 
#         add_special_tokens=True,
#         truncation=True, # This will cut off any text that is still too long
#         max_length=512   # The model's maximum length
#     )
    
#     token_type_ids = inputs['token_type_ids']
#     input_ids = inputs['input_ids']
#     tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
#     # --- Run the Model and Get Attention ---
#     with torch.no_grad():
#         outputs = model(input_ids, token_type_ids=token_type_ids)
#         attentions = outputs.attentions
    
#     # --- Visualization ---
#     # This will render a new interactive heatmap for each definition.
#     display(head_view(attentions, tokens))```

In [None]:
# Cell 3: The Analysis Cell - Plug in Your Data Here
# ===================================================
# 1. Find a question and the exact context from your previous runs' logs.
# 2. Paste them into the `question` and `context` variables below.
# 3. Run this cell to see the attention heatmap.

# --- PASTE YOUR DATA HERE ---
question = "Which of the following structural elements is characteristic of the ortopramide group drugs?"

# Use a small, focused piece of context from your logs for a clear visualization
context = "An orally bioavailable benzamide type inhibitor of histone deacetylase isoenzymes 1, 2, 3 and 10, with potential antineoplastic activity. Tucidinostat is an ortho-halogenated derivative of phenothiazine."

# --- TOKENIZATION AND PREPARATION ---
# This prepares the text in the special format BERT expects: [CLS] question [SEP] context [SEP]
inputs = tokenizer.encode_plus(question, context, return_tensors='pt', add_special_tokens=True)
token_type_ids = inputs['token_type_ids'] # This tells the model which part is the question and which is the context
input_ids = inputs['input_ids']

# --- RUN THE MODEL AND GET ATTENTION ---
with torch.no_grad():
    outputs = model(input_ids, token_type_ids=token_type_ids)
    attentions = outputs.attentions # This contains the attention weights for all layers and heads

# --- VISUALIZATION ---
# Convert token IDs back to human-readable tokens
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) 

# Create the visualization
# This will render an interactive heatmap directly in your notebook.
head_view(attentions, tokens)