In [None]:
# @title üöÄ 1. Install Dependencies & Setup
# This cell installs the necessary libraries to talk to ArangoDB and process the data.
# Run this cell first!

!pip install python-arango datasets ollama gradio sentence-transformers -q

import time
from getpass import getpass
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import subprocess
import requests
import sys
import re
import numpy as np
import warnings
from typing import List, Dict
from arango.exceptions import ServerConnectionError, ArangoServerError
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import os
import pickle
from sentence_transformers import CrossEncoder
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import gradio as gr

!curl -fsSL https://ollama.com/install.sh | sh


print("‚úÖ Libraries installed.")

In [None]:
def check_and_pull_model(model_name="deepseek-r1:8b"):
    """
    Checks if the model exists in Ollama. If not, pulls it automatically.
    """
    print(f"üïµÔ∏è [Ollama] Checking for model: {model_name}...")

    # 1. Check list of models
    try:
        result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
        if model_name in result.stdout:
            print(f"‚úÖ [Ollama] Model '{model_name}' is ready.")
            return
    except Exception as e:
        print(f"‚ö†Ô∏è [Ollama] Could not check model list: {e}")

    # 2. If missing, pull it
    print(f"‚¨áÔ∏è [Ollama] Model not found. Pulling {model_name} (This takes 2-5 mins)...")
    try:
        # We use Popen to stream the output so you don't think it hung
        process = subprocess.Popen(
            ["ollama", "pull", model_name],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )
        while True:
            output = process.stderr.readline()
            if output == b'' and process.poll() is not None:
                break
            if output:
                # Print progress to console
                print(output.decode().strip())

        print(f"‚úÖ [Ollama] Successfully pulled {model_name}!")

    except Exception as e:
        print(f"‚ùå [Ollama] Failed to pull model: {e}")
        sys.exit(1) # Stop script if model fails

MODEL_NAME = "deepseek-r1:8b"
OLLAMA_API = "http://localhost:11434/api/chat"


check_and_pull_model()

In [4]:
# @title
# --- 3. ROBUST UTILITIES ---

class FuzzyEvaluator:
    """Evaluates answers with logic to handle verbosity and synonyms."""

    def extract_answer(self, text: str) -> str:
        # Strip DeepSeek "Thinking" blocks
        clean_text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).lower()
        # Look for the last explicit declaration
        match = re.search(r'(?:final answer|answer):?\s*(yes|no|maybe)', clean_text)
        if match: return match.group(1)
        # Fallback: look for isolated words at end of text
        matches = re.findall(r'\b(yes|no|maybe)\b', clean_text)
        if matches: return matches[-1]
        return "maybe" # Default safety

    def is_correct(self, gt: str, pred: str) -> bool:
        gt, pred = gt.lower().strip(), pred.lower().strip()

        # 1. Exact Match
        if gt == pred: return True

        # 2. Starts With (e.g. "yes, because...")
        if pred.startswith(gt + " ") or pred.startswith(gt + ","): return True

        # 3. Synonyms
        positive = ["definitely yes", "likely", "probable", "certainly"]
        negative = ["unlikely", "doubtful", "never"]

        if gt == "yes" and any(x in pred for x in positive): return True
        if gt == "no" and any(x in pred for x in negative): return True

        return False

class ArangoConnectionManager:
    """Handles the 503 Service Unavailable errors by retrying."""

    def __init__(self, config):
        self.config = config
        self.client = ArangoClient(hosts=config["hosts"])
        self.db = self._connect_with_retry()

    def _connect_with_retry(self, max_retries=5):
        for attempt in range(max_retries):
            try:
                # verify connection
                sys_db = self.client.db("_system", username=self.config["username"], password=self.config["password"])
                sys_db.version() # Ping

                # Connect to actual DB
                db = self.client.db(self.config["db_name"], username=self.config["username"], password=self.config["password"])
                print(f"‚úÖ [ArangoDB] Connected successfully.")
                return db
            except (ServerConnectionError, ArangoServerError) as e:
                wait = (attempt + 1) * 5
                print(f"‚ö†Ô∏è [ArangoDB] Connection failed ({e}). Retrying in {wait}s...")
                time.sleep(wait)

        raise ConnectionError("Could not connect to ArangoDB after retries.")

In [5]:
# ==========================================
# 1. THE CACHING FUNCTION (Defined locally)
# ==========================================
def load_vectors_smartly(db, collection_name, cache_file="pubmed_vectors_cache.pkl"):
    """
    Handles the logic: Check Disk -> If Missing, Download -> Save to Disk.
    """
    # A. Check Disk
    if os.path.exists(cache_file):
        print(f"üíæ [Cache] Found local file: {cache_file}")
        try:
            with open(cache_file, 'rb') as f:
                data = pickle.load(f)
            ids = data.get('ids', [])
            texts = data.get('texts', [])
            embeddings = data.get('embeddings', [])

            if len(embeddings) > 0:
                print(f"‚úÖ [Cache] Loaded {len(embeddings)} vectors from disk instantly.")
                return ids, texts, embeddings
        except Exception as e:
            print(f"‚ö†Ô∏è [Cache] File corrupted ({e}). Re-downloading...")

    # B. Download from Cloud (Only if A failed)
    print(f"‚òÅÔ∏è [Index] Cache missing. Downloading from ArangoDB (This happens only once)...")

    ids, texts, embeddings = [], [], []

    # Get Count
    try:
        count = db.aql.execute(f"RETURN LENGTH({collection_name})").next()
    except:
        count = 200000

    # Paged Download
    BATCH_SIZE = 5000
    offset = 0

    with tqdm(total=count, desc="Downloading Index", unit="vec") as pbar:
        while True:
            aql = f"""
            FOR c IN {collection_name}
                FILTER c.embedding != null
                LIMIT {offset}, {BATCH_SIZE}
                RETURN {{ "id": c._id, "text": c.text, "emb": c.embedding }}
            """
            try:
                cursor = db.aql.execute(aql, ttl=3600)
                batch_count = 0
                for doc in cursor:
                    ids.append(doc["id"])
                    texts.append(doc["text"])
                    embeddings.append(doc["emb"])
                    batch_count += 1

                pbar.update(batch_count)
                offset += batch_count
                if batch_count < BATCH_SIZE: break
                time.sleep(0.1) # Be gentle on the server
            except Exception as e:
                print(f"‚ö†Ô∏è Error on batch: {e}")
                if "503" in str(e): time.sleep(5)
                else: break

    # C. Save to Disk
    embeddings_np = np.array(embeddings)
    if len(ids) > 0:
        print(f"üíæ [Cache] Saving {len(ids)} vectors to {cache_file}...")
        with open(cache_file, 'wb') as f:
            pickle.dump({'ids': ids, 'texts': texts, 'embeddings': embeddings_np}, f)
        print("‚úÖ [Cache] Saved.")

    return ids, texts, embeddings_np

class RobustGraphRAG:
    def __init__(self, config):
        self.config = config
        self.client = ArangoClient(hosts=config["hosts"])
        self.db = self.client.db(config["db_name"], username=config["username"], password=config["password"])

        print("‚è≥ [Model] Loading Encoders...")
        self.encoder = SentenceTransformer("all-MiniLM-L6-v2")
        self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

        self.chunk_ids, self.chunk_texts, self.chunk_embeddings = load_vectors_smartly(
            self.db,
            self.config['chunk_col']
        )

    def retrieve(self, query: str, top_k=3):
        if len(self.chunk_embeddings) == 0: return "No context."

        # 1. Wider Vector Search (75 candidates)
        # We widen this to ensure we catch "Conclusion" chunks that might use different wording
        query_emb = self.encoder.encode([query])
        sims = cosine_similarity(query_emb, self.chunk_embeddings)[0]
        top_n_indices = np.argsort(sims)[-75:][::-1]

        candidate_pairs = []
        for idx in top_n_indices:
            candidate_pairs.append((self.chunk_texts[idx], self.chunk_ids[idx]))

        # 2. Re-Ranking
        cross_inputs = [[query, text] for text, _ in candidate_pairs]
        scores = self.reranker.predict(cross_inputs)
        ranked_indices = np.argsort(scores)[::-1]

        best_chunk_ids = []
        for i in range(top_k):
            idx = ranked_indices[i]
            _, cid = candidate_pairs[idx]
            best_chunk_ids.append(cid)

        # 3. Graph Expansion (Parent Abstract Reconstruction)
        aql = """
        WITH Papers, Chunks
        FOR start_chunk_id IN @ids
            LET start_doc = DOCUMENT(start_chunk_id)

            // Find Parent Paper
            FOR paper IN 1..1 INBOUND start_doc HAS_CONTEXT

                // Get ALL chunks (Introduction + Results + Conclusion)
                LET full_text_chunks = (
                    FOR c IN 1..1 OUTBOUND paper HAS_CONTEXT
                    RETURN c.text
                )

                // Concatenate into a clean abstract
                LET full_abstract = CONCAT_SEPARATOR(" ", full_text_chunks)

                RETURN {
                    "title": paper.title,
                    "abstract": full_abstract
                }
        """

        try:
            cursor = self.db.aql.execute(aql, bind_vars={"ids": best_chunk_ids})
            context_parts = []
            seen_titles = set()

            for res in cursor:
                title = res.get('title', 'Unknown')
                if title in seen_titles: continue
                seen_titles.add(title)

                # Add "Study X" header to help LLM distinguish separate papers
                entry = (
                    f"=== STUDY: {title} ===\n"
                    f"ABSTRACT: {res.get('abstract')}\n"
                )
                context_parts.append(entry)

            return "\n".join(context_parts)

        except Exception as e:
            print(f"‚ö†Ô∏è Graph Error ({e}).")
            fallback_texts = []
            for i in range(top_k):
                idx = ranked_indices[i]
                t, _ = candidate_pairs[idx]
                fallback_texts.append(f"Excerpt: {t}")
            return "\n".join(fallback_texts)

    def _heuristic_override(self, response_text):
        """
        Python Safety Net: Catches 'Maybe' and flips it if strong keywords exist.
        """
        clean_text = response_text.lower()

        # 1. Extract the explicit answer
        match = re.search(r'(?:final answer|answer):?\s*(yes|no|maybe)', clean_text)
        pred = match.group(1) if match else "maybe"

        # 2. If prediction is YES or NO, trust the model.
        if pred in ["yes", "no"]:
            return pred

        # 3. If prediction is MAYBE, check the REASONING for "Soft Signals"
        # Positive Signals
        soft_yes = ["suggests", "indicates", "significant", "associated with", "effective", "improved"]
        for word in soft_yes:
            if word in clean_text:
                return "yes"

        # Negative Signals
        soft_no = ["no significant", "did not", "unrelated", "ineffective", "no difference"]
        for word in soft_no:
            if word in clean_text:
                return "no"

        return "maybe"

    def query_ollama(self, prompt: str):
        # The "Calibration" Prompt
        # We align the model with PubMedQA's specific annotation style.

        system_msg = """
        You are a PubMedQA annotator.
        Your task is to classify the answer as 'yes', 'no', or 'maybe' based on the Study Abstract.

        ANNOTATION GUIDELINES (CRITICAL):
        1. If the study suggests a positive outcome, even if "further study is needed", the answer is YES.
        2. If the study finds a correlation or association, the answer is YES.
        3. If the study finds "no significant difference", the answer is NO.
        4. ONLY use MAYBE if the abstract explicitly states "results were inconclusive" or provides zero data.

        Format:
        Final Answer: [yes/no/maybe]
        """

        full_prompt = f"{system_msg}\n\nContext:\n{prompt}"

        url = "http://localhost:11434/api/chat"
        payload = {
            "model": "deepseek-r1:8b",
            "messages": [{"role": "user", "content": full_prompt}],
            "stream": False,
            "options": {
                "temperature": 0.0,
                "num_ctx": 4096
            }
        }
        try:
            res = requests.post(url, json=payload, timeout=300)
            if res.status_code == 200:
                raw_response = res.json()['message']['content']

                # --- APPLY THE PYTHON SAFETY NET ---
                final_decision = self._heuristic_override(raw_response)

                # Return a format that your evaluator can parse
                return f"{raw_response}\n\n[Heuristic Override Result]: Final Answer: {final_decision}"

            return f"Error {res.status_code}"
        except Exception as e:
            return f"Exception: {e}"



    def generate_chat_response(self, message, context):
        """
        A specific prompt for the Chat UI (Conversational, not Yes/No).
        """
        system_msg = """
        You are a Helpful Medical AI Assistant.
        Use the provided Research Abstracts to answer the user's question accurately.

        Guidelines:
        1. Base your answer ONLY on the context provided.
        2. Cite the specific study titles when making claims (e.g., "According to the study on X...").
        3. If the studies are conflicting, explain the conflict.
        4. If the answer is not in the context, admit you don't have evidence but give your opinion.
        """

        full_prompt = f"{system_msg}\n\nContext:\n{context}\n\nUser Question: {message}"

        url = "http://localhost:11434/api/chat"
        payload = {
            "model": "deepseek-r1:8b",
            "messages": [{"role": "user", "content": full_prompt}],
            "stream": False,
            "options": {"temperature": 0.3, "num_ctx": 4096} # Slight creativity allowed
        }
        try:
            res = requests.post(url, json=payload, timeout=300)
            if res.status_code == 200:
                return res.json()['message']['content']
            return "Error: Could not communicate with model."
        except Exception as e:
            return f"Error: {e}"

    # --- THE UI LAUNCHER ---
    def launch_gradio_ui(self):
        print("\nüöÄ Launching Gradio UI...")

        def chat_logic(message, history):
            # 1. Retrieve Context
            print(f"üîé Retrieving for: {message}...")
            retrieved_context = self.retrieve(message)

            # 2. Generate Answer
            print(f"ü§ñ Generating Answer...")
            response = self.generate_chat_response(message, retrieved_context)

            # 3. Optional: Append Sources to the bottom of the answer
            final_output = f"{response}\n\n___\n**Sources Retrieved:**\n"

            # Simple regex to extract titles for display
            titles = re.findall(r"=== STUDY: (.*?) ===", retrieved_context)
            for t in titles:
                final_output += f"- *{t}*\n"

            return final_output

        # Create the Interface
        demo = gr.ChatInterface(
            fn=chat_logic,
            title="üß¨ PubMed GraphRAG Assistant",
            description="Ask detailed medical questions. I will retrieve full abstracts from the Knowledge Graph to answer you.",
            examples=[
                "Do preoperative statins reduce atrial fibrillation?",
                "Is obesity a risk factor for cirrhosis-related death or hospitalization?",
                "Does high-dose aspirin prevent cardiovascular events?"
            ],
            theme="soft"
        )

        demo.launch(share=True, debug=True)

In [6]:
class AdvancedEvaluator:
    def __init__(self):
        self.y_true = []
        self.y_pred = []
        self.start_time = None
        self.end_time = None

    def start(self):
        """Starts the stopwatch."""
        self.start_time = time.time()
        print("‚è±Ô∏è Evaluation Timer Started...")

    def stop(self):
        """Stops the stopwatch."""
        self.end_time = time.time()

    def record(self, gt, pred):
        """Records a single prediction pair."""
        # Normalize to ensure clean metrics
        clean_gt = gt.lower().strip()
        clean_pred = pred.lower().strip()

        # Safety: If model output garbage, classify as 'maybe'
        if clean_pred not in ['yes', 'no', 'maybe']:
            clean_pred = 'maybe'

        self.y_true.append(clean_gt)
        self.y_pred.append(clean_pred)

    def generate_report(self):
        """Calculates and visualizes all requested metrics."""
        if not self.y_true:
            print("‚ö†Ô∏è No data to report.")
            return

        # 1. Total Time
        total_seconds = self.end_time - self.start_time
        avg_per_sample = total_seconds / len(self.y_true)

        # 2. Accuracy
        acc = accuracy_score(self.y_true, self.y_pred) * 100

        print("\n" + "="*40)
        print(f"üìä FINAL EVALUATION REPORT")
        print("="*40)
        print(f"‚è±Ô∏è Total Time:     {total_seconds:.2f} seconds")
        print(f"‚ö° Avg Latency:    {avg_per_sample:.2f} seconds/query")
        print(f"üéØ Final Accuracy: {acc:.2f}%")
        print("-" * 40)

        # 3. Prediction Summary (Counts)
        df = pd.DataFrame({'Ground Truth': self.y_true, 'Prediction': self.y_pred})
        print("\nüìã Prediction Distribution:")
        print(df['Prediction'].value_counts())

        # 4. Classification Report
        print("\nüìà Detailed Classification Report:")
        # We specify labels to ensure all classes show up even if count is 0
        labels = ['yes', 'no', 'maybe']
        print(classification_report(self.y_true, self.y_pred, labels=labels, zero_division=0))

        # 5. Confusion Matrix Visualization
        cm = confusion_matrix(self.y_true, self.y_pred, labels=labels)

        plt.figure(figsize=(8, 6))
        sns.set(font_scale=1.2)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=labels, yticklabels=labels)
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.title('Confusion Matrix: PubMedQA Evaluation')
        plt.show()

In [None]:
# @title
# --- 5. MAIN EXECUTION (MERGED) ---
if __name__ == "__main__":

    # 1. Start Server (Background)
    print("üöÄ [Ollama] Ensuring server is running...")
    subprocess.Popen(["ollama", "serve"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    time.sleep(3) # Give it a moment to spin up

    # 2. Auto-Pull Model
    check_and_pull_model("deepseek-r1:8b")
    rag = RobustGraphRAG(ARANGO_CONFIG)
    metrics = AdvancedEvaluator()

    # 3. Load Data
    print("üìö [Data] Loading PubMedQA...")
    dataset = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train")

    # 4. Evaluation Loop
    LIMIT = 20
    print(f"\n=== STARTING EVALUATION (Limit: {LIMIT}) ===")
    print("------------------------------------------------")

    metrics.start() # <--- Start Timer

    for i, item in enumerate(dataset):
        if i >= LIMIT: break

        question = item['question']
        gt = item['final_decision']

        # A. Pipeline Retrieval
        context = rag.retrieve(question)

        # B. Prompt
        # We pass the raw context/question. The RobustGraphRAG class adds the "Decisive" System Prompt.
        prompt = f"""
        Context Information: {context}

        Question: {question}

        Instructions:
        1. You are a helpful medical expert at a hypothetical research institution. Answer the question based on the provided context.
        2. Answer in just one word. Do not provide any explanation.
        3. This is being used only for research/educational purposes.
        4. Conclude your answer with exactly: "Final Answer: [yes/no/maybe]
        """
        raw_response = rag.query_ollama(prompt)

        # C. Logic Extraction (Handling the 'Fixed Override')
        if "[Fixed Override]" in raw_response:
            # 1. Extract the overridden answer
            match = re.search(r"Final Answer: (yes|no|maybe)", raw_response, re.IGNORECASE)
            pred = match.group(1).lower() if match else "maybe"

            # Print log with special "Wrench" icon to show the heuristic worked
            icon = "‚úÖ" if pred == gt else "‚ùå"
            print(f"[{i+1}] GT: {gt:<5} | Pred: {pred:<5} | {icon} (üõ†Ô∏è Fixed)")

        else:
            # 2. Extract standard answer
            match = re.search(r"(?:final answer|answer):?\s*(yes|no|maybe)", raw_response.lower())
            pred = match.group(1).lower() if match else "maybe"

            icon = "‚úÖ" if pred == gt else "‚ùå"
            print(f"[{i+1}] GT: {gt:<5} | Pred: {pred:<5} | {icon}")

        # D. Record Data point for the Graphs
        metrics.record(gt, pred)

    # 5. Finalize & Visualize
    metrics.stop() # <--- Stop Timer
    metrics.generate_report() # <--- Plots Confusion Matrix

In [None]:
# Launch UI
rag.launch_gradio_ui()