In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RAG System Testing\n",
    "## End-to-end testing of the RAG pipeline with BiomedLM\n",
    "\n",
    "This notebook tests:\n",
    "- Query embedding and retrieval\n",
    "- Context construction\n",
    "- BiomedLM response generation\n",
    "- System performance metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import json\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from pathlib import Path\n",
    "from typing import List, Dict\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from src.embedding.model_loader import ModelLoader\n",
    "from src.embedding.embedder import EmbeddingGenerator\n",
    "from src.indexing.faiss_indexer import FAISSIndexer\n",
    "from src.indexing.retriever import Retriever\n",
    "from src.llm.biomedlm import BiomedLMGenerator\n",
    "from src.llm.prompt_builder import PromptBuilder\n",
    "from src.config import settings\n",
    "\n",
    "sns.set_style('whitegrid')\n",
    "plt.rcParams['figure.figsize'] = (12, 6)\n",
    "\n",
    "print(\"✓ Imports successful\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Initialize RAG Components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Initializing RAG components...\\n\")\n",
    "\n",
    "# Load embedding model\n",
    "print(\"1. Loading embedding model...\")\n",
    "model_loader = ModelLoader(settings.EMBEDDING_MODEL, settings.DEVICE)\n",
    "tokenizer, model = model_loader.load()\n",
    "embedder = EmbeddingGenerator(tokenizer, model, settings.DEVICE, batch_size=32)\n",
    "print(\"   ✓ Embedding model loaded\")\n",
    "\n",
    "# Load FAISS index\n",
    "print(\"\\n2. Loading FAISS index...\")\n",
    "faiss_indexer = FAISSIndexer()\n",
    "if Path(settings.INDEX_PATH).exists():\n",
    "    faiss_indexer.load(settings.INDEX_PATH)\n",
    "    print(f\"   ✓ FAISS index loaded ({faiss_indexer.index.ntotal} vectors)\")\n",
    "else:\n",
    "    print(\"   ❌ Index not found! Run: python scripts/build_index.py\")\n",
    "\n",
    "# Load chunks\n",
    "print(\"\\n3. Loading chunks metadata...\")\n",
    "if Path(settings.CHUNKS_PATH).exists():\n",
    "    with open(settings.CHUNKS_PATH, 'r') as f:\n",
    "        chunks = json.load(f)\n",
    "    print(f\"   ✓ Loaded {len(chunks)} chunks\")\n",
    "else:\n",
    "    print(\"   ❌ Chunks not found!\")\n",
    "    chunks = []\n",
    "\n",
    "# Initialize retriever\n",
    "print(\"\\n4. Initializing retriever...\")\n",
    "retriever = Retriever(faiss_indexer, chunks, top_k=settings.TOP_K_RETRIEVAL)\n",
    "print(\"   ✓ Retriever ready\")\n",
    "\n",
    "# Load BiomedLM\n",
    "print(\"\\n5. Loading BiomedLM (this may take a while)...\")\n",
    "llm_generator = BiomedLMGenerator(\n",
    "    settings.LLM_MODEL,\n",
    "    settings.DEVICE,\n",
    "    settings.MAX_NEW_TOKENS,\n",
    "    settings.TEMPERATURE,\n",
    "    settings.TOP_P\n",
    ")\n",
    "try:\n",
    "    llm_generator.load()\n",
    "    print(\"   ✓ BiomedLM loaded\")\n",
    "    llm_available = True\n",
    "except Exception as e:\n",
    "    print(f\"   ⚠ BiomedLM not available: {e}\")\n",
    "    print(\"   Will test retrieval only\")\n",
    "    llm_available = False\n",
    "\n",
    "# Initialize prompt builder\n",
    "print(\"\\n6. Initializing prompt builder...\")\n",
    "prompt_builder = PromptBuilder()\n",
    "print(\"   ✓ Prompt builder ready\")\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"RAG SYSTEM INITIALIZED\")\n",
    "print(\"=\"*80)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Single Query Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_single_query(query: str, top_k: int = 3, show_context: bool = True):\n",
    "    \"\"\"Test a single query through the RAG pipeline\"\"\"\n",
    "    \n",
    "    print(f\"\\n{'='*80}\")\n",
    "    print(f\"QUERY: {query}\")\n",
    "    print(f\"{'='*80}\\n\")\n",
    "    \n",
    "    # Step 1: Embed query\n",
    "    start_time = time.time()\n",
    "    query_embedding = embedder.embed_query(query)\n",
    "    embed_time = time.time() - start_time\n",
    "    print(f\"1. Query Embedding: {embed_time:.3f}s\")\n",
    "    \n",
    "    # Step 2: Retrieve\n",
    "    start_time = time.time()\n",
    "    retrieved_chunks = retriever.retrieve(query_embedding, k=top_k)\n",
    "    retrieval_time = time.time() - start_time\n",
    "    print(f\"2. Retrieval: {retrieval_time:.3f}s\")\n",
    "    print(f\"   Retrieved {len(retrieved_chunks)} chunks\\n\")\n",
    "    \n",
    "    # Show retrieved chunks\n",
    "    print(\"Retrieved Sources:\")\n",
    "    print(\"-\" * 80)\n",
    "    for i, chunk in enumerate(retrieved_chunks, 1):\n",
    "        print(f\"\\n[{i}] Disease: {chunk['disease_name']}\")\n",
    "        print(f\"    Field: {chunk['field']}\")\n",
    "        print(f\"    Score: {chunk['score']:.4f}\")\n",
    "        print(f\"    URL: {chunk['url']}\")\n",
    "        print(f\"    Text: {chunk['text'][:200]}...\")\n",
    "    \n",
    "    # Step 3: Build context\n",
    "    context = retriever.format_context(retrieved_chunks)\n",
    "    \n",
    "    if show_context:\n",
    "        print(f\"\\n{'='*80}\")\n",
    "        print(\"CONTEXT FOR LLM:\")\n",
    "        print(\"=\"*80)\n",
    "        print(context[:1000] + \"...\" if len(context) > 1000 else context)\n",
    "    \n",
    "    # Step 4: Build prompt\n",
    "    prompt = prompt_builder.build_prompt(query, context)\n",
    "    print(f\"\\n3. Prompt Length: {len(prompt)} characters\")\n",
    "    \n",
    "    # Step 5: Generate response (if LLM available)\n",
    "    if llm_available:\n",
    "        print(\"\\n4. Generating response...\")\n",
    "        start_time = time.time()\n",
    "        answer = llm_generator.generate(prompt)\n",
    "        generation_time = time.time() - start_time\n",
    "        \n",
    "        print(f\"   Generation time: {generation_time:.2f}s\")\n",
    "        print(f\"\\n{'='*80}\")\n",
    "        print(\"GENERATED ANSWER:\")\n",
    "        print(\"=\"*80)\n",
    "        print(answer)\n",
    "        print(f\"\\n{'='*80}\")\n",
    "        \n",
    "        total_time = embed_time + retrieval_time + generation_time\n",
    "        print(f\"\\nTotal Time: {total_time:.2f}s\")\n",
    "        \n",
    "        return {\n",
    "            'query': query,\n",
    "            'retrieved_chunks': retrieved_chunks,\n",
    "            'answer': answer,\n",
    "            'embed_time': embed_time,\n",
    "            'retrieval_time': retrieval_time,\n",
    "            'generation_time': generation_time,\n",
    "            'total_time': total_time\n",
    "        }\n",
    "    else:\n",
    "        print(\"\\n⚠ Skipping generation (LLM not available)\")\n",
    "        return {\n",
    "            'query': query,\n",
    "            'retrieved_chunks': retrieved_chunks,\n",
    "            'embed_time': embed_time,\n",
    "            'retrieval_time': retrieval_time\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test a single query\n",
    "test_query = \"What are the symptoms of malaria?\"\n",
    "result = test_single_query(test_query, top_k=3, show_context=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Multiple Queries Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define test queries\n",
    "test_queries = [\n",
    "    \"What are the symptoms of malaria?\",\n",
    "    \"How is tuberculosis treated?\",\n",
    "    \"What causes diabetes?\",\n",
    "    \"How can I prevent COVID-19?\",\n",
    "    \"What are the risk factors for heart disease?\",\n",
    "    \"How is HIV transmitted?\",\n",
    "    \"What are the complications of dengue fever?\",\n",
    "    \"How is cancer diagnosed?\",\n",
    "    \"What is the treatment for hypertension?\",\n",
    "    \"What are the symptoms of depression?\"\n",
    "]\n",
    "\n",
    "print(f\"Testing {len(test_queries)} queries...\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run all test queries\n",
    "results = []\n",
    "\n",
    "for i, query in enumerate(test_queries, 1):\n",
    "    print(f\"\\n[{i}/{len(test_queries)}] Testing: {query}\")\n",
    "    print(\"-\" * 80)\n",
    "    \n",
    "    try:\n",
    "        result = test_single_query(query, top_k=3, show_context=False)\n",
    "        results.append(result)\n",
    "        print(\"✓ Success\")\n",
    "    except Exception as e:\n",
    "        print(f\"✗ Error: {e}\")\n",
    "        results.append({'query': query, 'error': str(e)})\n",
    "    \n",
    "    time.sleep(0.5)  # Brief pause between queries\n",
    "\n",
    "print(f\"\\n\\nCompleted {len(results)} queries\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Performance Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create performance DataFrame\n",
    "perf_data = []\n",
    "for result in results:\n",
    "    if 'error' not in result:\n",
    "        perf_data.append({\n",
    "            'query': result['query'][:50] + '...',\n",
    "            'embed_time': result['embed_time'],\n",
    "            'retrieval_time': result['retrieval_time'],\n",
    "            'generation_time': result.get('generation_time', 0),\n",
    "            'total_time': result.get('total_time', result['embed_time'] + result['retrieval_time'])\n",
    "        })\n",
    "\n",
    "perf_df = pd.DataFrame(perf_data)\n",
    "\n",
    "if len(perf_df) > 0:\n",
    "    print(\"Performance Statistics:\")\n",
    "    print(\"=\" * 80)\n",
    "    print(perf_df[['embed_time', 'retrieval_time', 'generation_time', 'total_time']].describe())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize timing breakdown\n",
    "if len(perf_df) > 0:\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
    "    \n",
    "    # Stacked bar chart\n",
    "    ax = axes[0]\n",
    "    perf_df[['embed_time', 'retrieval_time', 'generation_time']].plot(\n",
    "        kind='bar',\n",
    "        stacked=True,\n",
    "        ax=ax,\n",
    "        color=['#3498db', '#2ecc71', '#e74c3c']\n",
    "    )\n",
    "    ax.set_xlabel('Query Index')\n",
    "    ax.set_ylabel('Time (seconds)')\n",
    "    ax.set_title('Time Breakdown by Query')\n",
    "    ax.legend(['Embedding', 'Retrieval', 'Generation'])\n",
    "    ax.grid(axis='y', alpha=0.3)\n",
    "    \n",
    "    # Box plot\n",
    "    ax = axes[1]\n",
    "    perf_df[['embed_time', 'retrieval_time', 'generation_time']].boxplot(ax=ax)\n",
    "    ax.set_ylabel('Time (seconds)')\n",
    "    ax.set_title('Time Distribution by Component')\n",
    "    ax.set_xticklabels(['Embedding', 'Retrieval', 'Generation'])\n",
    "    ax.grid(axis='y', alpha=0.3)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Retrieval Quality Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Analyze retrieval scores\n",
    "all_scores = []\n",
    "for result in results:\n",
    "    if 'retrieved_chunks' in result:\n",
    "        for chunk in result['retrieved_chunks']:\n",
    "            all_scores.append(chunk['score'])\n",
    "\n",
    "if all_scores:\n",
    "    scores_array = np.array(all_scores)\n",
    "    \n",
    "    print(\"Retrieval Score Statistics:\")\n",
    "    print(f\"Mean: {scores_array.mean():.4f}\")\n",
    "    print(f\"Std:  {scores_array.std():.4f}\")\n",
    "    print(f\"Min:  {scores_array.min():.4f}\")\n",
    "    print(f\"Max:  {scores_array.max():.4f}\")\n",
    "    \n",
    "    # Distribution\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    plt.hist(all_scores, bins=30, edgecolor='black', alpha=0.7, color='skyblue')\n",
    "    plt.axvline(scores_array.mean(), color='red', linestyle='--', label=f'Mean: {scores_array.mean():.3f}')\n",
    "    plt.xlabel('Similarity Score')\n",
    "    plt.ylabel('Frequency')\n",
    "    plt.title('Distribution of Retrieval Similarity Scores')\n",
    "    plt.legend()\n",
    "    plt.grid(alpha=0.3)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Analyze retrieved diseases\n",
    "disease_counts = {}\n",
    "for result in results:\n",
    "    if 'retrieved_chunks' in result:\n",
    "        for chunk in result['retrieved_chunks']:\n",
    "            disease = chunk['disease_name']\n",
    "            disease_counts[disease] = disease_counts.get(disease, 0) + 1\n",
    "\n",
    "if disease_counts:\n",
    "    # Top retrieved diseases\n",
    "    top_diseases = sorted(disease_counts.items(), key=lambda x: x[1], reverse=True)[:10]\n",
    "    \n",
    "    print(\"\\nTop 10 Most Retrieved Diseases:\")\n",
    "    for disease, count in top_diseases:\n",
    "        print(f\"  {disease}: {count} times\")\n",
    "    \n",
    "    # Visualize\n",
    "    diseases, counts = zip(*top_diseases)\n",
    "    plt.figure(figsize=(12, 6))\n",
    "    colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(diseases)))\n",
    "    plt.barh(diseases, counts, color=colors)\n",
    "    plt.xlabel('Retrieval Count')\n",
    "    plt.ylabel('Disease')\n",
    "    plt.title('Most Frequently Retrieved Diseases')\n",
    "    plt.grid(axis='x', alpha=0.3)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Response Quality Inspection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display sample responses\n",
    "if llm_available:\n",
    "    print(\"Sample Generated Responses:\")\n",
    "    print(\"=\" * 80)\n",
    "    \n",
    "    for i, result in enumerate(results[:3], 1):\n",
    "        if 'answer' in result:\n",
    "            print(f\"\\n{i}. Query: {result['query']}\")\n",
    "            print(\"-\" * 80)\n",
    "            print(f\"Answer: {result['answer']}\")\n",
    "            print(\"=\" * 80)\n",
    "else:\n",
    "    print(\"Response generation skipped (LLM not available)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Custom Query Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Interactive query testing\n",
    "print(\"Enter your own query to test the RAG system:\")\n",
    "print(\"(or press Enter to skip)\\n\")\n",
    "\n",
    "custom_query = input(\"Your query: \").strip()\n",
    "\n",
    "if custom_query:\n",
    "    result = test_single_query(custom_query, top_k=5, show_context=True)\n",
    "else:\n",
    "    print(\"Skipped custom query\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. System Summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"RAG SYSTEM TEST SUMMARY\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "success_count = len([r for r in results if 'error' not in r])\n",
    "error_count = len([r for r in results if 'error' in r])\n",
    "\n",
    "print(f\"\\nTest Queries: {len(test_queries)}\")\n",
    "print(f\"Successful: {success_count}\")\n",
    "print(f\"Failed: {error_count}\")\n",
    "\n",
    "if len(perf_df) > 0:\n",
    "    print(f\"\\nPerformance:\")\n",
    "    print(f\"  Average embedding time: {perf_df['embed_time'].mean():.3f}s\")\n",
    "    print(f\"  Average retrieval time: {perf_df['retrieval_time'].mean():.3f}s\")\n",
    "    if 'generation_time' in perf_df.columns:\n",
    "        print(f\"  Average generation time: {perf_df['generation_time'].mean():.3f}s\")\n",
    "        print(f\"  Average total time: {perf_df['total_time'].mean():.3f}s\")\n",
    "\n",
    "if all_scores:\n",
    "    print(f\"\\nRetrieval Quality:\")\n",
    "    print(f\"  Average similarity score: {np.mean(all_scores):.4f}\")\n",
    "    print(f\"  Min similarity score: {np.min(all_scores):.4f}\")\n",
    "    print(f\"  Max similarity score: {np.max(all_scores):.4f}\")\n",
    "\n",
    "print(f\"\\nLLM Status: {'Available' if llm_available else 'Not Available'}\")\n",
    "print(f\"Device: {settings.DEVICE}\")\n",
    "print(f\"Top-K Retrieval: {settings.TOP_K_RETRIEVAL}\")\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Export Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save test results\n",
    "output_path = Path('../data/processed/rag_test_results.json')\n",
    "output_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# Prepare results for JSON (remove non-serializable objects)\n",
    "export_results = []\n",
    "for result in results:\n",
    "    export_result = {\n",
    "        'query': result.get('query', ''),\n",
    "        'answer': result.get('answer', ''),\n",
    "        'embed_time': result.get('embed_time', 0),\n",
    "        'retrieval_time': result.get('retrieval_time', 0),\n",
    "        'generation_time': result.get('generation_time', 0),\n",
    "        'error': result.get('error', None)\n",
    "    }\n",
    "    if 'retrieved_chunks' in result:\n",
    "        export_result['retrieved_diseases'] = [\n",
    "            {\n",
    "                'disease': chunk['disease_name'],\n",
    "                'score': chunk['score']\n",
    "            }\n",
    "            for chunk in result['retrieved_chunks']\n",
    "        ]\n",
    "    export_results.append(export_result)\n",
    "\n",
    "with open(output_path, 'w') as f:\n",
    "    json.dump(export_results, f, indent=2)\n",
    "\n",
    "print(f\"✓ Results saved to {output_path}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}