In [None]:
!pip install -q langchain langchain-community langchain-openai faiss-cpu openai pandas openpyxl pypdf tqdm tenacity

In [None]:
import os
import re
import time
import logging
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass
from enum import Enum

import pandas as pd

from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm import tqdm

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS

from openai import OpenAI


In [None]:
# Guided retrieval configuration
clause_to_sections = {
    "4.1": "ACC.01.00 ME1",
    "4.3": "ACC.04.05 ME1",
    "4.4": "ACC.01.00 ME3",
    "4.5": "ACC.01.00 ME2",
    "4.6": "ACC.03.00 ME1",
    "4.7": "ACC.03.00 ME3",
    "4.8": "ACC.03.01 ME3",
    "5.1.1": "ACC.03.01 ME2 and ME3",
    "5.1.2": "ACC.03.00 ME1",
    "5.1.3": "ACC.01.00 ME5, ACC.03.00 ME2, ME3 and ME4",
    "5.1.4": "ACC.03.00 ME6, ACC.02.02 ME2",
    "5.1.5": "ACC.03.00 ME6",
    "5.1.8": "ACC.03.00 ME5, ACC.03.01 ME1, ME2 and ME3",
    "5.3.1": "ACC.01.00 ME6",
    "5.3.2": "ACC.03.01 ME1, ME2 and ME3",
    "5.3.3": "ACC.03.01 ME3",
    "5.3.4": "ACC.03.01 ME3",
    "5.4.1": "ACC.03.01 ME3",
    "5.5.1": "ACC.03.01 ME1 and ME3",
    "5.5.2": "ACC.03.00 ME6",
    "5.6": "ACC.03.01 ME3",
    "5.6.1": "Ministry of Health Circular No. MH 53:08/4 vol 6 – Guidelines for Inter-Hospital Transfer",
    "5.6.2": "ACC.03.00 ME6",
    "5.6.3": "ACC.03.00 ME5",
    "5.6.4": "ACC.03.01 ME1, ME2 and ME3",
    "5.7": "ACC.03.01 ME and ME3",
}

synonyms = {
    "transfer": ["handover", "relocation", "patient movement"],
    "policy": ["guideline", "procedure"],
    "emergency": ["urgent", "critical", "immediate"],
}

def expand_query_with_synonyms(query: str) -> str:
    words = query.split()
    expanded = []
    for w in words:
        expanded.append(w)
        if w.lower() in synonyms:
            expanded.extend(synonyms[w.lower()])
    return " ".join(expanded)


In [None]:
# API key
if not os.getenv("OPENAI_API_KEY"):
    try:
        from getpass import getpass
        os.environ["OPENAI_API_KEY"] = getpass("Enter OPENAI_API_KEY (input hidden): ")
    except Exception:
        raise ValueError("OPENAI_API_KEY is required.")

print("API Key starts with:", os.getenv("OPENAI_API_KEY")[:5])


Enter OPENAI_API_KEY (input hidden): ··········
API Key starts with: sk-pr


In [None]:
def setup_logging() -> logging.Logger:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )
    return logging.getLogger("policy_pipeline")

logger = setup_logging()


In [None]:
class PolicyDocumentChunker:
    def __init__(self):
        self.clause_patterns = [
            r'(\d+\.\d+(?:\.\d+)?)\.\s*(.*?)(?=\n\d+\.\d+(?:\.\d+)?\.|\n[A-Z]+\s*:|\n\n[A-Z]|\Z)',
            r'(\d+\.\d+(?:\.\d+)?)\s+((?:[A-Z][^.]*\..*?)(?=\n\d+\.\d+|\n[A-Z]+\s*:|\Z))',
            r'(\d+\.\d+\.\d+)\.\s*(.*?)(?=\n\d+\.\d+|\n[A-Z]+\s*:|\Z)',
        ]
        self.definition_pattern = r'(\d+\.\d+)\.\s+([A-Z]+(?:\s+[A-Z]+)*)\s+refers\s+to\s+(.*?)(?=\n\d+\.\d+|\n[A-Z]+\s*:|\Z)'
        self.section_header_pattern = r'(\d+)\.\s+([A-Z][A-Z\s]{3,})\s*\n'

    def extract_comprehensive_chunks(self, text: str) -> List[Dict]:
        text = self._preprocess_text(text)
        clauses = self._extract_clauses(text)
        definitions = self._extract_definitions(text)
        sections = self._extract_sections(text)
        all_chunks = clauses + definitions + sections
        enhanced_chunks = self._add_contextual_info(all_chunks, text)
        return enhanced_chunks

    def _preprocess_text(self, text: str) -> str:
        text = re.sub(r'\n\s*\n\s*\n', '\n\n', text)
        text = re.sub(r'(\d+\.\d+(?:\.\d+)?)\s*\.\s*', r'\1. ', text)
        text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text)
        text = re.sub(r'Document\s+No:\s+[A-Z0-9\-]+', '', text)
        return text

    def _extract_clauses(self, text: str) -> List[Dict]:
        clauses = []
        for pattern in self.clause_patterns:
            matches = re.finditer(pattern, text, re.DOTALL | re.MULTILINE)
            for match in matches:
                clause_num = match.group(1)
                content = match.group(2).strip()
                content = re.sub(r'\s+', ' ', content)
                content = re.sub(r'\n+', '\n', content)
                if len(content) > 20 and not self._is_duplicate_clause(clauses, clause_num):
                    clauses.append({
                        'clause_number': clause_num,
                        'content': content,
                        'content_type': 'clause',
                        'word_count': len(content.split()),
                        'char_count': len(content),
                    })
        return clauses

    def _extract_definitions(self, text: str) -> List[Dict]:
        definitions = []
        matches = re.finditer(self.definition_pattern, text, re.DOTALL | re.IGNORECASE)
        for match in matches:
            clause_num = match.group(1)
            term = match.group(2).strip()
            definition = re.sub(r'\s+', ' ', match.group(3).strip())
            definitions.append({
                'clause_number': clause_num,
                'term': term,
                'definition': definition,
                'content': f"{term} refers to {definition}",
                'content_type': 'definition',
                'word_count': len(definition.split()),
                'char_count': len(definition),
            })
        return definitions

    def _extract_sections(self, text: str) -> List[Dict]:
        sections = []
        matches = re.finditer(self.section_header_pattern, text, re.MULTILINE)
        for match in matches:
            section_num = match.group(1)
            section_title = match.group(2).strip()
            start_pos = match.end()
            next_section = re.search(r'\n\d+\.\s+[A-Z][A-Z\s]{3,}', text[start_pos:])
            if next_section:
                section_content = text[start_pos:start_pos + next_section.start()]
            else:
                section_content = text[start_pos:start_pos + 1000]
            section_content = section_content.strip()
            if len(section_content) > 50:
                sections.append({
                    'clause_number': f"{section_num}.0",
                    'section_title': section_title,
                    'content': f"Section {section_num}: {section_title}\n\n{section_content}",
                    'content_type': 'section',
                    'word_count': len(section_content.split()),
                    'char_count': len(section_content),
                })
        return sections

    def _add_contextual_info(self, chunks: List[Dict], full_text: str) -> List[Dict]:
        enhanced_chunks = []
        for chunk in chunks:
            enhanced_content = self._build_contextual_content(chunk, chunks)
            enhanced_chunk = chunk.copy()
            enhanced_chunk['enhanced_content'] = enhanced_content
            enhanced_chunks.append(enhanced_chunk)
        return enhanced_chunks

    def _build_contextual_content(self, chunk: Dict, all_chunks: List[Dict]) -> str:
        content_parts = [chunk['content']]
        if chunk['content_type'] == 'clause':
            related_defs = self._find_related_definitions(chunk, all_chunks)
            if related_defs:
                content_parts.append("Relevant definitions:")
                for def_chunk in related_defs[:2]:
                    content_parts.append(f"- {def_chunk['content']}")
        section_context = self._find_section_context(chunk, all_chunks)
        if section_context:
            content_parts.insert(0, f"Context: {section_context}")
        return "\n".join(content_parts)

    def _find_related_definitions(self, chunk: Dict, all_chunks: List[Dict]) -> List[Dict]:
        definitions = [c for c in all_chunks if c['content_type'] == 'definition']
        related = []
        chunk_content_lower = chunk['content'].lower()
        for def_chunk in definitions:
            term_lower = def_chunk['term'].lower()
            if term_lower in chunk_content_lower:
                related.append(def_chunk)
        return related

    def _find_section_context(self, chunk: Dict, all_chunks: List[Dict]) -> Optional[str]:
        sections = [c for c in all_chunks if c['content_type'] == 'section']
        clause_num = chunk['clause_number']
        if '.' in clause_num:
            section_num = clause_num.split('.')[0]
            for section in sections:
                if section['clause_number'].startswith(f"{section_num}."):
                    return section.get('section_title', '')
        return None

    def _is_duplicate_clause(self, existing_clauses: List[Dict], clause_num: str) -> bool:
        return any(c['clause_number'] == clause_num for c in existing_clauses)


In [None]:
def load_and_chunk_pdf(pdf_path: str) -> List[Dict]:
    loader = PyPDFLoader(pdf_path)
    pages = loader.load()
    full_text = "\n".join([page.page_content for page in pages])
    chunker = PolicyDocumentChunker()
    chunks = chunker.extract_comprehensive_chunks(full_text)
    for i, chunk in enumerate(chunks):
        chunk.update({
            'chunk_id': f"{os.path.basename(pdf_path)}::chunk_{i+1}",
            'source_file': pdf_path,
        })
    return chunks


def analyze_chunking_results(chunks: List[Dict]) -> Dict:
    analysis = {
        'total_chunks': len(chunks),
        'clause_numbers': [c['clause_number'] for c in chunks if c.get('clause_number') and c['clause_number'] != 'unknown'],
        'unknown_clauses': len([c for c in chunks if c.get('clause_number') == 'unknown']),
        'chunk_types': {}
    }
    for chunk in chunks:
        chunk_type = chunk.get('content_type', 'unknown')
        analysis['chunk_types'][chunk_type] = analysis['chunk_types'].get(chunk_type, 0) + 1
    if chunks:
        analysis['extraction_success_rate'] = ((len(chunks) - analysis['unknown_clauses']) / len(chunks)) * 100
    return analysis


def validate_pdf_path(pdf_path: str) -> None:
    path = Path(pdf_path)
    if not path.exists():
        raise FileNotFoundError(f"PDF file not found: {pdf_path}")
    if path.suffix.lower() != '.pdf':
        raise ValueError(f"File must be a PDF: {pdf_path}")


In [None]:
class ComplianceStatus(Enum):
    COMPLIANT = "Compliant"
    PARTIALLY_COMPLIANT = "Partially Compliant"
    NON_COMPLIANT = "Non-Compliant"
    INSUFFICIENT_INFO = "Insufficient Information"


@dataclass
class ComplianceAnalysis:
    clause_number: str
    compliance_status: ComplianceStatus
    confidence_score: float
    key_gaps: List[str]
    required_changes: List[str]
    jci_references: List[str]
    risk_level: str
    full_analysis: str


class ClauseComparisonAnalyzer:
    def __init__(self, api_key: str, model: str = "gpt-4o"):
        self.client = OpenAI(api_key=api_key)
        self.model = model
        self.config = {
            'max_tokens': 1000,
            'temperature': 0.1,
            'top_p': 0.9,
        }

    def compare_clause_to_jci(self, clause_number: str, clause_text: str, jci_references: List[Dict]) -> ComplianceAnalysis:
        prompt = self._build_comparison_prompt(clause_number, clause_text, jci_references)
        response = self._call_gpt4o(prompt)
        analysis = self._parse_response(response, clause_number)
        return analysis

    def _build_comparison_prompt(self, clause_number: str, clause_text: str, jci_references: List[Dict]) -> str:
        jci_text = self._format_jci_references(jci_references)
        return f"""You are a healthcare compliance expert specializing in JCI hospital accreditation standards.

**TASK:** Compare the policy clause against JCI standards and identify specific changes needed.

**POLICY CLAUSE {clause_number}:**
{clause_text}

**RELEVANT JCI STANDARDS:**
{jci_text}

**ANALYSIS FRAMEWORK:**

1. **REQUIREMENT COMPARISON:** Compare each requirement in the policy clause against the JCI standards above.

2. **GAP IDENTIFICATION:** Identify specific gaps between policy and JCI requirements.

3. **CHANGE SPECIFICATION:** For each gap, specify the exact change needed to achieve compliance.

**PROVIDE STRUCTURED OUTPUT:**

**COMPLIANCE STATUS:** [Compliant/Partially Compliant/Non-Compliant]

**CONFIDENCE LEVEL:** [0-100%]

**SPECIFIC GAPS IDENTIFIED:**
• [Gap 1]: [Specific difference between policy and JCI standard]
• [Gap 2]: [Specific difference between policy and JCI standard]

**REQUIRED CHANGES:**
• [Change 1]: [Exact modification needed] → [JCI Standard this addresses]
• [Change 2]: [Exact modification needed] → [JCI Standard this addresses]

**RISK LEVEL:** [High/Medium/Low]

**JCI STANDARDS REFERENCED:**
• [Standard]: [Specific requirement]

Focus on actionable, specific changes rather than general recommendations."""

    def _format_jci_references(self, references: List[Dict]) -> str:
        if not references:
            return "No specific JCI references available for comparison."
        formatted = []
        for i, ref in enumerate(references, 1):
            text = ref.get('text', '')
            if len(text) > 300:
                text = text[:300] + "..."
            formatted.append(f"**JCI Reference {i}:**\n{text}")
        return "\n\n".join(formatted)

    @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
    def _call_gpt4o(self, prompt: str) -> str:
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {
                    "role": "system",
                    "content": "You are a senior healthcare compliance consultant. Provide specific, actionable analysis focused on exact changes needed for JCI compliance."
                },
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            **self.config
        )
        return response.choices[0].message.content

    def _parse_response(self, response: str, clause_number: str) -> ComplianceAnalysis:
        import re as _re
        compliance_match = _re.search(r'\*\*COMPLIANCE STATUS:\*\*\s*([^\n]+)', response, _re.IGNORECASE)
        compliance_status = ComplianceStatus.INSUFFICIENT_INFO
        if compliance_match:
            status_text = compliance_match.group(1).strip().lower()
            if "non-compliant" in status_text:
                compliance_status = ComplianceStatus.NON_COMPLIANT
            elif "partially compliant" in status_text:
                compliance_status = ComplianceStatus.PARTIALLY_COMPLIANT
            elif "compliant" in status_text:
                compliance_status = ComplianceStatus.COMPLIANT

        confidence_match = _re.search(r'\*\*CONFIDENCE LEVEL:\*\*\s*(\d+)%', response)
        confidence_score = float(confidence_match.group(1)) / 100 if confidence_match else 0.0

        gaps_section = _re.search(r'\*\*SPECIFIC GAPS IDENTIFIED:\*\*(.*?)(?=\*\*[A-Z ]+:\*\*|\Z)', response, _re.DOTALL)
        key_gaps = []
        if gaps_section:
            gaps_text = gaps_section.group(1)
            key_gaps = [g.strip() for g in _re.findall(r'•\s*([^\n•]+)', gaps_text) if g.strip()]

        changes_section = _re.search(r'\*\*REQUIRED CHANGES:\*\*(.*?)(?=\*\*[A-Z ]+:\*\*|\Z)', response, _re.DOTALL)
        required_changes = []
        if changes_section:
            changes_text = changes_section.group(1)
            required_changes = [c.strip() for c in _re.findall(r'•\s*([^\n•]+)', changes_text) if c.strip()]

        jci_section = _re.search(r'\*\*JCI STANDARDS REFERENCED:\*\*(.*?)(?=\*\*[A-Z ]+:\*\*|\Z)', response, _re.DOTALL)
        jci_references = []
        if jci_section:
            jci_text = jci_section.group(1)
            jci_references = [r.strip() for r in _re.findall(r'•\s*([^\n•]+)', jci_text) if r.strip()]

        risk_match = _re.search(r'\*\*RISK LEVEL:\*\*\s*([^\n]+)', response, _re.IGNORECASE)
        risk_level = risk_match.group(1).strip() if risk_match else "Unknown"

        return ComplianceAnalysis(
            clause_number=clause_number,
            compliance_status=compliance_status,
            confidence_score=confidence_score,
            key_gaps=key_gaps,
            required_changes=required_changes,
            jci_references=jci_references,
            risk_level=risk_level,
            full_analysis=response
        )


In [None]:
# Load policy PDF, chunk, and save chunking results
POLICY_PDF = "/content/old Transfer 290523 .pdf"   # <- change as needed
JCI_PDF = "/content/ESS Standards_7th Ed .pdf"  # <- change as needed

validate_pdf_path(POLICY_PDF)
logger.info("Loading and chunking policy document...")
policy_chunks = load_and_chunk_pdf(POLICY_PDF)
policy_analysis = analyze_chunking_results(policy_chunks)
logger.info(f"Chunking complete: {policy_analysis}")

pd.DataFrame(policy_chunks).to_excel("enhanced_chunking_results.xlsx", index=False)
logger.info("Saved enhanced_chunking_results.xlsx")


In [None]:
# Embeddings
EMBEDDING_MODEL = "text-embedding-3-small"
embeddings = OpenAIEmbeddings(
    model=EMBEDDING_MODEL,
    api_key=os.getenv("OPENAI_API_KEY"),
)


In [None]:
# Load and chunk JCI, split into smaller pieces, build FAISS
validate_pdf_path(JCI_PDF)
logger.info("Loading and chunking JCI standards...")
jci_chunks = load_and_chunk_pdf(JCI_PDF)

def split_jci_pieces(chunks, size, overlap):
    splitter_local = RecursiveCharacterTextSplitter(
        chunk_size=size,
        chunk_overlap=overlap,
        separators=["\n\n", "\n", " ", ""],
    )
    out_texts, out_metas = [], []
    for c in chunks:
        content = c.get("content", "")
        if not content:
            continue
        pieces = splitter_local.split_text(content)
        base_meta = {"section": c.get("clause_number", "n/a")}
        for idx, piece in enumerate(pieces):
            piece = piece.strip()
            if len(piece) >= 20:
                out_texts.append(piece)
                out_metas.append({**base_meta, "piece_index": idx})
    return out_texts, out_metas

PRIMARY_SIZE, PRIMARY_OVERLAP = 1500, 200
split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, PRIMARY_SIZE, PRIMARY_OVERLAP)

def build_faiss_in_batches(texts, metadatas, embedder, batch_size=64):
    assert len(texts) == len(metadatas)
    if not texts:
        raise ValueError("No texts to index.")
    first_end = min(batch_size, len(texts))
    db = FAISS.from_texts(texts[:first_end], embedder, metadatas=metadatas[:first_end])
    for i in range(first_end, len(texts), batch_size):
        j = min(i + batch_size, len(texts))
        db.add_texts(texts[i:j], metadatas=metadatas[i:j])
        time.sleep(0.2)
    return db

try:
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=64)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (size {PRIMARY_SIZE})")
except Exception as e:
    logger.warning(f"Primary JCI embedding failed ({e}). Retrying with smaller pieces.")
    FALLBACK_SIZE, FALLBACK_OVERLAP = 900, 150
    split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, FALLBACK_SIZE, FALLBACK_OVERLAP)
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=48)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (fallback size {FALLBACK_SIZE})")


In [None]:
# Hybrid retrieval (guided-first, automated fallback)

def retrieve_jci_hybrid(clause_number: str, clause_text: str, k: int = 5) -> List[Dict]:
    expanded_text = expand_query_with_synonyms(clause_text)
    mapped = clause_to_sections.get(clause_number)

    # fetch expanded candidates
    docs = jci_db.similarity_search(expanded_text, k=max(3*k, 15))
    if mapped:
        filtered = [d for d in docs if mapped.lower() in d.page_content.lower()]
        if filtered:
            return [{"text": d.page_content, "metadata": d.metadata} for d in filtered[:k]]

    # fallback to vanilla similarity on original clause text
    auto_docs = jci_db.similarity_search(clause_text, k=k)
    return [{"text": d.page_content, "metadata": d.metadata} for d in auto_docs]


In [None]:
# Analysis config and batch run
MODEL = "gpt-4o"
RATE_LIMIT_DELAY = 2.0
BATCH_SIZE = 10

analyzer = ClauseComparisonAnalyzer(api_key=os.getenv("OPENAI_API_KEY"), model=MODEL)

def batch_analyze_clauses_with_progress(chunks: List[Dict]) -> List[ComplianceAnalysis]:
    results: List[ComplianceAnalysis] = []
    work_items = [c for c in chunks if c.get('clause_number') and c['clause_number'] != 'unknown']

    with tqdm(total=len(work_items), desc="Analyzing clauses") as pbar:
        for chunk in work_items:
            clause_number = chunk['clause_number']
            clause_text = chunk.get('enhanced_content', chunk.get('content', ''))
            jci_refs = retrieve_jci_hybrid(clause_number, clause_text)
            try:
                analysis = analyzer.compare_clause_to_jci(clause_number, clause_text, jci_refs)
                results.append(analysis)
            except Exception as e:
                logger.error(f"Analysis failed for clause {clause_number}: {e}")
                results.append(
                    ComplianceAnalysis(
                        clause_number=clause_number,
                        compliance_status=ComplianceStatus.INSUFFICIENT_INFO,
                        confidence_score=0.0,
                        key_gaps=[f"Analysis failed: {str(e)}"],
                        required_changes=["Retry analysis"],
                        jci_references=[],
                        risk_level="Unknown",
                        full_analysis=f"Error: {str(e)}"
                    )
                )
            time.sleep(RATE_LIMIT_DELAY)
            pbar.update(1)
    return results


In [None]:
# Run batches and save
all_results: List[ComplianceAnalysis] = []

for i in range(0, len(policy_chunks), BATCH_SIZE):
    batch = policy_chunks[i:i+BATCH_SIZE]
    logger.info(f"Processing batch {i//BATCH_SIZE + 1} with {len(batch)} clauses...")
    batch_results = batch_analyze_clauses_with_progress(batch)
    all_results.extend(batch_results)

logger.info(f"Total analyzed: {len(all_results)}")

# Save to Excel
def save_comparison_results(analyses: List[ComplianceAnalysis], output_file: str):
    data = []
    for analysis in analyses:
        data.append({
            'clause_number': analysis.clause_number,
            'compliance_status': analysis.compliance_status.value,
            'confidence_score': f"{analysis.confidence_score:.0%}",
            'risk_level': analysis.risk_level,
            'gaps_identified': ' | '.join(analysis.key_gaps),
            'required_changes': ' | '.join(analysis.required_changes),
            'jci_standards_cited': ' | '.join(analysis.jci_references),
            'detailed_analysis': analysis.full_analysis
        })

    summary = {
        'total_analyzed': len(analyses),
        'compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.COMPLIANT]),
        'partially_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.PARTIALLY_COMPLIANT]),
        'non_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.NON_COMPLIANT]),
        'high_risk': len([a for a in analyses if 'high' in a.risk_level.lower()]),
    }

    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        pd.DataFrame(data).to_excel(writer, sheet_name='Clause Analysis', index=False)
        pd.DataFrame([summary]).to_excel(writer, sheet_name='Summary', index=False)

save_comparison_results(all_results, "policy_vs_jci_hybrid.xlsx")
logger.info("Results saved to policy_vs_jci_hybrid.xlsx")


Analyzing clauses: 100%|██████████| 10/10 [01:28<00:00,  8.82s/it]
Analyzing clauses: 100%|██████████| 10/10 [01:35<00:00,  9.51s/it]
Analyzing clauses: 100%|██████████| 6/6 [00:56<00:00,  9.50s/it]


In [None]:
print("Chunking results:", policy_analysis)
display(pd.DataFrame(policy_chunks).head())

Chunking results: {'total_chunks': 26, 'clause_numbers': ['5.1', '5.2', '1.1', '1.2', '2.1', '3.1', '3.2', '3.3', '3.4', '3.5', '3.6', '3.7', '3.8', '4.1', '4.3', '4.5', '4.6', '4.8', '6.1', '6.3', '1.0', '2.0', '3.0', '4.0', '5.0', '6.0'], 'unknown_clauses': 0, 'chunk_types': {'clause': 20, 'section': 6}, 'extraction_success_rate': 100.0}


Unnamed: 0,clause_number,content,content_type,word_count,char_count,enhanced_content,chunk_id,source_file,section_title
0,5.1,1 Request should be made by a member of the re...,clause,59,466,Context: GUIDELINES\n1 Request should be made ...,old Transfer 290523 .pdf::chunk_1,/content/old Transfer 290523 .pdf,
1,5.2,1 The decision to transfer patients to another...,clause,67,535,Context: GUIDELINES\n1 The decision to transfe...,old Transfer 290523 .pdf::chunk_2,/content/old Transfer 290523 .pdf,
2,1.1,To provide guidelines for healthcare team to c...,clause,16,116,Context: PURPOSE\nTo provide guidelines for he...,old Transfer 290523 .pdf::chunk_3,/content/old Transfer 290523 .pdf,
3,1.2,To ensure that patient’s information is convey...,clause,19,124,Context: PURPOSE\nTo ensure that patient’s inf...,old Transfer 290523 .pdf::chunk_4,/content/old Transfer 290523 .pdf,
4,2.1,Patienttransfersduringthecourseofahospitaladmi...,clause,44,323,Context: POLICY\nPatienttransfersduringthecour...,old Transfer 290523 .pdf::chunk_5,/content/old Transfer 290523 .pdf,


# Task
Integrate BM25 retrieval into the existing retrieval pipeline for the JCI dataset.

## Install bm25 library

### Subtask:
Add a new cell to install the `rank_bm25` library.


**Reasoning**:
The subtask is to install the `rank_bm25` library. This requires adding a new code cell with the pip installation command.



In [1]:
!pip install -q rank_bm25

## Update imports

### Subtask:
Add the necessary import for `BM25Okapi` to the existing imports cell.


**Reasoning**:
Add the import statement for `BM25Okapi` to the existing import cell.



In [2]:
import os
import re
import time
import logging
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass
from enum import Enum

import pandas as pd

from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm import tqdm

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS

from openai import OpenAI

from rank_bm25 import BM25Okapi

ModuleNotFoundError: No module named 'langchain_community'

**Reasoning**:
The previous command failed because `langchain_community` was not found. This likely means the kernel session was reset or the environment changed. Re-run the initial installation cell to ensure all necessary libraries, including `langchain_community`, are installed.



# Task
Integrate BM25 retrieval into the existing Python notebook for hybrid search, modifying the code cells to include installation, imports, index building, and result combination with the existing FAISS retrieval.

## Install bm25 library

### Subtask:
Modify the first code cell to include the `rank_bm25` installation.


**Reasoning**:
The subtask is to modify the first code cell to include the `rank_bm25` installation. I will edit the code cell with cell_id `-MNgdbS9YtYt` and add `rank_bm25` to the list of libraries to be installed.



In [3]:
!pip install -q langchain langchain-community langchain-openai faiss-cpu openai pandas openpyxl pypdf tqdm tenacity rank_bm25

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━[0m [32m1.6/2.5 MB[0m [31m49.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.0/75.0 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m53.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m310.5/310.5 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m447.5/447.5 kB[0m [31m28.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.7/64.7 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Update imports

### Subtask:
Update imports


**Reasoning**:
Re-running the initial installation cell is necessary because the previous import cell failed due to a missing module, indicating a potential environment issue or kernel restart.



In [4]:
!pip install -q langchain langchain-community langchain-openai faiss-cpu openai pandas openpyxl pypdf tqdm tenacity rank_bm25

**Reasoning**:
Now that the required libraries are installed, the import cell needs to be re-run to include the `BM25Okapi` import and ensure all necessary modules are available for subsequent steps.



In [5]:
import os
import re
import time
import logging
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass
from enum import Enum

import pandas as pd

from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm import tqdm

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS

from openai import OpenAI

from rank_bm25 import BM25Okapi

## Build bm25 index

### Subtask:
Create a new function to build a BM25 index from the JCI chunk texts and integrate it into the existing JCI processing logic.


**Reasoning**:
Define the `build_bm25_index` function and integrate it into the JCI processing code cell to build the BM25 index.



In [6]:
def build_bm25_index(texts: List[str]) -> BM25Okapi:
    """Builds a BM25 index from a list of text documents."""
    tokenized_corpus = [doc.split() for doc in texts]
    bm25 = BM25Okapi(tokenized_corpus)
    return bm25

# Load and chunk JCI, split into smaller pieces, build FAISS and BM25
validate_pdf_path(JCI_PDF)
logger.info("Loading and chunking JCI standards...")
jci_chunks = load_and_chunk_pdf(JCI_PDF)

def split_jci_pieces(chunks, size, overlap):
    splitter_local = RecursiveCharacterTextSplitter(
        chunk_size=size,
        chunk_overlap=overlap,
        separators=["\n\n", "\n", " ", ""],
    )
    out_texts, out_metas = [], []
    for c in chunks:
        content = c.get("content", "")
        if not content:
            continue
        pieces = splitter_local.split_text(content)
        base_meta = {"section": c.get("clause_number", "n/a")}
        for idx, piece in enumerate(pieces):
            piece = piece.strip()
            if len(piece) >= 20:
                out_texts.append(piece)
                out_metas.append({**base_meta, "piece_index": idx})
    return out_texts, out_metas

PRIMARY_SIZE, PRIMARY_OVERLAP = 1500, 200
split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, PRIMARY_SIZE, PRIMARY_OVERLAP)

def build_faiss_in_batches(texts, metadatas, embedder, batch_size=64):
    assert len(texts) == len(metadatas)
    if not texts:
        raise ValueError("No texts to index.")
    first_end = min(batch_size, len(texts))
    db = FAISS.from_texts(texts[:first_end], embedder, metadatas=metadatas[:first_end])
    for i in range(first_end, len(texts), batch_size):
        j = min(i + batch_size, len(texts))
        db.add_texts(texts[i:j], metadatas=metadatas[i:j])
        time.sleep(0.2)
    return db

try:
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=64)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (size {PRIMARY_SIZE})")
except Exception as e:
    logger.warning(f"Primary JCI embedding failed ({e}). Retrying with smaller pieces.")
    FALLBACK_SIZE, FALLBACK_OVERLAP = 900, 150
    split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, FALLBACK_SIZE, FALLBACK_OVERLAP)
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=48)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (fallback size {FALLBACK_SIZE})")

# Build BM25 index
jci_bm25 = build_bm25_index(split_jci_texts)
logger.info(f"BM25 index built with {len(split_jci_texts)} JCI pieces")

NameError: name 'validate_pdf_path' is not defined

**Reasoning**:
The previous command failed because functions like `validate_pdf_path`, `load_and_chunk_pdf`, and `build_faiss_in_batches` were not defined in the current execution context, even though they were defined in previous cells. The entire code block from the previous attempt needs to be re-run, including the definitions of these functions, to fix the `NameError`.



In [7]:
def build_bm25_index(texts: List[str]) -> BM25Okapi:
    """Builds a BM25 index from a list of text documents."""
    tokenized_corpus = [doc.split() for doc in texts]
    bm25 = BM25Okapi(tokenized_corpus)
    return bm25

def validate_pdf_path(pdf_path: str) -> None:
    path = Path(pdf_path)
    if not path.exists():
        raise FileNotFoundError(f"PDF file not found: {pdf_path}")
    if path.suffix.lower() != '.pdf':
        raise ValueError(f"File must be a PDF: {pdf_path}")

def load_and_chunk_pdf(pdf_path: str) -> List[Dict]:
    loader = PyPDFLoader(pdf_path)
    pages = loader.load()
    full_text = "\n".join([page.page_content for page in pages])
    chunker = PolicyDocumentChunker()
    chunks = chunker.extract_comprehensive_chunks(full_text)
    for i, chunk in enumerate(chunks):
        chunk.update({
            'chunk_id': f"{os.path.basename(pdf_path)}::chunk_{i+1}",
            'source_file': pdf_path,
        })
    return chunks

# Load and chunk JCI, split into smaller pieces, build FAISS and BM25
validate_pdf_path(JCI_PDF)
logger.info("Loading and chunking JCI standards...")
jci_chunks = load_and_chunk_pdf(JCI_PDF)

def split_jci_pieces(chunks, size, overlap):
    splitter_local = RecursiveCharacterTextSplitter(
        chunk_size=size,
        chunk_overlap=overlap,
        separators=["\n\n", "\n", " ", ""],
    )
    out_texts, out_metas = [], []
    for c in chunks:
        content = c.get("content", "")
        if not content:
            continue
        pieces = splitter_local.split_text(content)
        base_meta = {"section": c.get("clause_number", "n/a")}
        for idx, piece in enumerate(pieces):
            piece = piece.strip()
            if len(piece) >= 20:
                out_texts.append(piece)
                out_metas.append({**base_meta, "piece_index": idx})
    return out_texts, out_metas

PRIMARY_SIZE, PRIMARY_OVERLAP = 1500, 200
split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, PRIMARY_SIZE, PRIMARY_OVERLAP)

def build_faiss_in_batches(texts, metadatas, embedder, batch_size=64):
    assert len(texts) == len(metadatas)
    if not texts:
        raise ValueError("No texts to index.")
    first_end = min(batch_size, len(texts))
    db = FAISS.from_texts(texts[:first_end], embedder, metadatas=metadatas[:first_end])
    for i in range(first_end, len(texts), batch_size):
        j = min(i + batch_size, len(texts))
        db.add_texts(texts[i:j], metadatas=metadatas[i:j])
        time.sleep(0.2)
    return db

try:
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=64)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (size {PRIMARY_SIZE})")
except Exception as e:
    logger.warning(f"Primary JCI embedding failed ({e}). Retrying with smaller pieces.")
    FALLBACK_SIZE, FALLBACK_OVERLAP = 900, 150
    split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, FALLBACK_SIZE, FALLBACK_OVERLAP)
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=48)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (fallback size {FALLBACK_SIZE})")

# Build BM25 index
jci_bm25 = build_bm25_index(split_jci_texts)
logger.info(f"BM25 index built with {len(split_jci_texts)} JCI pieces")

NameError: name 'JCI_PDF' is not defined

**Reasoning**:
The previous command failed because the `JCI_PDF` variable was not defined in the current execution context. This variable is defined in a previous cell. The entire code block from the previous attempt needs to be re-run, including the definition of `JCI_PDF` and other necessary variables like `embeddings` and `PolicyDocumentChunker`, to fix the `NameError`.



In [8]:
# Policy and JCI PDF paths (re-defined for context)
POLICY_PDF = "/content/old Transfer 290523 .pdf"   # <- change as needed
JCI_PDF = "/content/ESS Standards_7th Ed .pdf"  # <- change as needed

# Embeddings (re-instantiated for context)
EMBEDDING_MODEL = "text-embedding-3-small"
embeddings = OpenAIEmbeddings(
    model=EMBEDDING_MODEL,
    api_key=os.getenv("OPENAI_API_KEY"),
)

# Chunker (re-instantiated for context)
chunker = PolicyDocumentChunker()

def build_bm25_index(texts: List[str]) -> BM25Okapi:
    """Builds a BM25 index from a list of text documents."""
    tokenized_corpus = [doc.split() for doc in texts]
    bm25 = BM25Okapi(tokenized_corpus)
    return bm25

def validate_pdf_path(pdf_path: str) -> None:
    path = Path(pdf_path)
    if not path.exists():
        raise FileNotFoundError(f"PDF file not found: {pdf_path}")
    if path.suffix.lower() != '.pdf':
        raise ValueError(f"File must be a PDF: {pdf_path}")

def load_and_chunk_pdf(pdf_path: str) -> List[Dict]:
    loader = PyPDFLoader(pdf_path)
    pages = loader.load()
    full_text = "\n".join([page.page_content for page in pages])
    # Use the chunker from the outer scope
    chunks = chunker.extract_comprehensive_chunks(full_text)
    for i, chunk in enumerate(chunks):
        chunk.update({
            'chunk_id': f"{os.path.basename(pdf_path)}::chunk_{i+1}",
            'source_file': pdf_path,
        })
    return chunks

# Load and chunk JCI, split into smaller pieces, build FAISS and BM25
validate_pdf_path(JCI_PDF)
logger.info("Loading and chunking JCI standards...")
jci_chunks = load_and_chunk_pdf(JCI_PDF)

def split_jci_pieces(chunks, size, overlap):
    splitter_local = RecursiveCharacterTextSplitter(
        chunk_size=size,
        chunk_overlap=overlap,
        separators=["\n\n", "\n", " ", ""],
    )
    out_texts, out_metas = [], []
    for c in chunks:
        content = c.get("content", "")
        if not content:
            continue
        pieces = splitter_local.split_text(content)
        base_meta = {"section": c.get("clause_number", "n/a")}
        for idx, piece in enumerate(pieces):
            piece = piece.strip()
            if len(piece) >= 20:
                out_texts.append(piece)
                out_metas.append({**base_meta, "piece_index": idx})
    return out_texts, out_metas

PRIMARY_SIZE, PRIMARY_OVERLAP = 1500, 200
split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, PRIMARY_SIZE, PRIMARY_OVERLAP)

def build_faiss_in_batches(texts, metadatas, embedder, batch_size=64):
    assert len(texts) == len(metadatas)
    if not texts:
        raise ValueError("No texts to index.")
    first_end = min(batch_size, len(texts))
    db = FAISS.from_texts(texts[:first_end], embedder, metadatas=metadatas[:first_end])
    for i in range(first_end, len(texts), batch_size):
        j = min(i + batch_size, len(texts))
        db.add_texts(texts[i:j], metadatas=metadatas[i:j])
        time.sleep(0.2)
    return db

try:
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=64)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (size {PRIMARY_SIZE})")
except Exception as e:
    logger.warning(f"Primary JCI embedding failed ({e}). Retrying with smaller pieces.")
    FALLBACK_SIZE, FALLBACK_OVERLAP = 900, 150
    split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, FALLBACK_SIZE, FALLBACK_OVERLAP)
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=48)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (fallback size {FALLBACK_SIZE})")

# Build BM25 index
jci_bm25 = build_bm25_index(split_jci_texts)
logger.info(f"BM25 index built with {len(split_jci_texts)} JCI pieces")

OpenAIError: The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable

**Reasoning**:
The previous command failed because the `OPENAI_API_KEY` environment variable was not set or accessible when the `OpenAIEmbeddings` object was instantiated. This likely indicates that the environment variable set in an earlier cell was lost. Re-running the cell that sets the API key environment variable should fix this issue. Then, re-run the code block that defines variables and functions and builds the indices.



In [9]:
# API key
if not os.getenv("OPENAI_API_KEY"):
    try:
        from getpass import getpass
        os.environ["OPENAI_API_KEY"] = getpass("Enter OPENAI_API_KEY (input hidden): ")
    except Exception:
        raise ValueError("OPENAI_API_KEY is required.")

print("API Key starts with:", os.getenv("OPENAI_API_KEY")[:5])

Enter OPENAI_API_KEY (input hidden): ··········
API Key starts with: sk-pr


**Reasoning**:
The API key is now set. Re-run the code block that defines the necessary functions and variables, loads and chunks the JCI PDF, splits the chunks into smaller text pieces, and builds both FAISS and BM25 indices to continue with the subtask.



In [10]:
# Policy and JCI PDF paths (re-defined for context)
POLICY_PDF = "/content/old Transfer 290523 .pdf"   # <- change as needed
JCI_PDF = "/content/ESS Standards_7th Ed .pdf"  # <- change as needed

# Embeddings (re-instantiated for context)
EMBEDDING_MODEL = "text-embedding-3-small"
embeddings = OpenAIEmbeddings(
    model=EMBEDDING_MODEL,
    api_key=os.getenv("OPENAI_API_KEY"),
)

# Chunker (re-instantiated for context)
chunker = PolicyDocumentChunker()

def build_bm25_index(texts: List[str]) -> BM25Okapi:
    """Builds a BM25 index from a list of text documents."""
    tokenized_corpus = [doc.split() for doc in texts]
    bm25 = BM25Okapi(tokenized_corpus)
    return bm25

def validate_pdf_path(pdf_path: str) -> None:
    path = Path(pdf_path)
    if not path.exists():
        raise FileNotFoundError(f"PDF file not found: {pdf_path}")
    if path.suffix.lower() != '.pdf':
        raise ValueError(f"File must be a PDF: {pdf_path}")

def load_and_chunk_pdf(pdf_path: str) -> List[Dict]:
    loader = PyPDFLoader(pdf_path)
    pages = loader.load()
    full_text = "\n".join([page.page_content for page in pages])
    # Use the chunker from the outer scope
    chunks = chunker.extract_comprehensive_chunks(full_text)
    for i, chunk in enumerate(chunks):
        chunk.update({
            'chunk_id': f"{os.path.basename(pdf_path)}::chunk_{i+1}",
            'source_file': pdf_path,
        })
    return chunks

# Load and chunk JCI, split into smaller pieces, build FAISS and BM25
validate_pdf_path(JCI_PDF)
logger.info("Loading and chunking JCI standards...")
jci_chunks = load_and_chunk_pdf(JCI_PDF)

def split_jci_pieces(chunks, size, overlap):
    splitter_local = RecursiveCharacterTextSplitter(
        chunk_size=size,
        chunk_overlap=overlap,
        separators=["\n\n", "\n", " ", ""],
    )
    out_texts, out_metas = [], []
    for c in chunks:
        content = c.get("content", "")
        if not content:
            continue
        pieces = splitter_local.split_text(content)
        base_meta = {"section": c.get("clause_number", "n/a")}
        for idx, piece in enumerate(pieces):
            piece = piece.strip()
            if len(piece) >= 20:
                out_texts.append(piece)
                out_metas.append({**base_meta, "piece_index": idx})
    return out_texts, out_metas

PRIMARY_SIZE, PRIMARY_OVERLAP = 1500, 200
split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, PRIMARY_SIZE, PRIMARY_OVERLAP)

def build_faiss_in_batches(texts, metadatas, embedder, batch_size=64):
    assert len(texts) == len(metadatas)
    if not texts:
        raise ValueError("No texts to index.")
    first_end = min(batch_size, len(texts))
    db = FAISS.from_texts(texts[:first_end], embedder, metadatas=metadatas[:first_end])
    for i in range(first_end, len(texts), batch_size):
        j = min(i + batch_size, len(texts))
        db.add_texts(texts[i:j], metadatas=metadatas[i:j])
        time.sleep(0.2)
    return db

try:
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=64)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (size {PRIMARY_SIZE})")
except Exception as e:
    logger.warning(f"Primary JCI embedding failed ({e}). Retrying with smaller pieces.")
    FALLBACK_SIZE, FALLBACK_OVERLAP = 900, 150
    split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, FALLBACK_SIZE, FALLBACK_OVERLAP)
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=48)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (fallback size {FALLBACK_SIZE})")


# Build BM25 index
jci_bm25 = build_bm25_index(split_jci_texts)
logger.info(f"BM25 index built with {len(split_jci_texts)} JCI pieces")

NameError: name 'PolicyDocumentChunker' is not defined

**Reasoning**:
The previous command failed because `PolicyDocumentChunker` was not defined. This class is defined in an earlier cell. To fix this, the definition of `PolicyDocumentChunker` needs to be included in the current code block along with the other necessary definitions and the logic for building the indices.



In [11]:
# Policy and JCI PDF paths (re-defined for context)
POLICY_PDF = "/content/old Transfer 290523 .pdf"   # <- change as needed
JCI_PDF = "/content/ESS Standards_7th Ed .pdf"  # <- change as needed

# Embeddings (re-instantiated for context)
EMBEDDING_MODEL = "text-embedding-3-small"
embeddings = OpenAIEmbeddings(
    model=EMBEDDING_MODEL,
    api_key=os.getenv("OPENAI_API_KEY"),
)

# PolicyDocumentChunker definition (re-included for context)
class PolicyDocumentChunker:
    def __init__(self):
        self.clause_patterns = [
            r'(\d+\.\d+(?:\.\d+)?)\.\s*(.*?)(?=\n\d+\.\d+(?:\.\d+)?\.|\n[A-Z]+\s*:|\n\n[A-Z]|\Z)',
            r'(\d+\.\d+(?:\.\d+)?)\s+((?:[A-Z][^.]*\..*?)(?=\n\d+\.\d+|\n[A-Z]+\s*:|\Z))',
            r'(\d+\.\d+\.\d+)\.\s*(.*?)(?=\n\d+\.\d+|\n[A-Z]+\s*:|\Z)',
        ]
        self.definition_pattern = r'(\d+\.\d+)\.\s+([A-Z]+(?:\s+[A-Z]+)*)\s+refers\s+to\s+(.*?)(?=\n\d+\.\d+|\n[A-Z]+\s*:|\Z)'
        self.section_header_pattern = r'(\d+)\.\s+([A-Z][A-Z\s]{3,})\s*\n'

    def extract_comprehensive_chunks(self, text: str) -> List[Dict]:
        text = self._preprocess_text(text)
        clauses = self._extract_clauses(text)
        definitions = self._extract_definitions(text)
        sections = self._extract_sections(text)
        all_chunks = clauses + definitions + sections
        enhanced_chunks = self._add_contextual_info(all_chunks, text)
        return enhanced_chunks

    def _preprocess_text(self, text: str) -> str:
        text = re.sub(r'\n\s*\n\s*\n', '\n\n', text)
        text = re.sub(r'(\d+\.\d+(?:\.\d+)?)\s*\.\s*', r'\1. ', text)
        text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text)
        text = re.sub(r'Document\s+No:\s+[A-Z0-9\-]+', '', text)
        return text

    def _extract_clauses(self, text: str) -> List[Dict]:
        clauses = []
        for pattern in self.clause_patterns:
            matches = re.finditer(pattern, text, re.DOTALL | re.MULTILINE)
            for match in matches:
                clause_num = match.group(1)
                content = match.group(2).strip()
                content = re.sub(r'\s+', ' ', content)
                content = re.sub(r'\n+', '\n', content)
                if len(content) > 20 and not self._is_duplicate_clause(clauses, clause_num):
                    clauses.append({
                        'clause_number': clause_num,
                        'content': content,
                        'content_type': 'clause',
                        'word_count': len(content.split()),
                        'char_count': len(content),
                    })
        return clauses

    def _extract_definitions(self, text: str) -> List[Dict]:
        definitions = []
        matches = re.finditer(self.definition_pattern, text, re.DOTALL | re.IGNORECASE)
        for match in matches:
            clause_num = match.group(1)
            term = match.group(2).strip()
            definition = re.sub(r'\s+', ' ', match.group(3).strip())
            definitions.append({
                'clause_number': clause_num,
                'term': term,
                'definition': definition,
                'content': f"{term} refers to {definition}",
                'content_type': 'definition',
                'word_count': len(definition.split()),
                'char_count': len(definition),
            })
        return definitions

    def _extract_sections(self, text: str) -> List[Dict]:
        sections = []
        matches = re.finditer(self.section_header_pattern, text, re.MULTILINE)
        for match in matches:
            section_num = match.group(1)
            section_title = match.group(2).strip()
            start_pos = match.end()
            next_section = re.search(r'\n\d+\.\s+[A-Z][A-Z\s]{3,}', text[start_pos:])
            if next_section:
                section_content = text[start_pos:start_pos + next_section.start()]
            else:
                section_content = text[start_pos:start_pos + 1000]
            section_content = section_content.strip()
            if len(section_content) > 50:
                sections.append({
                    'clause_number': f"{section_num}.0",
                    'section_title': section_title,
                    'content': f"Section {section_num}: {section_title}\n\n{section_content}",
                    'content_type': 'section',
                    'word_count': len(section_content.split()),
                    'char_count': len(section_content),
                })
        return sections

    def _add_contextual_info(self, chunks: List[Dict], full_text: str) -> List[Dict]:
        enhanced_chunks = []
        for chunk in chunks:
            enhanced_content = self._build_contextual_content(chunk, chunks)
            enhanced_chunk = chunk.copy()
            enhanced_chunk['enhanced_content'] = enhanced_content
            enhanced_chunks.append(enhanced_chunk)
        return enhanced_chunks

    def _build_contextual_content(self, chunk: Dict, all_chunks: List[Dict]) -> str:
        content_parts = [chunk['content']]
        if chunk['content_type'] == 'clause':
            related_defs = self._find_related_definitions(chunk, all_chunks)
            if related_defs:
                content_parts.append("Relevant definitions:")
                for def_chunk in related_defs[:2]:
                    content_parts.append(f"- {def_chunk['content']}")
        section_context = self._find_section_context(chunk, all_chunks)
        if section_context:
            content_parts.insert(0, f"Context: {section_context}")
        return "\n".join(content_parts)

    def _find_related_definitions(self, chunk: Dict, all_chunks: List[Dict]) -> List[Dict]:
        definitions = [c for c in all_chunks if c['content_type'] == 'definition']
        related = []
        chunk_content_lower = chunk['content'].lower()
        for def_chunk in definitions:
            term_lower = def_chunk['term'].lower()
            if term_lower in chunk_content_lower:
                related.append(def_chunk)
        return related

    def _find_section_context(self, chunk: Dict, all_chunks: List[Dict]) -> Optional[str]:
        sections = [c for c in all_chunks if c['content_type'] == 'section']
        clause_num = chunk['clause_number']
        if '.' in clause_num:
            section_num = clause_num.split('.')[0]
            for section in sections:
                if section['clause_number'].startswith(f"{section_num}."):
                    return section.get('section_title', '')
        return None

    def _is_duplicate_clause(self, existing_clauses: List[Dict], clause_num: str) -> bool:
        return any(c['clause_number'] == clause_num for c in existing_clauses)

# Chunker (re-instantiated after definition)
chunker = PolicyDocumentChunker()

def build_bm25_index(texts: List[str]) -> BM25Okapi:
    """Builds a BM25 index from a list of text documents."""
    tokenized_corpus = [doc.split() for doc in texts]
    bm25 = BM25Okapi(tokenized_corpus)
    return bm25

def validate_pdf_path(pdf_path: str) -> None:
    path = Path(pdf_path)
    if not path.exists():
        raise FileNotFoundError(f"PDF file not found: {pdf_path}")
    if path.suffix.lower() != '.pdf':
        raise ValueError(f"File must be a PDF: {pdf_path}")

def load_and_chunk_pdf(pdf_path: str) -> List[Dict]:
    loader = PyPDFLoader(pdf_path)
    pages = loader.load()
    full_text = "\n".join([page.page_content for page in pages])
    # Use the chunker from the outer scope
    chunks = chunker.extract_comprehensive_chunks(full_text)
    for i, chunk in enumerate(chunks):
        chunk.update({
            'chunk_id': f"{os.path.basename(pdf_path)}::chunk_{i+1}",
            'source_file': pdf_path,
        })
    return chunks

# Load and chunk JCI, split into smaller pieces, build FAISS and BM25
validate_pdf_path(JCI_PDF)
logger.info("Loading and chunking JCI standards...")
jci_chunks = load_and_chunk_pdf(JCI_PDF)

def split_jci_pieces(chunks, size, overlap):
    splitter_local = RecursiveCharacterTextSplitter(
        chunk_size=size,
        chunk_overlap=overlap,
        separators=["\n\n", "\n", " ", ""],
    )
    out_texts, out_metas = [], []
    for c in chunks:
        content = c.get("content", "")
        if not content:
            continue
        pieces = splitter_local.split_text(content)
        base_meta = {"section": c.get("clause_number", "n/a")}
        for idx, piece in enumerate(pieces):
            piece = piece.strip()
            if len(piece) >= 20:
                out_texts.append(piece)
                out_metas.append({**base_meta, "piece_index": idx})
    return out_texts, out_metas

PRIMARY_SIZE, PRIMARY_OVERLAP = 1500, 200
split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, PRIMARY_SIZE, PRIMARY_OVERLAP)

def build_faiss_in_batches(texts, metadatas, embedder, batch_size=64):
    assert len(texts) == len(metadatas)
    if not texts:
        raise ValueError("No texts to index.")
    first_end = min(batch_size, len(texts))
    db = FAISS.from_texts(texts[:first_end], embedder, metadatas=metadatas[:first_end])
    for i in range(first_end, len(texts), batch_size):
        j = min(i + batch_size, len(texts))
        db.add_texts(texts[i:j], metadatas=metadatas[i:j])
        time.sleep(0.2)
    return db

try:
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=64)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (size {PRIMARY_SIZE})")
except Exception as e:
    logger.warning(f"Primary JCI embedding failed ({e}). Retrying with smaller pieces.")
    FALLBACK_SIZE, FALLBACK_OVERLAP = 900, 150
    split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, FALLBACK_SIZE, FALLBACK_OVERLAP)
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=48)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (fallback size {FALLBACK_SIZE})")

# Build BM25 index
jci_bm25 = build_bm25_index(split_jci_texts)
logger.info(f"BM25 index built with {len(split_jci_texts)} JCI pieces")

NameError: name 'logger' is not defined

**Reasoning**:
The previous command failed because the `logger` object was not defined. This object is initialized in an earlier cell using the `setup_logging` function. To fix this, the `setup_logging` function definition and the initialization of the `logger` object need to be included in the current code block along with the other necessary definitions and the logic for building the indices.



In [12]:
# Setup logging (re-included for context)
def setup_logging() -> logging.Logger:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )
    return logging.getLogger("policy_pipeline")

logger = setup_logging()

# Policy and JCI PDF paths (re-defined for context)
POLICY_PDF = "/content/old Transfer 290523 .pdf"   # <- change as needed
JCI_PDF = "/content/ESS Standards_7th Ed .pdf"  # <- change as needed

# Embeddings (re-instantiated for context)
EMBEDDING_MODEL = "text-embedding-3-small"
embeddings = OpenAIEmbeddings(
    model=EMBEDDING_MODEL,
    api_key=os.getenv("OPENAI_API_KEY"),
)

# PolicyDocumentChunker definition (re-included for context)
class PolicyDocumentChunker:
    def __init__(self):
        self.clause_patterns = [
            r'(\d+\.\d+(?:\.\d+)?)\.\s*(.*?)(?=\n\d+\.\d+(?:\.\d+)?\.|\n[A-Z]+\s*:|\n\n[A-Z]|\Z)',
            r'(\d+\.\d+(?:\.\d+)?)\s+((?:[A-Z][^.]*\..*?)(?=\n\d+\.\d+|\n[A-Z]+\s*:|\Z))',
            r'(\d+\.\d+\.\d+)\.\s*(.*?)(?=\n\d+\.\d+|\n[A-Z]+\s*:|\Z)',
        ]
        self.definition_pattern = r'(\d+\.\d+)\.\s+([A-Z]+(?:\s+[A-Z]+)*)\s+refers\s+to\s+(.*?)(?=\n\d+\.\d+|\n[A-Z]+\s*:|\Z)'
        self.section_header_pattern = r'(\d+)\.\s+([A-Z][A-Z\s]{3,})\s*\n'

    def extract_comprehensive_chunks(self, text: str) -> List[Dict]:
        text = self._preprocess_text(text)
        clauses = self._extract_clauses(text)
        definitions = self._extract_definitions(text)
        sections = self._extract_sections(text)
        all_chunks = clauses + definitions + sections
        enhanced_chunks = self._add_contextual_info(all_chunks, text)
        return enhanced_chunks

    def _preprocess_text(self, text: str) -> str:
        text = re.sub(r'\n\s*\n\s*\n', '\n\n', text)
        text = re.sub(r'(\d+\.\d+(?:\.\d+)?)\s*\.\s*', r'\1. ', text)
        text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text)
        text = re.sub(r'Document\s+No:\s+[A-Z0-9\-]+', '', text)
        return text

    def _extract_clauses(self, text: str) -> List[Dict]:
        clauses = []
        for pattern in self.clause_patterns:
            matches = re.finditer(pattern, text, re.DOTALL | re.MULTILINE)
            for match in matches:
                clause_num = match.group(1)
                content = match.group(2).strip()
                content = re.sub(r'\s+', ' ', content)
                content = re.sub(r'\n+', '\n', content)
                if len(content) > 20 and not self._is_duplicate_clause(clauses, clause_num):
                    clauses.append({
                        'clause_number': clause_num,
                        'content': content,
                        'content_type': 'clause',
                        'word_count': len(content.split()),
                        'char_count': len(content),
                    })
        return clauses

    def _extract_definitions(self, text: str) -> List[Dict]:
        definitions = []
        matches = re.finditer(self.definition_pattern, text, re.DOTALL | re.IGNORECASE)
        for match in matches:
            clause_num = match.group(1)
            term = match.group(2).strip()
            definition = re.sub(r'\s+', ' ', match.group(3).strip())
            definitions.append({
                'clause_number': clause_num,
                'term': term,
                'definition': definition,
                'content': f"{term} refers to {definition}",
                'content_type': 'definition',
                'word_count': len(definition.split()),
                'char_count': len(definition),
            })
        return definitions

    def _extract_sections(self, text: str) -> List[Dict]:
        sections = []
        matches = re.finditer(self.section_header_pattern, text, re.MULTILINE)
        for match in matches:
            section_num = match.group(1)
            section_title = match.group(2).strip()
            start_pos = match.end()
            next_section = re.search(r'\n\d+\.\s+[A-Z][A-Z\s]{3,}', text[start_pos:])
            if next_section:
                section_content = text[start_pos:start_pos + next_section.start()]
            else:
                section_content = text[start_pos:start_pos + 1000]
            section_content = section_content.strip()
            if len(section_content) > 50:
                sections.append({
                    'clause_number': f"{section_num}.0",
                    'section_title': section_title,
                    'content': f"Section {section_num}: {section_title}\n\n{section_content}",
                    'content_type': 'section',
                    'word_count': len(section_content.split()),
                    'char_count': len(section_content),
                })
        return sections

    def _add_contextual_info(self, chunks: List[Dict], full_text: str) -> List[Dict]:
        enhanced_chunks = []
        for chunk in chunks:
            enhanced_content = self._build_contextual_content(chunk, chunks)
            enhanced_chunk = chunk.copy()
            enhanced_chunk['enhanced_content'] = enhanced_content
            enhanced_chunks.append(enhanced_chunk)
        return enhanced_chunks

    def _build_contextual_content(self, chunk: Dict, all_chunks: List[Dict]) -> str:
        content_parts = [chunk['content']]
        if chunk['content_type'] == 'clause':
            related_defs = self._find_related_definitions(chunk, all_chunks)
            if related_defs:
                content_parts.append("Relevant definitions:")
                for def_chunk in related_defs[:2]:
                    content_parts.append(f"- {def_chunk['content']}")
        section_context = self._find_section_context(chunk, all_chunks)
        if section_context:
            content_parts.insert(0, f"Context: {section_context}")
        return "\n".join(content_parts)

    def _find_related_definitions(self, chunk: Dict, all_chunks: List[Dict]) -> List[Dict]:
        definitions = [c for c in all_chunks if c['content_type'] == 'definition']
        related = []
        chunk_content_lower = chunk['content'].lower()
        for def_chunk in definitions:
            term_lower = def_chunk['term'].lower()
            if term_lower in chunk_content_lower:
                related.append(def_chunk)
        return related

    def _find_section_context(self, chunk: Dict, all_chunks: List[Dict]) -> Optional[str]:
        sections = [c for c in all_chunks if c['content_type'] == 'section']
        clause_num = chunk['clause_number']
        if '.' in clause_num:
            section_num = clause_num.split('.')[0]
            for section in sections:
                if section['clause_number'].startswith(f"{section_num}."):
                    return section.get('section_title', '')
        return None

    def _is_duplicate_clause(self, existing_clauses: List[Dict], clause_num: str) -> bool:
        return any(c['clause_number'] == clause_num for c in existing_clauses)

# Chunker (re-instantiated after definition)
chunker = PolicyDocumentChunker()

def build_bm25_index(texts: List[str]) -> BM25Okapi:
    """Builds a BM25 index from a list of text documents."""
    tokenized_corpus = [doc.split() for doc in texts]
    bm25 = BM25Okapi(tokenized_corpus)
    return bm25

def validate_pdf_path(pdf_path: str) -> None:
    path = Path(pdf_path)
    if not path.exists():
        raise FileNotFoundError(f"PDF file not found: {pdf_path}")
    if path.suffix.lower() != '.pdf':
        raise ValueError(f"File must be a PDF: {pdf_path}")

def load_and_chunk_pdf(pdf_path: str) -> List[Dict]:
    loader = PyPDFLoader(pdf_path)
    pages = loader.load()
    full_text = "\n".join([page.page_content for page in pages])
    # Use the chunker from the outer scope
    chunks = chunker.extract_comprehensive_chunks(full_text)
    for i, chunk in enumerate(chunks):
        chunk.update({
            'chunk_id': f"{os.path.basename(pdf_path)}::chunk_{i+1}",
            'source_file': pdf_path,
        })
    return chunks

# Load and chunk JCI, split into smaller pieces, build FAISS and BM25
validate_pdf_path(JCI_PDF)
logger.info("Loading and chunking JCI standards...")
jci_chunks = load_and_chunk_pdf(JCI_PDF)

def split_jci_pieces(chunks, size, overlap):
    splitter_local = RecursiveCharacterTextSplitter(
        chunk_size=size,
        chunk_overlap=overlap,
        separators=["\n\n", "\n", " ", ""],
    )
    out_texts, out_metas = [], []
    for c in chunks:
        content = c.get("content", "")
        if not content:
            continue
        pieces = splitter_local.split_text(content)
        base_meta = {"section": c.get("clause_number", "n/a")}
        for idx, piece in enumerate(pieces):
            piece = piece.strip()
            if len(piece) >= 20:
                out_texts.append(piece)
                out_metas.append({**base_meta, "piece_index": idx})
    return out_texts, out_metas

PRIMARY_SIZE, PRIMARY_OVERLAP = 1500, 200
split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, PRIMARY_SIZE, PRIMARY_OVERLAP)

def build_faiss_in_batches(texts, metadatas, embedder, batch_size=64):
    assert len(texts) == len(metadatas)
    if not texts:
        raise ValueError("No texts to index.")
    first_end = min(batch_size, len(texts))
    db = FAISS.from_texts(texts[:first_end], embedder, metadatas=metadatas[:first_end])
    for i in range(first_end, len(texts), batch_size):
        j = min(i + batch_size, len(texts))
        db.add_texts(texts[i:j], metadatas=metadatas[i:j])
        time.sleep(0.2)
    return db

try:
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=64)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (size {PRIMARY_SIZE})")
except Exception as e:
    logger.warning(f"Primary JCI embedding failed ({e}). Retrying with smaller pieces.")
    FALLBACK_SIZE, FALLBACK_OVERLAP = 900, 150
    split_jci_texts, split_jci_metadatas = split_jci_pieces(jci_chunks, FALLBACK_SIZE, FALLBACK_OVERLAP)
    jci_db = build_faiss_in_batches(split_jci_texts, split_jci_metadatas, embeddings, batch_size=48)
    logger.info(f"FAISS DB built with {len(split_jci_texts)} JCI pieces (fallback size {FALLBACK_SIZE})")

# Build BM25 index
jci_bm25 = build_bm25_index(split_jci_texts)
logger.info(f"BM25 index built with {len(split_jci_texts)} JCI pieces")

## Modify retrieval function

### Subtask:
Update the `retrieve_jci_hybrid` function to perform both BM25 and FAISS searches and combine the results.


**Reasoning**:
Modify the `retrieve_jci_hybrid` function to incorporate BM25 retrieval, combine it with the existing FAISS retrieval, and return a combined list of unique results. This involves adding parameters for the BM25 index and text corpus, performing BM25 search, merging results with FAISS results, and ensuring uniqueness.



In [13]:
def retrieve_jci_hybrid(clause_number: str, clause_text: str, jci_bm25: BM25Okapi, split_jci_texts: List[str], split_jci_metadatas: List[Dict], k: int = 5) -> List[Dict]:
    """
    Retrieves relevant JCI standards using a hybrid approach (Guided/BM25 + FAISS).
    """
    expanded_text = expand_query_with_synonyms(clause_text)
    mapped = clause_to_sections.get(clause_number)

    # Perform BM25 search
    tokenized_query = expanded_text.split()
    bm25_scores = jci_bm25.get_scores(tokenized_query)
    # Get top K indices based on scores
    top_n_indices = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:k]
    bm25_results = [{"text": split_jci_texts[i], "metadata": split_jci_metadatas[i]} for i in top_n_indices]

    # Perform FAISS search
    faiss_docs = jci_db.similarity_search(expanded_text, k=k)
    faiss_results = [{"text": d.page_content, "metadata": d.metadata} for d in faiss_docs]


    # Combine results, prioritizing guided sections if mapped
    combined_results = []
    seen_texts = set()

    # Add BM25 results
    for result in bm25_results:
        if result['text'] not in seen_texts:
            combined_results.append(result)
            seen_texts.add(result['text'])

    # Add FAISS results, prioritizing mapped sections if available
    if mapped:
        # Add FAISS results from mapped sections first
        mapped_faiss_results = [d for d in faiss_results if mapped.lower() in d['metadata'].get('section', '').lower()]
        for result in mapped_faiss_results:
             if result['text'] not in seen_texts:
                combined_results.append(result)
                seen_texts.add(result['text'])

        # Add remaining FAISS results
        remaining_faiss_results = [d for d in faiss_results if mapped.lower() not in d['metadata'].get('section', '').lower()]
        for result in remaining_faiss_results:
             if result['text'] not in seen_texts:
                combined_results.append(result)
                seen_texts.add(result['text'])
    else:
        # If no mapping, add all FAISS results
        for result in faiss_results:
            if result['text'] not in seen_texts:
                combined_results.append(result)
                seen_texts.add(result['text'])


    # Return top K unique results
    return combined_results[:k]


## Combine results

### Subtask:
Implement a strategy to combine the results from BM25 and FAISS (e.g., re-ranking, reciprocal rank fusion).


**Reasoning**:
Modify the `retrieve_jci_hybrid` function to implement Reciprocal Rank Fusion (RRF) for combining BM25 and FAISS results, ensuring the function returns the top K unique results after re-ranking.



In [14]:
def reciprocal_rank_fusion(results: List[List[Dict]], k: int = 60) -> List[Dict]:
    """
    Applies Reciprocal Rank Fusion to a list of ranked lists of results.

    Args:
        results: A list where each element is a ranked list of dictionaries,
                 each dictionary containing 'text' and 'metadata'.
        k: The constant used in the RRF formula (1/ (k + rank)).

    Returns:
        A single list of unique results ranked by their RRF score.
    """
    fused_scores = {}
    document_map = {}

    for rank_list in results:
        for rank, doc in enumerate(rank_list):
            text = doc['text']
            if text not in fused_scores:
                fused_scores[text] = 0.0
                document_map[text] = doc # Store the original document object

            # RRF formula: 1 / (k + rank)
            score = 1.0 / (k + rank + 1)  # +1 because rank is 0-based
            fused_scores[text] += score

    # Sort documents by fused score in descending order
    sorted_docs = sorted(fused_scores.items(), key=lambda item: item[1], reverse=True)

    # Return the re-ranked documents
    return [document_map[text] for text, score in sorted_docs]


def retrieve_jci_hybrid(clause_number: str, clause_text: str, jci_bm25: BM25Okapi, split_jci_texts: List[str], split_jci_metadatas: List[Dict], k: int = 5) -> List[Dict]:
    """
    Retrieves relevant JCI standards using a hybrid approach (BM25 + FAISS) and
    combines results using Reciprocal Rank Fusion (RRF).
    """
    expanded_text = expand_query_with_synonyms(clause_text)

    # Perform BM25 search
    tokenized_query = expanded_text.split()
    bm25_scores = jci_bm25.get_scores(tokenized_query)
    # Get top N indices for RRF (N is typically larger than final K)
    N_rrf = 2 * k # Retrieve more documents for RRF
    top_n_indices_bm25 = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:N_rrf]
    bm25_results = [{"text": split_jci_texts[i], "metadata": split_jci_metadatas[i]} for i in top_n_indices_bm25]

    # Perform FAISS search
    faiss_docs = jci_db.similarity_search(expanded_text, k=N_rrf) # Retrieve more documents for RRF
    faiss_results = [{"text": d.page_content, "metadata": d.metadata} for d in faiss_docs]

    # Combine and re-rank using RRF
    combined_and_reranked = reciprocal_rank_fusion([bm25_results, faiss_results], k=60)

    # Apply guided retrieval filter *after* RRF for potential boost (optional, or prior to RRF)
    # For simplicity here, we'll just take the top K RRF results.
    # A more complex approach could boost RRF scores based on guided mapping.

    # Filter by mapped sections if a strong preference is needed (Alternative to RRF score boosting)
    # mapped = clause_to_sections.get(clause_number)
    # if mapped:
    #     filtered_reranked = [doc for doc in combined_and_reranked if mapped.lower() in doc['metadata'].get('section', '').lower()]
    #     # Take top K from filtered, then fill with top from original reranked if needed
    #     final_results = filtered_reranked[:k]
    #     if len(final_results) < k:
    #         seen_texts = {doc['text'] for doc in final_results}
    #         for doc in combined_and_reranked:
    #             if doc['text'] not in seen_texts and len(final_results) < k:
    #                 final_results.append(doc)
    # else:
    #     final_results = combined_and_reranked[:k]

    # Returning top K from RRF
    return combined_and_reranked[:k]

## Test and evaluate

### Subtask:
Run the analysis with the updated retrieval and potentially add evaluation metrics to assess the impact of hybrid retrieval.


**Reasoning**:
Execute the code cell that runs the batch analysis to use the updated hybrid retrieval function.



In [15]:
# Run batches and save
all_results: List[ComplianceAnalysis] = []

# Re-initialize the analyzer to use the latest model and potentially pick up new environment variables
analyzer = ClauseComparisonAnalyzer(api_key=os.getenv("OPENAI_API_KEY"), model=MODEL)

# Use the correct function signature for retrieve_jci_hybrid
def batch_analyze_clauses_with_progress(chunks: List[Dict], bm25_index: BM25Okapi, bm25_texts: List[str], bm25_metadatas: List[Dict]) -> List[ComplianceAnalysis]:
    results: List[ComplianceAnalysis] = []
    work_items = [c for c in chunks if c.get('clause_number') and c['clause_number'] != 'unknown']

    with tqdm(total=len(work_items), desc="Analyzing clauses") as pbar:
        for chunk in work_items:
            clause_number = chunk['clause_number']
            clause_text = chunk.get('enhanced_content', chunk.get('content', ''))
            # Pass the necessary BM25 index components
            jci_refs = retrieve_jci_hybrid(clause_number, clause_text, bm25_index, bm25_texts, bm25_metadatas)
            try:
                analysis = analyzer.compare_clause_to_jci(clause_number, clause_text, jci_refs)
                results.append(analysis)
            except Exception as e:
                logger.error(f"Analysis failed for clause {clause_number}: {e}")
                results.append(
                    ComplianceAnalysis(
                        clause_number=clause_number,
                        compliance_status=ComplianceStatus.INSUFFICIENT_INFO,
                        confidence_score=0.0,
                        key_gaps=[f"Analysis failed: {str(e)}"],
                        required_changes=["Retry analysis"],
                        jci_references=[],
                        risk_level="Unknown",
                        full_analysis=f"Error: {str(e)}"
                    )
                )
            time.sleep(RATE_LIMIT_DELAY)
            pbar.update(1)
    return results

# Call the batch analysis function with the BM25 index components
# Ensure split_jci_texts and split_jci_metadatas are accessible from the scope
# If not, you might need to re-run the cell that creates them or pass them explicitly.
# Assuming they are accessible from the previous execution:
for i in range(0, len(policy_chunks), BATCH_SIZE):
    batch = policy_chunks[i:i+BATCH_SIZE]
    logger.info(f"Processing batch {i//BATCH_SIZE + 1} with {len(batch)} clauses...")
    batch_results = batch_analyze_clauses_with_progress(batch, jci_bm25, split_jci_texts, split_jci_metadatas)
    all_results.extend(batch_results)

logger.info(f"Total analyzed: {len(all_results)}")

# Save to Excel
def save_comparison_results(analyses: List[ComplianceAnalysis], output_file: str):
    data = []
    for analysis in analyses:
        data.append({
            'clause_number': analysis.clause_number,
            'compliance_status': analysis.compliance_status.value,
            'confidence_score': f"{analysis.confidence_score:.0%}",
            'risk_level': analysis.risk_level,
            'gaps_identified': ' | '.join(analysis.key_gaps),
            'required_changes': ' | '.join(analysis.required_changes),
            'jci_standards_cited': ' | '.join(analysis.jci_references),
            'detailed_analysis': analysis.full_analysis
        })

    summary = {
        'total_analyzed': len(analyses),
        'compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.COMPLIANT]),
        'partially_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.PARTIALLY_COMPLIANT]),
        'non_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.NON_COMPLIANT]),
        'high_risk': len([a for a in analyses if 'high' in a.risk_level.lower()]),
    }

    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        pd.DataFrame(data).to_excel(writer, sheet_name='Clause Analysis', index=False)
        pd.DataFrame([summary]).to_excel(writer, sheet_name='Summary', index=False)

save_comparison_results(all_results, "policy_vs_jci_hybrid.xlsx")
logger.info("Results saved to policy_vs_jci_hybrid.xlsx")

NameError: name 'ComplianceAnalysis' is not defined

**Reasoning**:
The previous command failed because `ComplianceAnalysis` was not defined. This indicates the class definition was not available in the current execution scope. The code cell defining `ComplianceAnalysis` needs to be re-executed or its definition included in the current cell. Since the definition is somewhat long, re-executing the cell that defines it is more efficient. I will also include the definitions of `ComplianceStatus` and `ClauseComparisonAnalyzer` as they are likely in the same original cell and are also needed.



In [16]:
class ComplianceStatus(Enum):
    COMPLIANT = "Compliant"
    PARTIALLY_COMPLIANT = "Partially Compliant"
    NON_COMPLIANT = "Non-Compliant"
    INSUFFICIENT_INFO = "Insufficient Information"


@dataclass
class ComplianceAnalysis:
    clause_number: str
    compliance_status: ComplianceStatus
    confidence_score: float
    key_gaps: List[str]
    required_changes: List[str]
    jci_references: List[str]
    risk_level: str
    full_analysis: str


class ClauseComparisonAnalyzer:
    def __init__(self, api_key: str, model: str = "gpt-4o"):
        self.client = OpenAI(api_key=api_key)
        self.model = model
        self.config = {
            'max_tokens': 1000,
            'temperature': 0.1,
            'top_p': 0.9,
        }

    def compare_clause_to_jci(self, clause_number: str, clause_text: str, jci_references: List[Dict]) -> ComplianceAnalysis:
        prompt = self._build_comparison_prompt(clause_number, clause_text, jci_references)
        response = self._call_gpt4o(prompt)
        analysis = self._parse_response(response, clause_number)
        return analysis

    def _build_comparison_prompt(self, clause_number: str, clause_text: str, jci_references: List[Dict]) -> str:
        jci_text = self._format_jci_references(jci_references)
        return f"""You are a healthcare compliance expert specializing in JCI hospital accreditation standards.

**TASK:** Compare the policy clause against JCI standards and identify specific changes needed.

**POLICY CLAUSE {clause_number}:**
{clause_text}

**RELEVANT JCI STANDARDS:**
{jci_text}

**ANALYSIS FRAMEWORK:**

1. **REQUIREMENT COMPARISON:** Compare each requirement in the policy clause against the JCI standards above.

2. **GAP IDENTIFICATION:** Identify specific gaps between policy and JCI requirements.

3. **CHANGE SPECIFICATION:** For each gap, specify the exact change needed to achieve compliance.

**PROVIDE STRUCTURED OUTPUT:**

**COMPLIANCE STATUS:** [Compliant/Partially Compliant/Non-Compliant]

**CONFIDENCE LEVEL:** [0-100%]

**SPECIFIC GAPS IDENTIFIED:**
• [Gap 1]: [Specific difference between policy and JCI standard]
• [Gap 2]: [Specific difference between policy and JCI standard]

**REQUIRED CHANGES:**
• [Change 1]: [Exact modification needed] → [JCI Standard this addresses]
• [Change 2]: [Exact modification needed] → [JCI Standard this addresses]

**RISK LEVEL:** [High/Medium/Low]

**JCI STANDARDS REFERENCED:**
• [Standard]: [Specific requirement]

Focus on actionable, specific changes rather than general recommendations."""

    def _format_jci_references(self, references: List[Dict]) -> str:
        if not references:
            return "No specific JCI references available for comparison."
        formatted = []
        for i, ref in enumerate(references, 1):
            text = ref.get('text', '')
            if len(text) > 300:
                text = text[:300] + "..."
            formatted.append(f"**JCI Reference {i}:**\n{text}")
        return "\n\n".join(formatted)

    @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
    def _call_gpt4o(self, prompt: str) -> str:
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {
                    "role": "system",
                    "content": "You are a senior healthcare compliance consultant. Provide specific, actionable analysis focused on exact changes needed for JCI compliance."
                },
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            **self.config
        )
        return response.choices[0].message.content

    def _parse_response(self, response: str, clause_number: str) -> ComplianceAnalysis:
        import re as _re
        compliance_match = _re.search(r'\*\*COMPLIANCE STATUS:\*\*\s*([^\n]+)', response, _re.IGNORECASE)
        compliance_status = ComplianceStatus.INSUFFICIENT_INFO
        if compliance_match:
            status_text = compliance_match.group(1).strip().lower()
            if "non-compliant" in status_text:
                compliance_status = ComplianceStatus.NON_COMPLIANT
            elif "partially compliant" in status_text:
                compliance_status = ComplianceStatus.PARTIALLY_COMPLIANT
            elif "compliant" in status_text:
                compliance_status = ComplianceStatus.COMPLIANT

        confidence_match = _re.search(r'\*\*CONFIDENCE LEVEL:\*\*\s*(\d+)%', response)
        confidence_score = float(confidence_match.group(1)) / 100 if confidence_match else 0.0

        gaps_section = _re.search(r'\*\*SPECIFIC GAPS IDENTIFIED:\*\*(.*?)(?=\*\*[A-Z ]+:\*\*|\Z)', response, _re.DOTALL)
        key_gaps = []
        if gaps_section:
            gaps_text = gaps_section.group(1)
            key_gaps = [g.strip() for g in _re.findall(r'•\s*([^\n•]+)', gaps_text) if g.strip()]

        changes_section = _re.search(r'\*\*REQUIRED CHANGES:\*\*(.*?)(?=\*\*[A-Z ]+:\*\*|\Z)', response, _re.DOTALL)
        required_changes = []
        if changes_section:
            changes_text = changes_section.group(1)
            required_changes = [c.strip() for c in _re.findall(r'•\s*([^\n•]+)', changes_text) if c.strip()]

        jci_section = _re.search(r'\*\*JCI STANDARDS REFERENCED:\*\*(.*?)(?=\*\*[A-Z ]+:\*\*|\Z)', response, _re.DOTALL)
        jci_references = []
        if jci_section:
            jci_text = jci_section.group(1)
            jci_references = [r.strip() for r in _re.findall(r'•\s*([^\n•]+)', jci_text) if r.strip()]

        risk_match = _re.search(r'\*\*RISK LEVEL:\*\*\s*([^\n]+)', response, _re.IGNORECASE)
        risk_level = risk_match.group(1).strip() if risk_match else "Unknown"

        return ComplianceAnalysis(
            clause_number=clause_number,
            compliance_status=compliance_status,
            confidence_score=confidence_score,
            key_gaps=key_gaps,
            required_changes=required_changes,
            jci_references=jci_references,
            risk_level=risk_level,
            full_analysis=response
        )


**Reasoning**:
The `ComplianceAnalysis` class and other necessary definitions are now available. I will re-execute the code cell for batch analysis, ensuring the `retrieve_jci_hybrid` function is called with the correct arguments (including the BM25 index components).



In [17]:
# Run batches and save
all_results: List[ComplianceAnalysis] = []

# Re-initialize the analyzer to use the latest model and potentially pick up new environment variables
analyzer = ClauseComparisonAnalyzer(api_key=os.getenv("OPENAI_API_KEY"), model=MODEL)

# Use the correct function signature for retrieve_jci_hybrid
def batch_analyze_clauses_with_progress(chunks: List[Dict], bm25_index: BM25Okapi, bm25_texts: List[str], bm25_metadatas: List[Dict]) -> List[ComplianceAnalysis]:
    results: List[ComplianceAnalysis] = []
    work_items = [c for c in chunks if c.get('clause_number') and c['clause_number'] != 'unknown']

    with tqdm(total=len(work_items), desc="Analyzing clauses") as pbar:
        for chunk in work_items:
            clause_number = chunk['clause_number']
            clause_text = chunk.get('enhanced_content', chunk.get('content', ''))
            # Pass the necessary BM25 index components
            jci_refs = retrieve_jci_hybrid(clause_number, clause_text, bm25_index, bm25_texts, bm25_metadatas)
            try:
                analysis = analyzer.compare_clause_to_jci(clause_number, clause_text, jci_refs)
                results.append(analysis)
            except Exception as e:
                logger.error(f"Analysis failed for clause {clause_number}: {e}")
                results.append(
                    ComplianceAnalysis(
                        clause_number=clause_number,
                        compliance_status=ComplianceStatus.INSUFFICIENT_INFO,
                        confidence_score=0.0,
                        key_gaps=[f"Analysis failed: {str(e)}"],
                        required_changes=["Retry analysis"],
                        jci_references=[],
                        risk_level="Unknown",
                        full_analysis=f"Error: {str(e)}"
                    )
                )
            time.sleep(RATE_LIMIT_DELAY)
            pbar.update(1)
    return results

# Call the batch analysis function with the BM25 index components
# Ensure split_jci_texts and split_jci_metadatas are accessible from the scope
# If not, you might need to re-run the cell that creates them or pass them explicitly.
# Assuming they are accessible from the previous execution:
for i in range(0, len(policy_chunks), BATCH_SIZE):
    batch = policy_chunks[i:i+BATCH_SIZE]
    logger.info(f"Processing batch {i//BATCH_SIZE + 1} with {len(batch)} clauses...")
    batch_results = batch_analyze_clauses_with_progress(batch, jci_bm25, split_jci_texts, split_jci_metadatas)
    all_results.extend(batch_results)

logger.info(f"Total analyzed: {len(all_results)}")

# Save to Excel
def save_comparison_results(analyses: List[ComplianceAnalysis], output_file: str):
    data = []
    for analysis in analyses:
        data.append({
            'clause_number': analysis.clause_number,
            'compliance_status': analysis.compliance_status.value,
            'confidence_score': f"{analysis.confidence_score:.0%}",
            'risk_level': analysis.risk_level,
            'gaps_identified': ' | '.join(analysis.key_gaps),
            'required_changes': ' | '.join(analysis.required_changes),
            'jci_standards_cited': ' | '.join(analysis.jci_references),
            'detailed_analysis': analysis.full_analysis
        })

    summary = {
        'total_analyzed': len(analyses),
        'compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.COMPLIANT]),
        'partially_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.PARTIALLY_COMPLIANT]),
        'non_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.NON_COMPLIANT]),
        'high_risk': len([a for a in analyses if 'high' in a.risk_level.lower()]),
    }

    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        pd.DataFrame(data).to_excel(writer, sheet_name='Clause Analysis', index=False)
        pd.DataFrame([summary]).to_excel(writer, sheet_name='Summary', index=False)

save_comparison_results(all_results, "policy_vs_jci_hybrid.xlsx")
logger.info("Results saved to policy_vs_jci_hybrid.xlsx")

NameError: name 'MODEL' is not defined

**Reasoning**:
The previous command failed because `MODEL` was not defined. This variable is used to specify the language model for the `ClauseComparisonAnalyzer`. I need to define this variable before initializing the analyzer. I will also define `RATE_LIMIT_DELAY` and `BATCH_SIZE` as they were likely defined in the same original cell and are used later in the code block. Then I will re-execute the batch analysis code.



In [18]:
# Analysis config and batch run
MODEL = "gpt-4o"
RATE_LIMIT_DELAY = 2.0
BATCH_SIZE = 10

# Run batches and save
all_results: List[ComplianceAnalysis] = []

# Re-initialize the analyzer to use the latest model and potentially pick up new environment variables
analyzer = ClauseComparisonAnalyzer(api_key=os.getenv("OPENAI_API_KEY"), model=MODEL)

# Use the correct function signature for retrieve_jci_hybrid
def batch_analyze_clauses_with_progress(chunks: List[Dict], bm25_index: BM25Okapi, bm25_texts: List[str], bm25_metadatas: List[Dict]) -> List[ComplianceAnalysis]:
    results: List[ComplianceAnalysis] = []
    work_items = [c for c in chunks if c.get('clause_number') and c['clause_number'] != 'unknown']

    with tqdm(total=len(work_items), desc="Analyzing clauses") as pbar:
        for chunk in work_items:
            clause_number = chunk['clause_number']
            clause_text = chunk.get('enhanced_content', chunk.get('content', ''))
            # Pass the necessary BM25 index components
            jci_refs = retrieve_jci_hybrid(clause_number, clause_text, bm25_index, bm25_texts, bm25_metadatas)
            try:
                analysis = analyzer.compare_clause_to_jci(clause_number, clause_text, jci_refs)
                results.append(analysis)
            except Exception as e:
                logger.error(f"Analysis failed for clause {clause_number}: {e}")
                results.append(
                    ComplianceAnalysis(
                        clause_number=clause_number,
                        compliance_status=ComplianceStatus.INSUFFICIENT_INFO,
                        confidence_score=0.0,
                        key_gaps=[f"Analysis failed: {str(e)}"],
                        required_changes=["Retry analysis"],
                        jci_references=[],
                        risk_level="Unknown",
                        full_analysis=f"Error: {str(e)}"
                    )
                )
            time.sleep(RATE_LIMIT_DELAY)
            pbar.update(1)
    return results

# Call the batch analysis function with the BM25 index components
# Ensure split_jci_texts and split_jci_metadatas are accessible from the scope
# If not, you might need to re-run the cell that creates them or pass them explicitly.
# Assuming they are accessible from the previous execution:
for i in range(0, len(policy_chunks), BATCH_SIZE):
    batch = policy_chunks[i:i+BATCH_SIZE]
    logger.info(f"Processing batch {i//BATCH_SIZE + 1} with {len(batch)} clauses...")
    batch_results = batch_analyze_clauses_with_progress(batch, jci_bm25, split_jci_texts, split_jci_metadatas)
    all_results.extend(batch_results)

logger.info(f"Total analyzed: {len(all_results)}")

# Save to Excel
def save_comparison_results(analyses: List[ComplianceAnalysis], output_file: str):
    data = []
    for analysis in analyses:
        data.append({
            'clause_number': analysis.clause_number,
            'compliance_status': analysis.compliance_status.value,
            'confidence_score': f"{analysis.confidence_score:.0%}",
            'risk_level': analysis.risk_level,
            'gaps_identified': ' | '.join(analysis.key_gaps),
            'required_changes': ' | '.join(analysis.required_changes),
            'jci_standards_cited': ' | '.join(analysis.jci_references),
            'detailed_analysis': analysis.full_analysis
        })

    summary = {
        'total_analyzed': len(analyses),
        'compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.COMPLIANT]),
        'partially_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.PARTIALLY_COMPLIANT]),
        'non_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.NON_COMPLIANT]),
        'high_risk': len([a for a in analyses if 'high' in a.risk_level.lower()]),
    }

    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        pd.DataFrame(data).to_excel(writer, sheet_name='Clause Analysis', index=False)
        pd.DataFrame([summary]).to_excel(writer, sheet_name='Summary', index=False)

save_comparison_results(all_results, "policy_vs_jci_hybrid.xlsx")
logger.info("Results saved to policy_vs_jci_hybrid.xlsx")

NameError: name 'policy_chunks' is not defined

**Reasoning**:
The previous command failed because `policy_chunks` was not defined. This variable holds the chunks of the policy document and is needed to iterate through the clauses for analysis. I need to re-execute the code that loads and chunks the policy document to make `policy_chunks` available in the current scope. I will include the definition of `load_and_chunk_pdf` and `validate_pdf_path` as they are needed for loading and chunking, as well as the `POLICY_PDF` variable and the `logger`. I will also include the `chunker` object initialization as it's used by `load_and_chunk_pdf`.



In [19]:
# Load policy PDF, chunk, and save chunking results
# Define necessary variables and functions if not in scope
POLICY_PDF = "/content/old Transfer 290523 .pdf" # <- change as needed

# Re-define logger and chunker if necessary
# Assuming setup_logging and PolicyDocumentChunker are defined in previous cells
# logger = setup_logging() # Assuming this was run previously
# chunker = PolicyDocumentChunker() # Assuming this was run previously

def validate_pdf_path(pdf_path: str) -> None:
    path = Path(pdf_path)
    if not path.exists():
        raise FileNotFoundError(f"PDF file not found: {pdf_path}")
    if path.suffix.lower() != '.pdf':
        raise ValueError(f"File must be a PDF: {pdf_path}")

def load_and_chunk_pdf(pdf_path: str) -> List[Dict]:
    loader = PyPDFLoader(pdf_path)
    pages = loader.load()
    full_text = "\n".join([page.page_content for page in pages])
    # Use the chunker from the outer scope
    chunks = chunker.extract_comprehensive_chunks(full_text)
    for i, chunk in enumerate(chunks):
        chunk.update({
            'chunk_id': f"{os.path.basename(pdf_path)}::chunk_{i+1}",
            'source_file': pdf_path,
        })
    return chunks


validate_pdf_path(POLICY_PDF)
logger.info("Loading and chunking policy document...")
policy_chunks = load_and_chunk_pdf(POLICY_PDF)
# policy_analysis = analyze_chunking_results(policy_chunks) # No need to re-calculate analysis here
logger.info(f"Chunking complete: {len(policy_chunks)} chunks found.")

# No need to save chunking results again unless they changed, which they didn't
# pd.DataFrame(policy_chunks).to_excel("enhanced_chunking_results.xlsx", index=False)
# logger.info("Saved enhanced_chunking_results.xlsx")


# Analysis config and batch run (re-included from previous attempt)
MODEL = "gpt-4o"
RATE_LIMIT_DELAY = 2.0
BATCH_SIZE = 10

# Run batches and save
all_results: List[ComplianceAnalysis] = []

# Re-initialize the analyzer to use the latest model and potentially pick up new environment variables
analyzer = ClauseComparisonAnalyzer(api_key=os.getenv("OPENAI_API_KEY"), model=MODEL)

# Use the correct function signature for retrieve_jci_hybrid
def batch_analyze_clauses_with_progress(chunks: List[Dict], bm25_index: BM25Okapi, bm25_texts: List[str], bm25_metadatas: List[Dict]) -> List[ComplianceAnalysis]:
    results: List[ComplianceAnalysis] = []
    work_items = [c for c in chunks if c.get('clause_number') and c['clause_number'] != 'unknown']

    with tqdm(total=len(work_items), desc="Analyzing clauses") as pbar:
        for chunk in work_items:
            clause_number = chunk['clause_number']
            clause_text = chunk.get('enhanced_content', chunk.get('content', ''))
            # Pass the necessary BM25 index components
            # Assuming jci_bm25, split_jci_texts, split_jci_metadatas are accessible
            jci_refs = retrieve_jci_hybrid(clause_number, clause_text, jci_bm25, split_jci_texts, split_jci_metadatas)
            try:
                analysis = analyzer.compare_clause_to_jci(clause_number, clause_text, jci_refs)
                results.append(analysis)
            except Exception as e:
                logger.error(f"Analysis failed for clause {clause_number}: {e}")
                results.append(
                    ComplianceAnalysis(
                        clause_number=clause_number,
                        compliance_status=ComplianceStatus.INSUFFICIENT_INFO,
                        confidence_score=0.0,
                        key_gaps=[f"Analysis failed: {str(e)}"],
                        required_changes=["Retry analysis"],
                        jci_references=[],
                        risk_level="Unknown",
                        full_analysis=f"Error: {str(e)}"
                    )
                )
            time.sleep(RATE_LIMIT_DELAY)
            pbar.update(1)
    return results

# Call the batch analysis function with the BM25 index components
# Ensure jci_bm25, split_jci_texts, and split_jci_metadatas are accessible from the scope.
# Assuming they are accessible from the previous execution where the BM25 index was built.
for i in range(0, len(policy_chunks), BATCH_SIZE):
    batch = policy_chunks[i:i+BATCH_SIZE]
    logger.info(f"Processing batch {i//BATCH_SIZE + 1} with {len(batch)} clauses...")
    batch_results = batch_analyze_clauses_with_progress(batch, jci_bm25, split_jci_texts, split_jci_metadatas)
    all_results.extend(batch_results)

logger.info(f"Total analyzed: {len(all_results)}")

# Save to Excel
def save_comparison_results(analyses: List[ComplianceAnalysis], output_file: str):
    data = []
    for analysis in analyses:
        data.append({
            'clause_number': analysis.clause_number,
            'compliance_status': analysis.compliance_status.value,
            'confidence_score': f"{analysis.confidence_score:.0%}",
            'risk_level': analysis.risk_level,
            'gaps_identified': ' | '.join(analysis.key_gaps),
            'required_changes': ' | '.join(analysis.required_changes),
            'jci_standards_cited': ' | '.join(analysis.jci_references),
            'detailed_analysis': analysis.full_analysis
        })

    summary = {
        'total_analyzed': len(analyses),
        'compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.COMPLIANT]),
        'partially_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.PARTIALLY_COMPLIANT]),
        'non_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.NON_COMPLIANT]),
        'high_risk': len([a for a in analyses if 'high' in a.risk_level.lower()]),
    }

    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        pd.DataFrame(data).to_excel(writer, sheet_name='Clause Analysis', index=False)
        pd.DataFrame([summary]).to_excel(writer, sheet_name='Summary', index=False)

save_comparison_results(all_results, "policy_vs_jci_hybrid.xlsx")
logger.info("Results saved to policy_vs_jci_hybrid.xlsx")

Analyzing clauses:   0%|          | 0/10 [00:00<?, ?it/s]


NameError: name 'expand_query_with_synonyms' is not defined

**Reasoning**:
The previous command failed because `expand_query_with_synonyms` was not defined. This function is used within `retrieve_jci_hybrid` to expand the query using synonyms. I need to re-execute the code cell that defines this function and the `synonyms` dictionary, as well as the `clause_to_sections` dictionary which is also used in `retrieve_jci_hybrid`. Then I will re-execute the batch analysis code block.



In [20]:
# Guided retrieval configuration (re-included)
clause_to_sections = {
    "4.1": "ACC.01.00 ME1",
    "4.3": "ACC.04.05 ME1",
    "4.4": "ACC.01.00 ME3",
    "4.5": "ACC.01.00 ME2",
    "4.6": "ACC.03.00 ME1",
    "4.7": "ACC.03.00 ME3",
    "4.8": "ACC.03.01 ME3",
    "5.1.1": "ACC.03.01 ME2 and ME3",
    "5.1.2": "ACC.03.00 ME1",
    "5.1.3": "ACC.01.00 ME5, ACC.03.00 ME2, ME3 and ME4",
    "5.1.4": "ACC.03.00 ME6, ACC.02.02 ME2",
    "5.1.5": "ACC.03.00 ME6",
    "5.1.8": "ACC.03.00 ME5, ACC.03.01 ME1, ME2 and ME3",
    "5.3.1": "ACC.01.00 ME6",
    "5.3.2": "ACC.03.01 ME1, ME2 and ME3",
    "5.3.3": "ACC.03.01 ME3",
    "5.3.4": "ACC.03.01 ME3",
    "5.4.1": "ACC.03.01 ME3",
    "5.5.1": "ACC.03.01 ME1 and ME3",
    "5.5.2": "ACC.03.00 ME6",
    "5.6": "ACC.03.01 ME3",
    "5.6.1": "Ministry of Health Circular No. MH 53:08/4 vol 6 – Guidelines for Inter-Hospital Transfer",
    "5.6.2": "ACC.03.00 ME6",
    "5.6.3": "ACC.03.00 ME5",
    "5.6.4": "ACC.03.01 ME1, ME2 and ME3",
    "5.7": "ACC.03.01 ME and ME3",
}

synonyms = {
    "transfer": ["handover", "relocation", "patient movement"],
    "policy": ["guideline", "procedure"],
    "emergency": ["urgent", "critical", "immediate"],
}

def expand_query_with_synonyms(query: str) -> str:
    words = query.split()
    expanded = []
    for w in words:
        expanded.append(w)
        if w.lower() in synonyms:
            expanded.extend(synonyms[w.lower()])
    return " ".join(expanded)

# Define reciprocal_rank_fusion and retrieve_jci_hybrid if not in scope
# Assuming reciprocal_rank_fusion and retrieve_jci_hybrid are defined in a previous cell
# def reciprocal_rank_fusion(...)
# def retrieve_jci_hybrid(...)


# Load policy PDF, chunk, and save chunking results (re-included)
# Define necessary variables and functions if not in scope
POLICY_PDF = "/content/old Transfer 290523 .pdf" # <- change as needed

# Re-define logger and chunker if necessary
# Assuming setup_logging and PolicyDocumentChunker are defined in previous cells
# logger = setup_logging() # Assuming this was run previously
# chunker = PolicyDocumentChunker() # Assuming this was run previously

def validate_pdf_path(pdf_path: str) -> None:
    path = Path(pdf_path)
    if not path.exists():
        raise FileNotFoundError(f"PDF file not found: {pdf_path}")
    if path.suffix.lower() != '.pdf':
        raise ValueError(f"File must be a PDF: {pdf_path}")

def load_and_chunk_pdf(pdf_path: str) -> List[Dict]:
    loader = PyPDFLoader(pdf_path)
    pages = loader.load()
    full_text = "\n".join([page.page_content for page in pages])
    # Use the chunker from the outer scope
    chunks = chunker.extract_comprehensive_chunks(full_text)
    for i, chunk in enumerate(chunks):
        chunk.update({
            'chunk_id': f"{os.path.basename(pdf_path)}::chunk_{i+1}",
            'source_file': pdf_path,
        })
    return chunks


validate_pdf_path(POLICY_PDF)
logger.info("Loading and chunking policy document...")
policy_chunks = load_and_chunk_pdf(POLICY_PDF)
# policy_analysis = analyze_chunking_results(policy_chunks) # No need to re-calculate analysis here
logger.info(f"Chunking complete: {len(policy_chunks)} chunks found.")

# No need to save chunking results again unless they changed, which they didn't
# pd.DataFrame(policy_chunks).to_excel("enhanced_chunking_results.xlsx", index=False)
# logger.info("Saved enhanced_chunking_results.xlsx")


# Analysis config and batch run (re-included from previous attempt)
MODEL = "gpt-4o"
RATE_LIMIT_DELAY = 2.0
BATCH_SIZE = 10

# Run batches and save
all_results: List[ComplianceAnalysis] = []

# Re-initialize the analyzer to use the latest model and potentially pick up new environment variables
# Assuming ComplianceAnalysis and ClauseComparisonAnalyzer are defined in a previous cell
# analyzer = ClauseComparisonAnalyzer(api_key=os.getenv("OPENAI_API_KEY"), model=MODEL) # Assuming this was run previously

# Use the correct function signature for retrieve_jci_hybrid
def batch_analyze_clauses_with_progress(chunks: List[Dict], bm25_index: BM25Okapi, bm25_texts: List[str], bm25_metadatas: List[Dict]) -> List[ComplianceAnalysis]:
    results: List[ComplianceAnalysis] = []
    work_items = [c for c in chunks if c.get('clause_number') and c['clause_number'] != 'unknown']

    with tqdm(total=len(work_items), desc="Analyzing clauses") as pbar:
        for chunk in work_items:
            clause_number = chunk['clause_number']
            clause_text = chunk.get('enhanced_content', chunk.get('content', ''))
            # Pass the necessary BM25 index components
            # Assuming jci_bm25, split_jci_texts, split_jci_metadatas are accessible
            jci_refs = retrieve_jci_hybrid(clause_number, clause_text, jci_bm25, split_jci_texts, split_jci_metadatas)
            try:
                analysis = analyzer.compare_clause_to_jci(clause_number, clause_text, jci_refs)
                results.append(analysis)
            except Exception as e:
                logger.error(f"Analysis failed for clause {clause_number}: {e}")
                results.append(
                    ComplianceAnalysis(
                        clause_number=clause_number,
                        compliance_status=ComplianceStatus.INSUFFICIENT_INFO,
                        confidence_score=0.0,
                        key_gaps=[f"Analysis failed: {str(e)}"],
                        required_changes=["Retry analysis"],
                        jci_references=[],
                        risk_level="Unknown",
                        full_analysis=f"Error: {str(e)}"
                    )
                )
            time.sleep(RATE_LIMIT_DELAY)
            pbar.update(1)
    return results

# Call the batch analysis function with the BM25 index components
# Ensure jci_bm25, split_jci_texts, and split_jci_metadatas are accessible from the scope.
# Assuming they are accessible from the previous execution where the BM25 index was built.
for i in range(0, len(policy_chunks), BATCH_SIZE):
    batch = policy_chunks[i:i+BATCH_SIZE]
    logger.info(f"Processing batch {i//BATCH_SIZE + 1} with {len(batch)} clauses...")
    batch_results = batch_analyze_clauses_with_progress(batch, jci_bm25, split_jci_texts, split_jci_metadatas)
    all_results.extend(batch_results)

logger.info(f"Total analyzed: {len(all_results)}")

# Save to Excel
def save_comparison_results(analyses: List[ComplianceAnalysis], output_file: str):
    data = []
    for analysis in analyses:
        data.append({
            'clause_number': analysis.clause_number,
            'compliance_status': analysis.compliance_status.value,
            'confidence_score': f"{analysis.confidence_score:.0%}",
            'risk_level': analysis.risk_level,
            'gaps_identified': ' | '.join(analysis.key_gaps),
            'required_changes': ' | '.join(analysis.required_changes),
            'jci_standards_cited': ' | '.join(analysis.jci_references),
            'detailed_analysis': analysis.full_analysis
        })

    summary = {
        'total_analyzed': len(analyses),
        'compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.COMPLIANT]),
        'partially_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.PARTIALLY_COMPLIANT]),
        'non_compliant': len([a for a in analyses if a.compliance_status == ComplianceStatus.NON_COMPLIANT]),
        'high_risk': len([a for a in analyses if 'high' in a.risk_level.lower()]),
    }

    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        pd.DataFrame(data).to_excel(writer, sheet_name='Clause Analysis', index=False)
        pd.DataFrame([summary]).to_excel(writer, sheet_name='Summary', index=False)

save_comparison_results(all_results, "policy_vs_jci_hybrid.xlsx")
logger.info("Results saved to policy_vs_jci_hybrid.xlsx")

Analyzing clauses: 100%|██████████| 10/10 [01:11<00:00,  7.14s/it]
Analyzing clauses: 100%|██████████| 10/10 [01:16<00:00,  7.61s/it]
Analyzing clauses: 100%|██████████| 6/6 [00:56<00:00,  9.40s/it]
