# KRT Diabetes Knowledge Evaluation

This notebook evaluates the impact of injecting curated Type 2 Diabetes Q&A (Key Risk Topic approach) into the TrustMed AI system.

**Evaluation Metrics:**
- Answer accuracy vs ground truth
- Before/after comparison (baseline vs KRT-enhanced)
- Key term coverage
- Factual correctness

## 1. Setup & Imports

In [None]:
import sys
import os
from dotenv import load_dotenv
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
import re

# Load environment variables
load_dotenv()

# Import project modules
from krt_diabetes_knowledge import TOP_TYPE2_QA, format_top_qa_text
from creat_graph_ollama import call_ollama, get_ollama_embedding
from camel.storages import Neo4jGraph

print(f"Loaded {len(TOP_TYPE2_QA)} curated diabetes Q&A pairs")

In [None]:
# Initialize Neo4j connection
neo4j_url = os.getenv("NEO4J_URL", "bolt://localhost:7687")
neo4j_user = os.getenv("NEO4J_USERNAME", "neo4j")
neo4j_password = os.getenv("NEO4J_PASSWORD", "password")

n4j = Neo4jGraph(
    url=neo4j_url,
    username=neo4j_user,
    password=neo4j_password
)

print(f"Connected to Neo4j at {neo4j_url}")

## 2. Define Evaluation Functions

In [None]:
def generate_baseline_response(question, model="llama3"):
    """
    Generate response WITHOUT KRT context injection (baseline).
    Uses only the model's inherent knowledge.
    """
    system_prompt = """You are a medical assistant providing information about diabetes.
Answer the question concisely and accurately based on your medical knowledge.
Keep your answer to 2-3 paragraphs."""
    
    user_prompt = f"Question: {question}"
    
    response = call_ollama(system_prompt + "\n\n" + user_prompt, model)
    return response


def generate_krt_enhanced_response(question, model="llama3"):
    """
    Generate response WITH KRT context injection (enhanced).
    Injects curated diabetes knowledge as reference.
    """
    krt_context = format_top_qa_text()
    
    system_prompt = """You are a medical assistant providing information about diabetes.
Answer the question concisely and accurately.
You have access to curated reference material below - use it to provide accurate information.
Keep your answer to 2-3 paragraphs."""
    
    user_prompt = f"""Reference Material:
======================================================================
** REFERENCE: CURATED TYPE 2 DIABETES KNOWLEDGE (KRT) **
{krt_context}
======================================================================

Question: {question}"""
    
    response = call_ollama(system_prompt + "\n\n" + user_prompt, model)
    return response


print("Response generation functions defined")

In [None]:
def extract_key_terms(text):
    """
    Extract key medical/factual terms from text.
    Returns set of lowercase terms.
    """
    # Medical and factual keywords to look for
    patterns = [
        r'\b\d+(?:\.\d+)?\s*(?:million|percent|%|mg|years?|days?)\b',  # Numbers with units
        r'\btype\s*[12]\s*diabetes\b',
        r'\b(?:CDC|NIDDK|NIH|WHO)\b',  # Organizations
        r'\binsulin\s*resistance\b',
        r'\bblood\s*(?:sugar|glucose)\b',
        r'\bA1C\b',
        r'\b(?:hyper|hypo)glycemia\b',
        r'\bprediabetes\b',
        r'\b(?:heart|kidney|liver)\s*disease\b',
        r'\bneuropathy\b',
        r'\b(?:GLP-1|SGLT2)\b',
        r'\bmetformin\b',
        r'\bbeta\s*cells?\b',
        r'\bpancreas\b',
    ]
    
    terms = set()
    text_lower = text.lower()
    
    for pattern in patterns:
        matches = re.findall(pattern, text_lower, re.IGNORECASE)
        terms.update([m.strip().lower() for m in matches])
    
    return terms


def calculate_key_term_overlap(response, ground_truth):
    """
    Calculate percentage of key terms from ground truth present in response.
    """
    ground_truth_terms = extract_key_terms(ground_truth)
    response_terms = extract_key_terms(response)
    
    if not ground_truth_terms:
        return 0.0
    
    overlap = ground_truth_terms.intersection(response_terms)
    overlap_percentage = len(overlap) / len(ground_truth_terms) * 100
    
    return overlap_percentage


def check_source_citation(response):
    """
    Check if response mentions authoritative sources.
    """
    sources = ['CDC', 'NIDDK', 'NIH', 'WHO', 'American Diabetes Association']
    cited_sources = []
    
    for source in sources:
        if source.lower() in response.lower():
            cited_sources.append(source)
    
    return cited_sources


def check_critical_facts(response, ground_truth):
    """
    Check for presence of critical facts from ground truth.
    Returns dict with fact checks.
    """
    facts_found = {}
    
    # Extract numbers from ground truth
    numbers = re.findall(r'\b(\d+(?:\.\d+)?)\s*(?:million|percent|%)', ground_truth)
    
    for num in numbers:
        facts_found[f"mentions_{num}"] = num in response
    
    # Check for key concepts
    key_concepts = [
        ('insulin_resistance', r'insulin\s*resistance'),
        ('autoimmune', r'autoimmune'),
        ('lifestyle', r'lifestyle|diet|exercise|weight'),
        ('complications', r'complication|heart|kidney|nerve|eye'),
    ]
    
    for concept_name, pattern in key_concepts:
        if re.search(pattern, ground_truth, re.IGNORECASE):
            facts_found[concept_name] = bool(re.search(pattern, response, re.IGNORECASE))
    
    return facts_found


print("Evaluation metrics functions defined")

## 3. Run Evaluation: Baseline vs KRT-Enhanced

In [None]:
# Store results
evaluation_results = []

print("Starting evaluation...")
print("=" * 70)

for i, qa_pair in enumerate(TOP_TYPE2_QA):
    question = qa_pair['question']
    ground_truth = qa_pair['answer']
    krt_focus = qa_pair['krt_focus']
    
    print(f"\nQ{i+1}: {question[:60]}...")
    
    # Generate baseline response (without KRT)
    print("  Generating baseline response...")
    baseline_response = generate_baseline_response(question)
    
    # Generate KRT-enhanced response
    print("  Generating KRT-enhanced response...")
    krt_response = generate_krt_enhanced_response(question)
    
    # Calculate metrics for baseline
    baseline_overlap = calculate_key_term_overlap(baseline_response, ground_truth)
    baseline_sources = check_source_citation(baseline_response)
    baseline_facts = check_critical_facts(baseline_response, ground_truth)
    baseline_fact_score = sum(baseline_facts.values()) / max(len(baseline_facts), 1) * 100
    
    # Calculate metrics for KRT-enhanced
    krt_overlap = calculate_key_term_overlap(krt_response, ground_truth)
    krt_sources = check_source_citation(krt_response)
    krt_facts = check_critical_facts(krt_response, ground_truth)
    krt_fact_score = sum(krt_facts.values()) / max(len(krt_facts), 1) * 100
    
    # Store results
    result = {
        'question_id': qa_pair['id'],
        'question': question,
        'krt_focus': krt_focus,
        'ground_truth': ground_truth,
        'baseline_response': baseline_response,
        'krt_response': krt_response,
        'baseline_key_term_overlap': baseline_overlap,
        'krt_key_term_overlap': krt_overlap,
        'baseline_sources': baseline_sources,
        'krt_sources': krt_sources,
        'baseline_fact_score': baseline_fact_score,
        'krt_fact_score': krt_fact_score,
    }
    evaluation_results.append(result)
    
    print(f"  Baseline: {baseline_overlap:.1f}% key terms, {baseline_fact_score:.1f}% facts")
    print(f"  KRT-Enhanced: {krt_overlap:.1f}% key terms, {krt_fact_score:.1f}% facts")
    print(f"  Improvement: +{krt_overlap - baseline_overlap:.1f}% terms, +{krt_fact_score - baseline_fact_score:.1f}% facts")

print("\n" + "=" * 70)
print("Evaluation complete!")

## 4. Results Analysis & Visualization

In [None]:
# Convert to DataFrame for analysis
results_df = pd.DataFrame(evaluation_results)

# Summary statistics
print("=" * 70)
print("EVALUATION SUMMARY")
print("=" * 70)

avg_baseline_terms = results_df['baseline_key_term_overlap'].mean()
avg_krt_terms = results_df['krt_key_term_overlap'].mean()
avg_baseline_facts = results_df['baseline_fact_score'].mean()
avg_krt_facts = results_df['krt_fact_score'].mean()

print(f"\nKey Term Overlap (avg):")
print(f"  Baseline: {avg_baseline_terms:.1f}%")
print(f"  KRT-Enhanced: {avg_krt_terms:.1f}%")
print(f"  Improvement: +{avg_krt_terms - avg_baseline_terms:.1f}%")

print(f"\nFactual Accuracy (avg):")
print(f"  Baseline: {avg_baseline_facts:.1f}%")
print(f"  KRT-Enhanced: {avg_krt_facts:.1f}%")
print(f"  Improvement: +{avg_krt_facts - avg_baseline_facts:.1f}%")

# Source citations
baseline_with_sources = sum(1 for r in evaluation_results if r['baseline_sources'])
krt_with_sources = sum(1 for r in evaluation_results if r['krt_sources'])

print(f"\nSource Citations:")
print(f"  Baseline: {baseline_with_sources}/{len(evaluation_results)} responses cite sources")
print(f"  KRT-Enhanced: {krt_with_sources}/{len(evaluation_results)} responses cite sources")

print("\n" + "=" * 70)

In [None]:
# Visualization: Bar chart comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot 1: Key Term Overlap
question_numbers = [f"Q{i+1}" for i in range(len(evaluation_results))]
x_positions = np.arange(len(question_numbers))
bar_width = 0.35

baseline_terms = [r['baseline_key_term_overlap'] for r in evaluation_results]
krt_terms = [r['krt_key_term_overlap'] for r in evaluation_results]

axes[0].bar(x_positions - bar_width/2, baseline_terms, bar_width, label='Baseline', color='#ff7f7f')
axes[0].bar(x_positions + bar_width/2, krt_terms, bar_width, label='KRT-Enhanced', color='#7fbf7f')
axes[0].set_xlabel('Question')
axes[0].set_ylabel('Key Term Overlap (%)')
axes[0].set_title('Key Term Coverage: Baseline vs KRT-Enhanced')
axes[0].set_xticks(x_positions)
axes[0].set_xticklabels(question_numbers)
axes[0].legend()
axes[0].grid(axis='y', alpha=0.3)

# Plot 2: Factual Accuracy
baseline_facts_scores = [r['baseline_fact_score'] for r in evaluation_results]
krt_facts_scores = [r['krt_fact_score'] for r in evaluation_results]

axes[1].bar(x_positions - bar_width/2, baseline_facts_scores, bar_width, label='Baseline', color='#ff7f7f')
axes[1].bar(x_positions + bar_width/2, krt_facts_scores, bar_width, label='KRT-Enhanced', color='#7fbf7f')
axes[1].set_xlabel('Question')
axes[1].set_ylabel('Factual Accuracy Score (%)')
axes[1].set_title('Factual Accuracy: Baseline vs KRT-Enhanced')
axes[1].set_xticks(x_positions)
axes[1].set_xticklabels(question_numbers)
axes[1].legend()
axes[1].grid(axis='y', alpha=0.3)

plt.suptitle('KRT Diabetes Knowledge Integration - Performance Evaluation', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('krt_evaluation_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("Visualization saved as 'krt_evaluation_results.png'")

In [None]:
# Improvement distribution
improvements = [
    r['krt_key_term_overlap'] - r['baseline_key_term_overlap'] 
    for r in evaluation_results
]

fig, ax = plt.subplots(figsize=(10, 5))
colors = ['green' if imp > 0 else 'red' for imp in improvements]
ax.bar(question_numbers, improvements, color=colors, alpha=0.7)
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
ax.set_xlabel('Question')
ax.set_ylabel('Improvement in Key Term Overlap (%)')
ax.set_title('Key Term Coverage Improvement per Question')
ax.grid(axis='y', alpha=0.3)

# Add average line
avg_improvement = np.mean(improvements)
ax.axhline(y=avg_improvement, color='blue', linestyle='--', linewidth=2, 
           label=f'Average: +{avg_improvement:.1f}%')
ax.legend()

plt.tight_layout()
plt.show()

## 5. Detailed Response Comparison

In [None]:
def display_comparison(result_index):
    """
    Display side-by-side comparison for a specific question.
    """
    result = evaluation_results[result_index]
    
    print("=" * 80)
    print(f"QUESTION {result['question_id']}: {result['question']}")
    print(f"KRT Focus: {result['krt_focus']}")
    print("=" * 80)
    
    print("\nüìä GROUND TRUTH (Expected Answer):")
    print("-" * 40)
    print(result['ground_truth'])
    
    print("\n\nüî¥ BASELINE RESPONSE (Without KRT):")
    print("-" * 40)
    print(result['baseline_response'])
    print(f"\n  Key Term Overlap: {result['baseline_key_term_overlap']:.1f}%")
    print(f"  Factual Score: {result['baseline_fact_score']:.1f}%")
    print(f"  Sources Cited: {result['baseline_sources'] if result['baseline_sources'] else 'None'}")
    
    print("\n\nüü¢ KRT-ENHANCED RESPONSE (With KRT):")
    print("-" * 40)
    print(result['krt_response'])
    print(f"\n  Key Term Overlap: {result['krt_key_term_overlap']:.1f}%")
    print(f"  Factual Score: {result['krt_fact_score']:.1f}%")
    print(f"  Sources Cited: {result['krt_sources'] if result['krt_sources'] else 'None'}")
    
    improvement_terms = result['krt_key_term_overlap'] - result['baseline_key_term_overlap']
    improvement_facts = result['krt_fact_score'] - result['baseline_fact_score']
    
    print("\n\nüìà IMPROVEMENT:")
    print("-" * 40)
    print(f"  Key Term Overlap: {'+' if improvement_terms >= 0 else ''}{improvement_terms:.1f}%")
    print(f"  Factual Accuracy: {'+' if improvement_facts >= 0 else ''}{improvement_facts:.1f}%")
    print("=" * 80)


# Display first comparison as example
print("Example comparison for Question 1:\n")
display_comparison(0)

In [None]:
# Interactive: View any question comparison
# Change the index (0-9) to view different questions
question_to_view = 3  # Change this to view different questions (0-9)
display_comparison(question_to_view)

## 6. Final Metrics Report

In [None]:
# Generate final report
print("\n" + "=" * 70)
print("        KRT DIABETES KNOWLEDGE INTEGRATION - FINAL REPORT")
print("=" * 70)

print("\nüìä OVERALL PERFORMANCE METRICS")
print("-" * 70)

# Key metrics table
metrics_table = pd.DataFrame({
    'Metric': ['Key Term Overlap', 'Factual Accuracy', 'Source Citations'],
    'Baseline': [
        f"{avg_baseline_terms:.1f}%",
        f"{avg_baseline_facts:.1f}%",
        f"{baseline_with_sources}/{len(evaluation_results)}"
    ],
    'KRT-Enhanced': [
        f"{avg_krt_terms:.1f}%",
        f"{avg_krt_facts:.1f}%",
        f"{krt_with_sources}/{len(evaluation_results)}"
    ],
    'Improvement': [
        f"+{avg_krt_terms - avg_baseline_terms:.1f}%",
        f"+{avg_krt_facts - avg_baseline_facts:.1f}%",
        f"+{krt_with_sources - baseline_with_sources}"
    ]
})

print(metrics_table.to_string(index=False))

print("\n\nüìã PER-QUESTION BREAKDOWN")
print("-" * 70)

question_summary = pd.DataFrame({
    'Q#': [f"Q{r['question_id']}" for r in evaluation_results],
    'KRT Focus': [r['krt_focus'][:40] + '...' for r in evaluation_results],
    'Baseline': [f"{r['baseline_key_term_overlap']:.0f}%" for r in evaluation_results],
    'KRT': [f"{r['krt_key_term_overlap']:.0f}%" for r in evaluation_results],
    'Œî': [f"+{r['krt_key_term_overlap'] - r['baseline_key_term_overlap']:.0f}%" for r in evaluation_results],
    'Pass': ['‚úÖ' if r['krt_key_term_overlap'] > r['baseline_key_term_overlap'] else '‚ùå' for r in evaluation_results]
})

print(question_summary.to_string(index=False))

# Pass rate
pass_count = sum(1 for r in evaluation_results if r['krt_key_term_overlap'] > r['baseline_key_term_overlap'])
pass_rate = pass_count / len(evaluation_results) * 100

print(f"\n\n‚úÖ PASS RATE: {pass_count}/{len(evaluation_results)} ({pass_rate:.0f}%)")
print(f"   Questions where KRT injection improved key term coverage")

print("\n\nüéØ CONCLUSION")
print("-" * 70)

if avg_krt_terms > avg_baseline_terms:
    print(f"‚úÖ KRT diabetes knowledge injection IMPROVES answer quality")
    print(f"   - Average key term overlap increased by {avg_krt_terms - avg_baseline_terms:.1f}%")
    print(f"   - Average factual accuracy improved by {avg_krt_facts - avg_baseline_facts:.1f}%")
else:
    print(f"‚ö†Ô∏è Results inconclusive - further tuning may be needed")

print(f"\n   Recommendation: The KRT context injection provides the model with")
print(f"   authoritative, source-cited information that results in more accurate")
print(f"   and comprehensive responses to diabetes-related questions.")

print("\n" + "=" * 70)

In [None]:
# Save results to CSV for further analysis
results_df.to_csv('krt_evaluation_results.csv', index=False)
print("Results saved to 'krt_evaluation_results.csv'")