# üî¨ TPN-RAG Full Evaluation with LangChain Pipeline

This notebook uses the **full production RAG pipeline** including:
- ‚úÖ LangChain integration
- ‚úÖ Hybrid Search (BM25 + Vector)
- ‚úÖ Cross-Encoder Reranking
- ‚úÖ Multi-Query Expansion
- ‚úÖ All advanced features from `app/services/advanced_rag.py`

---

## 1Ô∏è‚É£ Setup & Configuration

In [None]:
# Configuration
NUM_SAMPLES = 100  # Start with 100, increase to 941 for full eval
CHECKPOINT_EVERY = 10
OUTPUT_FILE = "full_eval_langchain_results.json"

# RAG Configuration (Advanced Features)
RAG_CONFIG = {
    'top_k': 10,                        # Retrieval depth
    'enable_hybrid': True,              # BM25 + Vector
    'enable_reranking': True,           # Cross-encoder reranking
    'enable_multi_query': False,        # Query expansion (slower but better)
    'reranker_model': 'BAAI/bge-reranker-v2-m3',
    'initial_k': 20,                    # Retrieve 20, rerank to top_k
}

# Model configuration  
LLM_MODEL = "chandramax/tpn-gpt-oss-20b"
JUDGE_MODEL = "gpt-5-mini"

print(f"üìä Will evaluate {NUM_SAMPLES} samples with advanced RAG")
print(f"   Hybrid Search: {RAG_CONFIG['enable_hybrid']}")
print(f"   Reranking: {RAG_CONFIG['enable_reranking']}")

In [None]:
# Setup paths and imports
import os
import sys
import json
import time
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from tqdm.notebook import tqdm

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

print(f"‚úÖ Project root: {project_root}")

## 2Ô∏è‚É£ Load LangChain RAG Pipeline (ONCE)

In [None]:
# Import the production RAG pipeline
from app.rag_pipeline import TPN_RAG, PipelineConfig, PipelineMode

print("üîÑ Initializing TPN_RAG pipeline...")
start = time.time()

# Create production config
config = PipelineConfig(
    llm_model="qwen2.5:7b",  # For testing; switch to "chandramax/tpn-gpt-oss-20b" for full eval
    embed_model="qwen3-embedding:0.6b",
    mode=PipelineMode.STANDARD,
    top_k=RAG_CONFIG['top_k'],
    require_grounding=True
)

# Initialize RAG pipeline
rag = TPN_RAG(config=config)
rag.initialize()

print(f"‚úÖ RAG pipeline initialized in {time.time()-start:.1f}s")
print(f"   Stats: {rag.get_stats()}")

In [None]:
# OR: Load the fine-tuned 20B model directly for more control
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

USE_FINETUNED_MODEL = True  # Set to True to use chandramax/tpn-gpt-oss-20b

if USE_FINETUNED_MODEL:
    print(f"üîÑ Loading fine-tuned LLM: {LLM_MODEL}...")
    start = time.time()
    
    tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL, trust_remote_code=True)
    llm_model = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    
    print(f"‚úÖ Fine-tuned LLM loaded in {time.time()-start:.1f}s")
    print(f"   Device: {next(llm_model.parameters()).device}")
    print(f"   GPU Memory: {torch.cuda.memory_allocated()/1e9:.1f} GB")
else:
    print("‚è© Using RAG pipeline's default LLM")
    llm_model = None
    tokenizer = None

## 3Ô∏è‚É£ Load Advanced RAG Components (Optional)

In [None]:
# Load advanced RAG components if available
try:
    from app.services.advanced_rag import AdvancedRAGConfig, AdvancedRAGPipeline
    
    advanced_config = AdvancedRAGConfig(
        enable_bm25_hybrid=RAG_CONFIG['enable_hybrid'],
        enable_cross_encoder_rerank=RAG_CONFIG['enable_reranking'],
        reranker_model=RAG_CONFIG['reranker_model'],
        initial_k=RAG_CONFIG['initial_k'],
        rerank_top_k=RAG_CONFIG['top_k'],
        enable_multi_query=RAG_CONFIG['enable_multi_query'],
        enable_rrf=True,
        rrf_k=60
    )
    
    print("‚úÖ Advanced RAG config loaded")
    print(f"   Hybrid Search: {advanced_config.enable_bm25_hybrid}")
    print(f"   Reranking: {advanced_config.enable_cross_encoder_rerank}")
    HAS_ADVANCED_RAG = True
    
except ImportError as e:
    print(f"‚ö†Ô∏è Advanced RAG not available: {e}")
    print("   Using basic vector search only")
    HAS_ADVANCED_RAG = False

## 4Ô∏è‚É£ Load Test Dataset

In [None]:
# Load test data
test_file = project_root / "eval" / "data" / "test_with_citations.jsonl"

samples = []
with open(test_file, 'r') as f:
    for line in f:
        samples.append(json.loads(line))

print(f"‚úÖ Loaded {len(samples)} test samples")
print(f"üìä Will evaluate first {min(NUM_SAMPLES, len(samples))} samples")

In [None]:
def extract_qa(sample: dict) -> Tuple[str, str]:
    """Extract question and expected answer from sample."""
    question = None
    expected = None
    for msg in sample['messages']:
        if msg['role'] == 'user':
            question = msg['content']
        elif msg['role'] == 'assistant':
            expected = msg.get('content', '')
    return question, expected

# Test extraction
q, a = extract_qa(samples[0])
print(f"Sample Q: {q[:100]}...")
print(f"Sample A: {a[:100]}...")

## 5Ô∏è‚É£ Inference Functions

In [None]:
def run_rag_query(question: str, use_rag: bool = True) -> Tuple[str, List[str], float]:
    """
    Run query through the LangChain RAG pipeline.
    
    Returns:
        answer: Generated answer
        context_list: Retrieved documents (for faithfulness eval)
        retrieval_time: Time taken for retrieval
    """
    start = time.time()
    
    if use_rag:
        # Use the full RAG pipeline
        result = rag.ask(
            question=question,
            answer_type="single",
            require_grounding=True
        )
        
        answer = result.answer
        context_list = [result.context_used] if result.context_used else []
        
        # Get individual sources for faithfulness eval
        if hasattr(result, 'sources') and result.sources:
            context_list = [s.get('content', '') for s in result.sources if 'content' in s]
    else:
        # No RAG - just use the LLM directly
        if llm_model is not None:
            answer = run_direct_inference(question)
        else:
            # Use RAG pipeline but with empty context
            result = rag.ask(question=question, require_grounding=False)
            answer = result.answer
        context_list = []
    
    return answer, context_list, time.time() - start


def run_direct_inference(question: str, max_tokens: int = 1024) -> str:
    """Run the fine-tuned model directly without RAG context."""
    if llm_model is None:
        raise ValueError("Fine-tuned model not loaded")
    
    system_prompt = """You are a clinical expert in neonatal/pediatric TPN. 
Answer accurately based on your training."""
    
    messages = [
        {"role": "developer", "content": system_prompt},
        {"role": "user", "content": question}
    ]
    
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
    
    with torch.no_grad():
        outputs = llm_model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response.strip()

print("‚úÖ Inference functions ready")

## 6Ô∏è‚É£ Metric Functions

In [None]:
# DeepEval metrics
from deepeval.metrics import GEval, FaithfulnessMetric, AnswerRelevancyMetric, ContextualRecallMetric
from deepeval.test_case import LLMTestCase, LLMTestCaseParams

# Clinical GEval
clinical_geval = GEval(
    name="Clinical Correctness",
    criteria="""Evaluate clinical correctness for TPN:
1. DOSING VALUES: Are numeric values exactly correct?
2. UNITS: Are units correct? (mg vs g, mEq vs mmol)
3. RANGES: Are ranges accurate?
4. POPULATIONS: Is patient population correct?
5. SAFETY: Are contraindications mentioned?""",
    evaluation_params=[LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT],
    model=JUDGE_MODEL,
    threshold=0.7
)

faithfulness = FaithfulnessMetric(threshold=0.7, model=JUDGE_MODEL)
relevancy = AnswerRelevancyMetric(threshold=0.7, model=JUDGE_MODEL)
recall = ContextualRecallMetric(threshold=0.7, model=JUDGE_MODEL)

print("‚úÖ DeepEval metrics ready")

## 7Ô∏è‚É£ Run Full Evaluation

In [None]:
# Initialize or resume
results = []
start_idx = 0

checkpoint_file = Path(OUTPUT_FILE.replace('.json', '_checkpoint.json'))
if checkpoint_file.exists():
    with open(checkpoint_file, 'r') as f:
        checkpoint = json.load(f)
        results = checkpoint['results']
        start_idx = len(results)
        print(f"üìÇ Resuming from checkpoint: {start_idx} samples done")
else:
    print(f"üÜï Starting fresh evaluation")

In [None]:
# MAIN EVALUATION LOOP
num_to_eval = min(NUM_SAMPLES, len(samples))

for idx in tqdm(range(start_idx, num_to_eval), desc="Evaluating"):
    sample = samples[idx]
    question, expected = extract_qa(sample)
    
    try:
        # Phase 1: No RAG
        if llm_model is not None:
            phase1_answer = run_direct_inference(question)
            phase1_context = []
        else:
            phase1_answer, phase1_context, _ = run_rag_query(question, use_rag=False)
        
        # Phase 2: With RAG (using LangChain pipeline)
        phase2_answer, phase2_context, retrieval_time = run_rag_query(question, use_rag=True)
        
        # Evaluate Phase 1
        tc1 = LLMTestCase(input=question, actual_output=phase1_answer, expected_output=expected)
        clinical_geval.measure(tc1)
        p1_clinical = clinical_geval.score
        
        # Evaluate Phase 2
        tc2 = LLMTestCase(
            input=question, 
            actual_output=phase2_answer, 
            expected_output=expected,
            retrieval_context=phase2_context if phase2_context else ["No context retrieved"]
        )
        clinical_geval.measure(tc2)
        p2_clinical = clinical_geval.score
        
        # Faithfulness & Recall (Phase 2 only)
        if phase2_context:
            faithfulness.measure(tc2)
            faith_score = faithfulness.score
            
            recall.measure(tc2)
            recall_score = recall.score
        else:
            faith_score = 0.0
            recall_score = 0.0
        
        # Relevancy
        relevancy.measure(tc2)
        rel_score = relevancy.score
        
        result = {
            'idx': idx,
            'question': question[:150],
            'p1_clinical': p1_clinical,
            'p2_clinical': p2_clinical,
            'rag_lift': p2_clinical - p1_clinical,
            'faithfulness': faith_score,
            'relevancy': rel_score,
            'ctx_recall': recall_score,
            'retrieval_time': retrieval_time,
            'num_docs_retrieved': len(phase2_context)
        }
        results.append(result)
        
        # Checkpoint
        if (idx + 1) % CHECKPOINT_EVERY == 0:
            with open(checkpoint_file, 'w') as f:
                json.dump({'results': results}, f)
            print(f"  üíæ Checkpoint at {idx+1}")
            
    except Exception as e:
        print(f"  ‚ùå Error on sample {idx}: {e}")
        import traceback
        traceback.print_exc()
        continue

print(f"\n‚úÖ Evaluation complete: {len(results)} samples")

## 8Ô∏è‚É£ Analyze Results

In [None]:
import pandas as pd
import numpy as np

df = pd.DataFrame(results)

print("="*70)
print(f"üìä FULL EVALUATION RESULTS (n={len(df)})")
print("="*70)
print()

print("CLINICAL CORRECTNESS:")
print(f"  Phase 1 (No RAG):   {df['p1_clinical'].mean():.1%} (std: {df['p1_clinical'].std():.1%})")
print(f"  Phase 2 (With RAG): {df['p2_clinical'].mean():.1%} (std: {df['p2_clinical'].std():.1%})")
print(f"  RAG Lift:           +{df['rag_lift'].mean():.1%}")
print(f"  RAG Lift %:         +{(df['p2_clinical'].mean() - df['p1_clinical'].mean()) / df['p1_clinical'].mean() * 100:.1f}%")
print()

print("RAG QUALITY METRICS:")
print(f"  Faithfulness:       {df['faithfulness'].mean():.1%}")
print(f"  Answer Relevancy:   {df['relevancy'].mean():.1%}")
print(f"  Contextual Recall:  {df['ctx_recall'].mean():.1%}")
print()

print("SCORE DISTRIBUTION (Phase 2):")
print(f"  Perfect (100%):  {(df['p2_clinical'] >= 1.0).sum()} ({(df['p2_clinical'] >= 1.0).mean():.0%})")
print(f"  Passing (‚â•70%):  {(df['p2_clinical'] >= 0.7).sum()} ({(df['p2_clinical'] >= 0.7).mean():.0%})")
print(f"  Failing (<50%):  {(df['p2_clinical'] < 0.5).sum()} ({(df['p2_clinical'] < 0.5).mean():.0%})")
print()

print("PERFORMANCE:")
print(f"  Avg retrieval time: {df['retrieval_time'].mean():.2f}s")
print(f"  Avg docs retrieved: {df['num_docs_retrieved'].mean():.1f}")

In [None]:
# Visualize
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Clinical Correctness Distribution
ax = axes[0, 0]
ax.hist(df['p1_clinical'], bins=10, alpha=0.5, label='No RAG', color='red')
ax.hist(df['p2_clinical'], bins=10, alpha=0.5, label='With RAG', color='green')
ax.set_xlabel('Clinical Correctness')
ax.set_ylabel('Count')
ax.set_title('Score Distribution Comparison')
ax.legend()

# 2. RAG Lift by Sample
ax = axes[0, 1]
colors = ['green' if x > 0 else 'red' for x in df['rag_lift']]
ax.bar(range(len(df)), df['rag_lift'], color=colors, alpha=0.7)
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
ax.axhline(y=df['rag_lift'].mean(), color='blue', linestyle='--', label=f'Mean: +{df["rag_lift"].mean():.1%}')
ax.set_xlabel('Sample Index')
ax.set_ylabel('RAG Lift')
ax.set_title('RAG Improvement by Sample')
ax.legend()

# 3. Key Metrics
ax = axes[1, 0]
metrics = ['Clinical\nCorrectness', 'Faithfulness', 'Relevancy', 'Recall']
values = [df['p2_clinical'].mean(), df['faithfulness'].mean(), df['relevancy'].mean(), df['ctx_recall'].mean()]
colors = ['#2ecc71' if v >= 0.8 else '#f39c12' if v >= 0.6 else '#e74c3c' for v in values]
bars = ax.bar(metrics, values, color=colors)
ax.set_ylim(0, 1)
ax.axhline(y=0.7, color='gray', linestyle='--', alpha=0.5, label='Threshold (70%)')
ax.set_title('RAG Quality Metrics')
ax.legend()
for bar, v in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width()/2, v + 0.02, f'{v:.0%}', ha='center', fontweight='bold')

# 4. Summary Box
ax = axes[1, 1]
ax.axis('off')
summary = f"""
EVALUATION SUMMARY
{'='*40}

Total Samples:     {len(df)}
RAG Config:        top_k={RAG_CONFIG['top_k']}, hybrid={RAG_CONFIG['enable_hybrid']}

RESULTS:
  Clinical (No RAG):   {df['p1_clinical'].mean():.1%}
  Clinical (With RAG): {df['p2_clinical'].mean():.1%}
  RAG Lift:            +{(df['p2_clinical'].mean() - df['p1_clinical'].mean()) / df['p1_clinical'].mean() * 100:.1f}%

  Faithfulness:        {df['faithfulness'].mean():.1%}
  Contextual Recall:   {df['ctx_recall'].mean():.1%}

PASS RATE:
  ‚â•70%: {(df['p2_clinical'] >= 0.7).mean():.0%} ({(df['p2_clinical'] >= 0.7).sum()}/{len(df)})
"""
ax.text(0.1, 0.95, summary, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.savefig('eval_results_langchain.png', dpi=150)
plt.show()

## 9Ô∏è‚É£ Save Results

In [None]:
# Save final results
final = {
    'num_samples': len(results),
    'config': RAG_CONFIG,
    'summary': {
        'p1_clinical': df['p1_clinical'].mean(),
        'p2_clinical': df['p2_clinical'].mean(),
        'rag_lift': df['rag_lift'].mean(),
        'rag_lift_pct': (df['p2_clinical'].mean() - df['p1_clinical'].mean()) / df['p1_clinical'].mean() * 100,
        'faithfulness': df['faithfulness'].mean(),
        'relevancy': df['relevancy'].mean(),
        'ctx_recall': df['ctx_recall'].mean(),
        'pass_rate': (df['p2_clinical'] >= 0.7).mean(),
        'perfect_rate': (df['p2_clinical'] >= 1.0).mean()
    },
    'results': results
}

with open(OUTPUT_FILE, 'w') as f:
    json.dump(final, f, indent=2)

print(f"‚úÖ Saved to {OUTPUT_FILE}")

# Cleanup checkpoint
if checkpoint_file.exists():
    checkpoint_file.unlink()
    print("üóëÔ∏è Checkpoint removed")