# üöÄ RunPod GPU Setup

**This notebook is optimized for RunPod GPU pods with NVIDIA GPUs**

## Quick Start on RunPod:

1. **Launch a GPU Pod** (RTX 3090, 4090, or A5000 recommended)
2. **Upload this notebook** to the pod
3. **Upload test data** (`test_run_20251106.jsonl`) to `/workspace/data/`
4. **Run cells in order** - evaluation should complete in ~5-10 minutes

## Expected Performance:
- **GPU**: RTX 3090/4090 ‚Üí ~0.5-1 sec/sample (~5 min total)
- **GPU**: RTX A5000 ‚Üí ~1-2 sec/sample (~10 min total)
- **Full evaluation**: 300 samples

---

# Medical NER Model Evaluation

This notebook evaluates the fine-tuned Llama 3.2 3B medical NER model.

## ‚úÖ DATASET VERIFIED & READY FOR EVALUATION

**Current Dataset Distribution** (from `both_rel_instruct_all.jsonl`):
- **1,000 Chemical extraction** examples (33.3%)
- **1,000 Disease extraction** examples (33.3%)
- **1,000 Relationship extraction** examples (33.3%)

**Data Splits Status**: ‚úÖ Properly stratified using `stratify=` parameter
- Training (2,400): 33.3% chemical, 33.3% disease, 33.3% relationship
- Validation (300): 33.3% chemical, 33.3% disease, 33.3% relationship
- Test (300): 33.3% chemical, 33.3% disease, 33.3% relationship

**Balanced Distribution**:
- All three tasks equally represented
- Stratified splitting ensures exact proportions in all splits
- All splits are properly balanced with perfect 33.3% per task

**Next Steps**:
1. ‚úÖ Training data is properly split with stratification
2. ‚úÖ No data leakage between train/val/test
3. ‚úÖ Update `HF_MODEL_ID` below with your trained model ID
4. ‚úÖ Run this evaluation notebook on the balanced test set

---

## Prerequisites:
1. Complete training in `Medical_NER_Fine_Tuning.ipynb` (uses stratified splits!)
2. Model saved to `./final_model` or uploaded to HuggingFace Hub
3. Test data available in `../data/splits_20251111/test.jsonl`

## Evaluation Tasks:
1. Load the fine-tuned model
2. Evaluate on test set (33.3% each: chemicals, diseases, relationships)
3. Calculate precision, recall, F1 scores per task type
4. Test on custom medical texts
5. Analyze errors and false positives

## üìä Expected Dataset Characteristics (from Deep Data Exploration)

**This evaluation is calibrated against known dataset statistics**

### Test Set Expected Properties:
- **Total samples**: 300 (10% of 3,000)
- **Task distribution**: Perfectly balanced (stratified split)
  - Chemical extraction: 100 samples (33.3%)
  - Disease extraction: 100 samples (33.3%)
  - Relationship extraction: 100 samples (33.3%)

### Entity Universe (Full Dataset):
- **Unique chemicals**: ~1,578 total entities
- **Unique diseases**: ~2,199 total entities  
- **Vocabulary**: ~13,710 unique words

### Entity Complexity:
- **Chemical names**: Avg 11.1 chars, 1.2 words
- **Disease names**: Avg 14.9 chars, 1.7 words
- **Hyphenated entities**: ~459 (e.g., "type-2 diabetes")
  - Model should preserve hyphens during extraction
- **Special characters**: 13 types found in entities

### Expected Model Behavior:
1. **Format Handling**: Model trained with simple system prompt
   - Training format: Llama 3 chat with basic NER instructions
   - Evaluation uses matching prompt format for consistency
2. **Entity Recognition**: Should extract only entities appearing verbatim in text
   - Post-filters verify all predictions exist in source text
3. **Precision Focus**: Conservative training to minimize false positives

### Quality Benchmarks:
- **Training data quality**: Zero duplicates, zero empty completions
- **Vocabulary coverage**: Complete (13,710 words trained)
- **Task distribution**: Perfect 33.3% per task maintained in all splits

**Use these statistics to validate evaluation results and detect anomalies.**

## 0. Environment Variables Setup

‚ö†Ô∏è **IMPORTANT**: Set your credentials before running the notebook!

**Note**: `hf_transfer` is enabled for faster downloads from HuggingFace Hub.

In [None]:
import os
from getpass import getpass

# Enable hf_transfer for faster downloads from HuggingFace Hub
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# HuggingFace Token (required to download your model from Hub)
# Get your token from: https://huggingface.co/settings/tokens
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
    print("HF_TOKEN not found in environment variables")
    hf_token = getpass("Enter your HuggingFace token: ")
    os.environ["HF_TOKEN"] = hf_token
else:
    print("‚úì HF_TOKEN loaded from environment")

# Weights & Biases API Key (optional - only if tracking evaluation metrics)
# Get your key from: https://wandb.ai/authorize
wandb_key = os.getenv("WANDB_API_KEY")
if wandb_key:
    print("‚úì WANDB_API_KEY loaded from environment")
else:
    print("‚Ñπ WANDB_API_KEY not set (optional)")

print("\n‚úì Environment variables configured")
print(f"  HF_HUB_ENABLE_HF_TRANSFER: {os.getenv('HF_HUB_ENABLE_HF_TRANSFER')}")

## 1. Setup and Installation


In [None]:
# Install PyTorch and other required packages
!pip install -q transformers datasets peft accelerate bitsandbytes
!pip install -q huggingface-hub tokenizers hf-transfer

print("‚úì All packages installed successfully!")
print("  - transformers (HuggingFace models)")
print("  - peft (LoRA adapters)")
print("  - accelerate (device management)")
print("  - bitsandbytes (quantization)")
print("  - hf-transfer (fast downloads)")

## 2. Import Libraries


In [None]:

import json
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import login

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 0) Reusable Utilities

‚ö†Ô∏è **IMPORTANT**: Run this cell BEFORE running evaluation cells below!

These utility functions provide text normalization, hashing, parsing, and validation for the evaluation pipeline.

In [None]:
# ===== Utilities: normalization, hashing, parsing =====
import re, json, hashlib
from collections import Counter

def dehyphenate(s: str) -> str:
    # Join words broken across lines with hyphens + whitespace
    return re.sub(r"(\w+)-\s+(\w+)", r"\1\2", s)

def normalize_text(s: str) -> str:
    s = dehyphenate(s or "")
    s = s.lower()
    s = re.sub(r"[\u00A0\t\r\n]+", " ", s)     # spaces/newlines
    s = re.sub(r"\s+", " ", s).strip()
    return s

def prompt_hash(prompt: str) -> str:
    return hashlib.md5(normalize_text(prompt).encode("utf-8")).hexdigest()

def parse_bullets(text: str):
    items = []
    for line in (text or "").splitlines():
        m = re.match(r"^\s*[-*]\s*(.+?)\s*$", line)
        if m:
            items.append(m.group(1))
    return items

def normalize_item(s: str) -> str:
    s = (s or "").lower()
    # Keep hyphens intact (e.g., "type-2 diabetes" stays "type-2 diabetes")
    s = re.sub(r"\s+", " ", s)  # Only normalize whitespace
    s = re.sub(r"[\.,;:]+$", "", s).strip()
    return s

def in_text(item: str, text: str) -> bool:
    """Check if item appears in text using word boundaries to avoid partial matches."""
    item_norm = normalize_item(item)
    text_norm = normalize_text(text)
    # Use word boundaries to avoid matching "aspirin" in "aspirinate"
    pattern = r'\b' + re.escape(item_norm) + r'\b'
    return bool(re.search(pattern, text_norm))

def unique_preserve_order(seq):
    seen = set()
    out = []
    for x in seq:
        if x not in seen:
            seen.add(x); out.append(x)
    return out

print("‚úì Utility functions loaded")

## 3. Configuration

‚ö†Ô∏è **Update these paths** to match your model location!


In [None]:
# Model configuration
BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"

# ‚ö†Ô∏è IMPORTANT: Update with YOUR HuggingFace model ID
# Find it at: https://huggingface.co/your-username
# Format: "your-username/llama3-medical-ner-lora-YYYYMMDD_HHMMSS"
HF_MODEL_ID = "albyos/llama3-medical-ner-checkpoint-450-20251108_114135"  # ‚Üê UPDATE THIS!

# Alternative: Use local model if you prefer
USE_HF_HUB = True  # Set to False to use local ../final_model
PROJECT_ROOT = Path.cwd().parent
LOCAL_MODEL_PATH = PROJECT_ROOT / "final_model"

ADAPTER_PATH = HF_MODEL_ID if USE_HF_HUB else str(LOCAL_MODEL_PATH)

# Data configuration
# For RunPod: Upload test data to /workspace/data/test.jsonl
# For local: Use your local path
try:
    # Try actual split location first
    TEST_DATA_PATH = Path.cwd().parent.parent / "data" / "splits_20251111" / "test.jsonl"
    if not TEST_DATA_PATH.exists():
        # Fallback to parent data directory (for local)
        TEST_DATA_PATH = Path.cwd().parent / "data" / "splits_20251111" / "test.jsonl"
        if not TEST_DATA_PATH.exists():
            # Another fallback - current directory
            TEST_DATA_PATH = Path("test.jsonl")
except Exception:
    TEST_DATA_PATH = Path("test.jsonl")

# Verify test data exists
if not TEST_DATA_PATH.exists():
    print(f"‚ùå Test data not found at {TEST_DATA_PATH}")
    print(f"üí° RunPod: Upload to /workspace/data/test.jsonl")
    print(f"üí° Local: Place in ../data/test.jsonl or notebooks/test.jsonl")
    raise FileNotFoundError(f"Test data file not found: {TEST_DATA_PATH}")

print("‚úì Configuration loaded")
print(f"  Base model: {BASE_MODEL}")
print(f"  Adapter source: {'HuggingFace Hub' if USE_HF_HUB else 'Local filesystem'}")
print(f"  Adapter path: {ADAPTER_PATH}")
print(f"  Test data: {TEST_DATA_PATH}")
print(f"  Test data exists: {TEST_DATA_PATH.exists()}")

## 4. Authenticate with Hugging Face

Log into Hugging Face to download the LoRA adapter when `USE_HF_HUB` is enabled.

In [None]:
# Login to HuggingFace Hub to access your model
import os
from huggingface_hub import login

hf_token = os.environ.get("HF_TOKEN")

if not hf_token:
    print("‚ùå HF_TOKEN not found in environment")
    print("   Please run cell #3 first to set your HF token")
    raise ValueError("HF_TOKEN is required to download model from HuggingFace Hub")

# Login to HuggingFace
login(token=hf_token, add_to_git_credential=True)

print("‚úì Logged into Hugging Face Hub")
print(f"  Will load model from: {HF_MODEL_ID}")

## 5. Load the Fine-Tuned Model

Load the base model and attach the LoRA adapter from either Hugging Face Hub or your local filesystem.

**Note**: Using `hf_transfer` for faster downloads from HuggingFace Hub.

In [None]:
# Ensure hf_transfer is enabled for faster downloads
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# Load the fine-tuned model for inference
print("="*80)
print("LOADING FINE-TUNED MODEL")
print("="*80)

print(f"\nLoading base model: {BASE_MODEL}...")

#Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token

print(f"‚úì Tokenizer loaded")

# Check for GPU support (optimized for RunPod/CUDA)
if torch.cuda.is_available():
    device = "cuda"
    print(f"üöÄ NVIDIA GPU detected: {torch.cuda.get_device_name(0)}")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
elif torch.backends.mps.is_available():
    device = "mps"
    print(f"üöÄ Apple Silicon GPU (MPS) detected")
else:
    device = "cpu"
    print(f"‚ö†Ô∏è  No GPU detected, using CPU (very slow)")

# Load base model with GPU acceleration
# On RunPod: Uses CUDA with float16 for optimal performance
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map="auto",  # Automatically distribute model across available GPUs
    low_cpu_mem_usage=True,
)

print(f"\n‚úì Base model loaded: {BASE_MODEL}")
print(f"  Device: {device.upper()}")
print(f"  Precision: {base_model.dtype}")
if device == "cuda":
    print(f"  GPU Memory Used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

# Load LoRA adapter from HuggingFace Hub or local path
print(f"\nLoading LoRA adapter from: {ADAPTER_PATH}...")
print(f"  Using hf_transfer for faster downloads...")

model = PeftModel.from_pretrained(
    base_model,
    ADAPTER_PATH,
)
model.eval()

print(f"\n‚úì Fine-tuned model loaded successfully!")
print(f"  Base: {BASE_MODEL}")
print(f"  LoRA adapter: {ADAPTER_PATH}")
print(f"  Source: {'HuggingFace Hub' if USE_HF_HUB else 'Local filesystem'}")

In [None]:
# ===== Deterministic generation for evaluation =====
def generate_response(prompt_text, max_new_tokens=128):
    """
    Generate a response for a given prompt - DETERMINISTIC for precision.
    
    CRITICAL: Uses SIMPLE system prompt that matches training format exactly.
    This ensures training-inference consistency for optimal performance.
    
    Generation parameters:
    - do_sample=False: Greedy decoding prevents hallucinations
    - temperature=0.0: No randomness for reproducible results
    - Removes sampling parameters (top_k, top_p)
    """
    formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a medical NER expert. Extract the requested entities from medical texts accurately.<|eot_id|><|start_header_id|>user<|end_header_id|>

{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""
    
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Greedy decoding (deterministic)
            temperature=0.0,  # No randomness
            top_p=1.0,  # Not used with do_sample=False, but set for clarity
            num_beams=1,  # No beam search (faster)
            repetition_penalty=1.15,  # Slight penalty to avoid repetition
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            use_cache=True,  # Enable KV cache for faster generation
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract assistant's response
    if "assistant\n\n" in response:
        response = response.split("assistant\n\n")[-1]
    elif "assistant" in response:
        response = response.split("assistant")[-1].strip()
    
    return response.strip()

print("‚úì Deterministic inference function ready")
print("  System prompt: MATCHES training format exactly (simple NER instructions)")
print("\n  Generation parameters:")
print("    - do_sample: False (greedy decoding)")
print("    - temperature: 0.0 (no randomness)")
print("    - max_new_tokens: 128 (optimal for NER tasks)")
print("    - use_cache: True (KV cache for speed)")
print("\n  Benefits:")
print("    - Training-inference consistency (same prompt structure)")
print("    - Reproducible results (same input ‚Üí same output)")
print("    - Reduced hallucinations and false positives")
print("    - Faster inference (no sampling overhead)")

## üîç System Prompt - Training-Inference Consistency

**CRITICAL**: The evaluation prompt must match the training prompt for optimal performance.

### Training Prompt (from Medical_NER_Fine_Tuning.ipynb):
```
You are a medical NER expert. Extract the requested entities from medical texts accurately.
```

### Why Simple Prompt Works:
- **Training used this exact format** ‚Üí Model learned these specific instructions
- **No complex rules** ‚Üí Model relies on training data patterns, not verbose instructions
- **Clean output** ‚Üí Trained to produce bullet lists without explanations

### Training-Inference Mismatch Problems:
- ‚ùå Enhanced prompt during eval ‚â† simple prompt during training = confusion
- ‚ùå Model hasn't seen complex instructions ‚Üí may ignore or misinterpret them
- ‚ùå Different prompt structure ‚Üí activates different model behaviors

### Our Solution:
- ‚úÖ Use identical system prompt in training AND evaluation
- ‚úÖ Model sees familiar instruction pattern it was trained on
- ‚úÖ Activates learned behaviors consistently

**Result**: More accurate evaluation that reflects true model performance.

## 6. Task Classification and Post-Filters

These functions classify tasks from prompts and filter predictions to ensure they appear in the source text, reducing false positives.

In [None]:
# ===== Task classification and post-filters =====

# Task classifier
def task_from_prompt(prompt: str) -> str:
    """Classify task type from prompt text."""
    p = normalize_text(prompt)
    # Check for influences FIRST because those prompts also contain "chemicals" and "diseases"
    if "influences between" in p or "list of extracted influences" in p:
        return "influences"
    if "list of extracted chemicals" in p or "chemicals mentioned" in p:
        return "chemicals"
    if "list of extracted diseases" in p or "diseases mentioned" in p:
        return "diseases"
    return "other"

# Entity extraction and filtering
def extract_list_from_generation(gen_text):
    """Parse bullets from the model output."""
    return parse_bullets(gen_text)

def filter_items_against_text(pred_items, prompt_text):
    """
    Keep only items that appear in the source text (after normalization). Deduplicate.
    
    Enhanced with data exploration insights:
    - Strict word boundary matching (~459 hyphenated entities)
    - Preserves multi-word entities (avg 1.7 words for diseases)
    - Handles special characters (13 types found)
    """
    keep = []
    for it in pred_items:
        if in_text(it, prompt_text):
            keep.append(normalize_item(it))
    return unique_preserve_order(keep)

# ENHANCED: Strict filtering with confidence scoring
def strict_filter_items_against_text(pred_items, prompt_text, min_length=2):
    """
    Stricter filtering to reduce false positives.
    
    Based on data exploration insights:
    - Filters very short entities (likely fragments)
    - Requires strict word boundaries
    - Validates against known entity complexity (avg 11.1 chars for chemicals, 14.9 for diseases)
    
    Args:
        pred_items: Predicted entity list
        prompt_text: Source text to verify against
        min_length: Minimum entity length (default 2 to avoid single chars)
    """
    keep = []
    for it in pred_items:
        normalized = normalize_item(it)
        # Skip very short entities (likely noise or fragments)
        if len(normalized) < min_length:
            continue
        # Strict word boundary check
        if in_text(it, prompt_text):
            keep.append(normalized)
    return unique_preserve_order(keep)

# Influences/Relationships - parse as pairs or sentences
def parse_pairs(gen_text):
    """Parse 'chemical | disease' pairs from generation output."""
    pairs = []
    for line in parse_bullets(gen_text):
        parts = [p.strip() for p in line.split("|")]
        if len(parts)==2:
            pairs.append(tuple(parts))
    return unique_preserve_order(pairs)

def parse_pairs_from_sentence(gen_text):
    """Parse sentence format: 'Chemical X influences disease Y' from generation."""
    pairs = []
    for line in parse_bullets(gen_text):
        # Match pattern: "Chemical NAME influences disease NAME"
        m = re.match(r'^\s*chemical\s+(.+?)\s+influences\s+disease\s+(.+?)\s*$', line, re.I)
        if m:
            pairs.append((m.group(1).strip(), m.group(2).strip()))
    return unique_preserve_order(pairs)

def filter_pairs_against_text(pairs, prompt_text):
    """Keep the pair only if BOTH sides appear in the prompt."""
    kept = []
    for chem, dis in pairs:
        if in_text(chem, prompt_text) and in_text(dis, prompt_text):
            kept.append((normalize_item(chem), normalize_item(dis)))
    # Deduplicate normalized pairs
    seen=set(); out=[]
    for p in kept:
        if p not in seen:
            seen.add(p); out.append(p)
    return out

# ENHANCED: Fuzzy matching for minor variations
from difflib import SequenceMatcher

def fuzzy_match(pred, gold, threshold=0.9):
    """
    Allow minor typos or formatting differences.
    
    Based on data exploration:
    - 13 types of special characters may cause exact match failures
    - Hyphen variations (~459 entities)
    - Capitalization differences
    
    Args:
        pred: Predicted entity
        gold: Gold standard entity
        threshold: Similarity threshold (0.9 = 90% match)
    """
    return SequenceMatcher(None, pred.lower(), gold.lower()).ratio() > threshold

def enhanced_match_with_fuzzy(pred_set, gold_set, threshold=0.9):
    """
    Match predictions with gold using both exact and fuzzy matching.
    Returns true positives using flexible matching.
    """
    tp_exact = pred_set & gold_set
    
    # For remaining predictions, try fuzzy matching
    remaining_pred = pred_set - gold_set
    remaining_gold = gold_set - tp_exact
    
    tp_fuzzy = set()
    for pred in remaining_pred:
        for gold in remaining_gold:
            if fuzzy_match(pred, gold, threshold):
                tp_fuzzy.add(pred)
                remaining_gold.discard(gold)
                break
    
    return tp_exact | tp_fuzzy

print("‚úì Task classification and filter functions loaded")
print("\n  Active Functions (for 3-task dataset):")
print("    - task_from_prompt(): Classify chemicals, diseases, or influences")
print("    - filter_items_against_text(): Keep only entities in source text")
print("    - strict_filter_items_against_text(): Stricter filtering for FP reduction")
print("    - parse_pairs() / parse_pairs_from_sentence(): Parse relationship pairs")
print("    - filter_pairs_against_text(): Keep pairs where both sides exist")
print("\n  ENHANCED Functions (based on data exploration):")
print("    - fuzzy_match(): Handle minor variations (hyphens, special chars)")
print("    - enhanced_match_with_fuzzy(): Flexible matching for evaluation")
print("\n  Addresses data exploration findings:")
print("    - ~459 hyphenated entities (strict preservation)")
print("    - 13 special character types (flexible matching)")
print("    - Avg 1.7 words for diseases (multi-word validation)")
print("    - Relationship parsing for 'Chemical X influences disease Y' format")

## 7. Evaluate on the Held-Out Test Set

Run inference on the test set with deterministic generation and post-filters.

**Key Features**:
- **Deterministic generation**: No sampling (do_sample=False)
- **Post-filters**: Keep only entities that appear in source text
- **Per-task metrics**: Separate P/R/F1 for chemicals, diseases, influences
- **Sanity checks**: Show examples of false positives and false negatives

## üîß Critical Fixes Applied

**Format Mismatch Issue Resolved:**

The test data uses OLD format for influences:
```
"- chemical cyclophosphamide influences disease urinary bladder cancer"
```

But the model may output NEW format:
```
"- cyclophosphamide | urinary bladder cancer"
```

**Solution:** The evaluation now handles BOTH formats automatically by:
1. Parsing gold data from OLD sentence format
2. Trying to parse model output from NEW format first, then OLD format as fallback
3. Normalizing both to `"chemical | disease"` format for comparison

This ensures accurate metrics regardless of which format the model learned!

In [None]:
# ===== Evaluation with per-task metrics and filters =====
from statistics import mean

def f1(p, r): 
    return 0.0 if (p+r)==0 else 2*p*r/(p+r)

# Load test data
with open(TEST_DATA_PATH, 'r', encoding='utf-8') as f:
    test_data = [json.loads(line) for line in f]

print(f"‚úì Loaded test set: {len(test_data)} samples")
print(f"\n‚ö†Ô∏è  IMPORTANT:")
print(f"  - Training set (80%): Used for fine-tuning")
print(f"  - Validation set (10%): Monitored during training (W&B)")
print(f"  - Test set (10%): Used ONLY NOW for final evaluation")
print(f"\nRunning evaluation with deterministic generation + post-filters...")

# Initialize per-task counters
gold_total = {"chemicals":0, "diseases":0, "influences":0}
pred_total = {"chemicals":0, "diseases":0, "influences":0}
tp_total   = {"chemicals":0, "diseases":0, "influences":0}

examples_fp = []  # False positives
examples_fn = []  # False negatives

# Process each test sample
for idx, row in enumerate(test_data):
    if (idx + 1) % 50 == 0:
        print(f"  Progress: {idx + 1}/{len(test_data)} samples...")
    
    prompt = row["prompt"]
    gold_items = [normalize_item(x) for x in parse_bullets(row.get("completion",""))]
    task = task_from_prompt(prompt)
    
    # Generate prediction
    gen = generate_response(prompt, max_new_tokens=128)
    pred_raw = extract_list_from_generation(gen)
    
    # Apply filters based on task type
    if task in {"chemicals", "diseases"}:
        pred = filter_items_against_text(pred_raw, prompt)
    elif task == "influences":
        # Parse gold data (format: "Chemical X influences disease Y")
        gold_pairs = []
        for item in parse_bullets(row.get("completion","")):
            # Parse sentence format
            m = re.match(r'^\s*chemical\s+(.+?)\s+influences\s+disease\s+(.+?)\s*$', item, re.I)
            if m:
                chem = normalize_item(m.group(1))
                dis = normalize_item(m.group(2))
                gold_pairs.append(f"{chem} | {dis}")
        gold_items = gold_pairs
        
        # Parse model output (could be sentence format OR pipe-separated)
        pairs_sentence = parse_pairs_from_sentence(gen)  # Try sentence format first
        pairs_pipe = parse_pairs(gen)  # Try pipe format as fallback
        all_pairs = pairs_sentence if pairs_sentence else pairs_pipe
        
        # Normalize both sides of the pair for consistent comparison
        pred = [f"{normalize_item(c)} | {normalize_item(d)}" 
                for (c,d) in filter_pairs_against_text(all_pairs, prompt)]
    else:
        pred = []
    
    # Convert to sets for metrics
    gs = set(gold_items)
    ps = set(pred)
    
    tp = len(gs & ps)
    fp = len(ps - gs)
    fn = len(gs - ps)
    
    gold_total[task] += len(gs)
    pred_total[task] += len(ps)
    tp_total[task]   += tp
    
    # Collect examples for analysis
    if fp and len(examples_fp) < 8:
        examples_fp.append({
            "task": task,
            "prompt_preview": prompt[:120]+"...",
            "pred_extras": list(ps-gs)[:5]
        })
    if fn and len(examples_fn) < 8:
        examples_fn.append({
            "task": task,
            "prompt_preview": prompt[:120]+"...",
            "missed": list(gs-ps)[:5]
        })

print(f"\n‚úì Evaluation complete!")
print(f"\n{'='*80}")
print("PER-TASK METRICS (with post-filters)")
print(f"{'='*80}\n")

# Calculate and display metrics for each task
for t in ["chemicals", "diseases", "influences"]:
    P = 0.0 if pred_total[t]==0 else tp_total[t]/pred_total[t]
    R = 0.0 if gold_total[t]==0 else tp_total[t]/gold_total[t]
    F = f1(P,R)
    print(f"{t.upper()}")
    print(f"  Precision: {P*100:5.1f}%  (TP={tp_total[t]}, Pred={pred_total[t]})")
    print(f"  Recall:    {R*100:5.1f}%  (TP={tp_total[t]}, Gold={gold_total[t]})")
    print(f"  F1 Score:  {F*100:5.1f}%")
    print()

# Overall metrics
total_tp = sum(tp_total.values())
total_pred = sum(pred_total.values())
total_gold = sum(gold_total.values())
overall_P = 0.0 if total_pred==0 else total_tp/total_pred
overall_R = 0.0 if total_gold==0 else total_tp/total_gold
overall_F = f1(overall_P, overall_R)

print(f"{'='*80}")
print("OVERALL METRICS")
print(f"{'='*80}")
print(f"  Precision: {overall_P*100:5.1f}%")
print(f"  Recall:    {overall_R*100:5.1f}%")
print(f"  F1 Score:  {overall_F*100:5.1f}%")
print(f"\n  Total TP: {total_tp}, Total Pred: {total_pred}, Total Gold: {total_gold}")

# Show example errors
if examples_fp:
    print(f"\n{'='*80}")
    print("EXAMPLE FALSE POSITIVES (model predicted, but not in gold)")
    print(f"{'='*80}")
    for e in examples_fp[:5]:
        print(f"\nTask: {e['task']}")
        print(f"Prompt: {e['prompt_preview']}")
        print(f"Extra predictions: {e['pred_extras']}")

if examples_fn:
    print(f"\n{'='*80}")
    print("EXAMPLE FALSE NEGATIVES (in gold, but model missed)")
    print(f"{'='*80}")
    for e in examples_fn[:5]:
        print(f"\nTask: {e['task']}")
        print(f"Prompt: {e['prompt_preview']}")
        print(f"Missed items: {e['missed']}")

## üìä Test Set Distribution Validation

Verify test set matches expected characteristics from data exploration.

In [None]:
# ===== Validate test set distribution against exploration findings =====
from collections import Counter

print("="*80)
print("TEST SET VALIDATION (vs. Exploration Expectations)")
print("="*80)

# Get task distribution
test_tasks = [task_from_prompt(row["prompt"]) for row in test_data]
task_dist = Counter(test_tasks)

print(f"\nüìä Test Set Size: {len(test_data)} samples")
print(f"   Expected: ~300 samples (10% of 3,000) ‚úì" if 250 <= len(test_data) <= 350 else f"   ‚ö†Ô∏è Size anomaly detected!")

print(f"\nüìä Task Distribution:")
for task in ["chemicals", "diseases", "influences"]:
    count = task_dist.get(task, 0)
    pct = count / len(test_data) * 100 if len(test_data) > 0 else 0
    expected_pct = 33.3
    status = "‚úì" if abs(pct - expected_pct) < 5 else "‚ö†Ô∏è"
    print(f"   {task.capitalize()}: {count} samples ({pct:.1f}%) {status}")
    print(f"      Expected: ~33.3% (stratified split)")

# Count unique entities in test completions
print(f"\nüìä Entity Statistics (Test Set):")
all_chemicals = set()
all_diseases = set()
all_relationships = 0

for row in test_data:
    task = task_from_prompt(row["prompt"])
    items = [normalize_item(x) for x in parse_bullets(row.get("completion",""))]
    
    if task == "chemicals":
        all_chemicals.update(items)
    elif task == "diseases":
        all_diseases.update(items)
    elif task == "influences":
        all_relationships += len(items)

print(f"   Unique chemicals in test: {len(all_chemicals)}")
print(f"      (Full dataset has ~1,578 unique chemicals)")
print(f"   Unique diseases in test: {len(all_diseases)}")
print(f"      (Full dataset has ~2,199 unique diseases)")
print(f"   Total relationships in test: {all_relationships}")

print(f"\n‚úì Test set validation complete!")
print(f"  Ready for evaluation with calibrated expectations.")
print("="*80)

## 8. Custom Test Cases ‚Äî Comprehensive NER Evaluation

Test the model's ability to:
1. **Extract Chemicals** - Identify drug names and chemical compounds
2. **Extract Diseases** - Identify medical conditions and diseases
3. **Extract Relationships** - Identify which chemicals are related to which diseases

In [None]:
# Test 1: Chemical Extraction
print("="*80)
print("TEST 1: CHEMICAL EXTRACTION")
print("="*80)

chemical_test = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the chemicals mentioned.

A patient was treated with aspirin and ibuprofen for pain relief. The combination of these NSAIDs proved effective in reducing inflammation. Additionally, metformin was prescribed for glucose control.

List of extracted chemicals:
"""

print(f"\nüìù Prompt:\n{chemical_test}")
print("\nü§ñ Model Output:")
print(generate_response(chemical_test))

In [None]:
# Test 2: Disease Extraction
print("\n" + "="*80)
print("TEST 2: DISEASE EXTRACTION")
print("="*80)

disease_test = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the diseases mentioned.

The patient presented with hypertension, diabetes mellitus, and chronic kidney disease. Laboratory findings revealed proteinuria and elevated creatinine levels, suggesting diabetic nephropathy.

List of extracted diseases:
"""

print(f"\nüìù Prompt:\n{disease_test}")
print("\nü§ñ Model Output:")
print(generate_response(disease_test))

# Test 3: Chemical-Disease Relationship Extraction
print("\n" + "="*80)
print("TEST 3: RELATIONSHIP EXTRACTION - BASIC")
print("="*80)

relationship_test_1 = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the influences between the chemicals and diseases mentioned.

Metformin is commonly prescribed for type 2 diabetes by improving insulin sensitivity and reducing hepatic glucose production. Aspirin is used in cardiovascular disease management in high-risk patients.

List of extracted influences:
"""

print(f"\nüìù Prompt:\n{relationship_test_1}")
print("\nü§ñ Model Output:")
print(generate_response(relationship_test_1, max_new_tokens=600))

# Test 4: Multiple Relationship Extraction
print("\n" + "="*80)
print("TEST 4: RELATIONSHIP EXTRACTION - MULTIPLE PAIRS")
print("="*80)

relationship_test_2 = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the influences between the chemicals and diseases mentioned.

Long-term use of corticosteroids is associated with osteoporosis and increases the risk of bone fractures. NSAIDs are linked to chronic kidney disease and gastrointestinal bleeding in susceptible patients.

List of extracted influences:
"""

print(f"\nüìù Prompt:\n{relationship_test_2}")
print("\nü§ñ Model Output:")
print(generate_response(relationship_test_2, max_new_tokens=600))

# Test 5: Complex Multi-Entity Relationship Extraction
print("\n" + "="*80)
print("TEST 5: COMPREHENSIVE EXTRACTION - ALL ENTITIES & RELATIONSHIPS")
print("="*80)

relationship_test_3 = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the influences between the chemicals and diseases mentioned.

The patient with rheumatoid arthritis was started on methotrexate for inflammatory joint disease. However, methotrexate is associated with hepatotoxicity and requires monitoring. The patient also has hypertension managed with lisinopril. Statins were prescribed for cardiovascular disease prevention given elevated cholesterol levels.

List of extracted influences:
"""

print(f"\nüìù Prompt:\n{relationship_test_3}")
print("\nü§ñ Model Output:")
print(generate_response(relationship_test_3, max_new_tokens=800))

## üîç Enhanced False Positive Analysis

Based on data exploration insights, this section provides diagnostic tools to categorize and analyze false positives:

### Known Risk Factors from Data Exploration:
1. **Hyphen Variations** (~459 hyphenated entities)
   - Examples: "type-2" vs "type 2", "5-fluorouracil" vs "5 fluorouracil"
   
2. **Multi-word Partial Extraction** (avg 1.7 words for diseases)
   - Examples: Predicting "disease" instead of "chronic kidney disease"
   
3. **Special Character Mismatches** (13 types found)
   - Examples: "COVID-19" vs "COVID19", parentheses, slashes
   
4. **Format Confusion** (2,050 relationships in OLD format)
   - Model outputs sentence format when pipe-separated expected
   
5. **Synonym Generation** (13,710 vocabulary words)
   - Examples: "myocardial infarction" vs "heart attack"

### Diagnostic Approach:
- Categorize each false positive by likely root cause
- Compute statistics per category
- Show representative examples for targeted improvements

In [None]:
# ===== False Positive Categorization Functions =====

def has_hyphen_variation(pred, gold_set):
    """
    Check if FP is due to hyphen differences.
    Based on data exploration: ~459 hyphenated entities at risk.
    """
    # Try replacing hyphens with spaces and vice versa
    pred_no_hyphen = pred.replace('-', ' ')
    pred_with_hyphen = pred.replace(' ', '-')
    
    for gold in gold_set:
        if pred_no_hyphen.lower() == gold.lower():
            return True, gold
        if pred_with_hyphen.lower() == gold.lower():
            return True, gold
    return False, None

def is_partial_multiword(pred, gold_set):
    """
    Check if FP is a partial extraction of a multi-word entity.
    Based on data exploration: avg 1.7 words for diseases, 1.2 for chemicals.
    """
    pred_lower = pred.lower()
    for gold in gold_set:
        gold_lower = gold.lower()
        # Check if prediction is a substring of gold (or vice versa)
        if pred_lower in gold_lower and pred_lower != gold_lower:
            return True, gold
        if gold_lower in pred_lower and pred_lower != gold_lower:
            return True, gold
    return False, None

def has_special_char_mismatch(pred, gold_set):
    """
    Check if FP is due to special character differences.
    Based on data exploration: 13 types of special characters found.
    """
    import re
    # Remove all non-alphanumeric except spaces
    pred_clean = re.sub(r'[^a-zA-Z0-9\s]', '', pred)
    
    for gold in gold_set:
        gold_clean = re.sub(r'[^a-zA-Z0-9\s]', '', gold)
        if pred_clean.lower() == gold_clean.lower():
            return True, gold
    return False, None

def is_likely_synonym(pred, gold_set, threshold=0.7):
    """
    Check if FP might be a synonym or alternative term.
    Based on data exploration: 13,710 vocabulary words allow many alternatives.
    """
    from difflib import SequenceMatcher
    
    pred_lower = pred.lower()
    for gold in gold_set:
        gold_lower = gold.lower()
        # Check for high similarity (but not exact match)
        similarity = SequenceMatcher(None, pred_lower, gold_lower).ratio()
        if 0.5 < similarity < 1.0 and similarity > threshold:
            return True, gold, similarity
    return False, None, 0.0

def categorize_false_positive(fp, gold_set):
    """
    Categorize a false positive by likely root cause.
    
    Returns:
        category: str - Type of FP
        matched_gold: str or None - Gold entity it's similar to
        details: dict - Additional diagnostic information
    """
    # Check each category in order
    is_hyphen, gold1 = has_hyphen_variation(fp, gold_set)
    if is_hyphen:
        return "hyphen_variation", gold1, {"original": fp, "gold": gold1}
    
    is_partial, gold2 = is_partial_multiword(fp, gold_set)
    if is_partial:
        return "partial_multiword", gold2, {"original": fp, "gold": gold2}
    
    is_special, gold3 = has_special_char_mismatch(fp, gold_set)
    if is_special:
        return "special_char", gold3, {"original": fp, "gold": gold3}
    
    is_synonym, gold4, sim = is_likely_synonym(fp, gold_set)
    if is_synonym:
        return "synonym", gold4, {"original": fp, "gold": gold4, "similarity": sim}
    
    # If none of the above, it's a true false positive (hallucination)
    return "hallucination", None, {"original": fp}

def analyze_false_positives(fps, gold_set):
    """
    Analyze all false positives and categorize them.
    
    Returns:
        dict: Statistics and examples for each FP category
    """
    categories = {
        "hyphen_variation": [],
        "partial_multiword": [],
        "special_char": [],
        "synonym": [],
        "hallucination": []
    }
    
    for fp in fps:
        category, matched_gold, details = categorize_false_positive(fp, gold_set)
        categories[category].append(details)
    
    # Compute statistics
    stats = {
        cat: {
            "count": len(items),
            "percentage": len(items) / len(fps) * 100 if fps else 0,
            "examples": items[:5]  # First 5 examples
        }
        for cat, items in categories.items()
    }
    
    return stats

print("‚úì False positive categorization functions loaded")
print("\n  Available Functions:")
print("    - has_hyphen_variation(): Check for hyphen differences")
print("    - is_partial_multiword(): Check for incomplete multi-word extraction")
print("    - has_special_char_mismatch(): Check for special character issues")
print("    - is_likely_synonym(): Check for synonym/alternative terms")
print("    - categorize_false_positive(): Classify single FP")
print("    - analyze_false_positives(): Batch analysis with statistics")

In [None]:
# ===== Example: Analyze False Positives from Test Results =====

# This cell demonstrates FP analysis on mock data
# Replace with actual predictions from your evaluation results

# Mock example data (replace with actual data from evaluation cells above)
example_fps = [
    "type 2 diabetes",      # Hyphen variation of "type-2 diabetes"
    "disease",              # Partial of "chronic kidney disease"
    "COVID19",              # Special char of "COVID-19"
    "heart attack",         # Synonym of "myocardial infarction"
    "unicorn syndrome"      # Hallucination (not in gold)
]

example_gold = {
    "type-2 diabetes",
    "chronic kidney disease",
    "COVID-19",
    "myocardial infarction",
    "hypertension"
}

# Run analysis
fp_stats = analyze_false_positives(example_fps, example_gold)

# Display results
print("=" * 80)
print("FALSE POSITIVE ANALYSIS REPORT")
print("=" * 80)
print(f"\nTotal False Positives: {len(example_fps)}")
print("\n" + "-" * 80)

for category, data in fp_stats.items():
    if data["count"] > 0:
        print(f"\nüìä {category.upper().replace('_', ' ')}")
        print(f"   Count: {data['count']} ({data['percentage']:.1f}%)")
        print(f"   Examples:")
        for ex in data["examples"]:
            if "gold" in ex:
                print(f"      ‚Ä¢ Predicted: '{ex['original']}' ‚Üí Gold: '{ex['gold']}'")
                if "similarity" in ex:
                    print(f"        (Similarity: {ex['similarity']:.2f})")
            else:
                print(f"      ‚Ä¢ Predicted: '{ex['original']}' (No match in gold)")

print("\n" + "=" * 80)
print("\nüí° Usage Instructions:")
print("   1. After running evaluation cells above, collect actual false positives")
print("   2. Replace 'example_fps' with your actual FP list")
print("   3. Replace 'example_gold' with your gold standard entity set")
print("   4. Re-run this cell to see real FP breakdown")
print("\n   Example:")
print("      # After evaluation")
print("      actual_fps = pred_set - gold_set  # Your false positives")
print("      fp_stats = analyze_false_positives(list(actual_fps), gold_set)")

## üîß Training Data Format Verification

Based on data exploration: **2,050 relationships (68%)** were in OLD sentence format.

This cell verifies that the format conversion worked correctly during training data preparation. Format confusion is a major source of false positives if the model learned inconsistent relationship representations.

In [None]:
# ===== Verify Training Data Format Conversion =====

import json
import re

def check_format_conversion(train_file_path, num_samples=10):
    """
    Verify that relationship format conversion was successful.
    
    Based on data exploration:
    - 2,050 relationships (68%) were in OLD sentence format
    - Should have been converted to NEW pipe-separated format
    """
    print("=" * 80)
    print("TRAINING DATA FORMAT VERIFICATION")
    print("=" * 80)
    
    # Counters
    old_format_count = 0
    new_format_count = 0
    samples_checked = 0
    examples = []
    
    try:
        with open(train_file_path, 'r') as f:
            for idx, line in enumerate(f):
                if idx >= num_samples:
                    break
                    
                data = json.loads(line.strip())
                output = data.get('output', '')
                samples_checked += 1
                
                # Check for OLD format pattern: "chemical X influences disease Y"
                old_pattern = re.findall(r'chemical\s+.+?\s+influences\s+disease\s+.+', output, re.I)
                # Check for NEW format pattern: "chemical | disease"
                new_pattern = re.findall(r'.+?\s*\|\s*.+', output)
                
                if old_pattern:
                    old_format_count += 1
                    examples.append({
                        'sample': idx + 1,
                        'format': 'OLD',
                        'snippet': old_pattern[0][:100] + '...' if len(old_pattern[0]) > 100 else old_pattern[0]
                    })
                elif new_pattern and '|' in output:
                    new_format_count += 1
                    examples.append({
                        'sample': idx + 1,
                        'format': 'NEW',
                        'snippet': new_pattern[0][:100] + '...' if len(new_pattern[0]) > 100 else new_pattern[0]
                    })
        
        # Display results
        print(f"\nüìä Format Statistics (from {samples_checked} samples):")
        print(f"   OLD format (sentence): {old_format_count}")
        print(f"   NEW format (pipe):     {new_format_count}")
        print(f"   Other/None:            {samples_checked - old_format_count - new_format_count}")
        
        if old_format_count > 0:
            print("\n‚ö†Ô∏è  WARNING: Found OLD format relationships in training data!")
            print("   This may cause format confusion during inference.")
        else:
            print("\n‚úì All relationship samples use consistent NEW (pipe) format")
        
        print("\nüìù Sample Outputs:")
        for ex in examples[:5]:
            print(f"   Sample {ex['sample']} ({ex['format']}): {ex['snippet']}")
        
    except FileNotFoundError:
        print(f"\n‚ùå Training file not found: {train_file_path}")
        print("   Please update the path to your training data file.")
    except Exception as e:
        print(f"\n‚ùå Error reading training data: {e}")
    
    print("\n" + "=" * 80)
    return {
        'old_format': old_format_count,
        'new_format': new_format_count,
        'total_checked': samples_checked
    }

# Run verification
# Update path to your actual training data file
train_data_path = "../data/train.jsonl"
format_stats = check_format_conversion(train_data_path, num_samples=50)

print("\nüí° Interpretation:")
print("   ‚Ä¢ If OLD format found: Re-prepare training data with consistent NEW format")
print("   ‚Ä¢ If all NEW format: Format conversion was successful")
print("   ‚Ä¢ Mixed formats may cause model to output inconsistent relationship formats")

## üìà Comparative Evaluation: Standard vs Enhanced Filtering

This section compares evaluation metrics using:
1. **Standard Filtering**: Basic text verification (current approach)
2. **Strict Filtering**: Enhanced word boundaries + length requirements
3. **Fuzzy Matching**: Allows minor variations (hyphens, special chars)

Goal: Quantify the impact of enhanced filtering on false positive reduction.

In [None]:
# ===== Comparative Evaluation Function =====

def compare_filtering_strategies(predictions, gold_standard, prompt_text):
    """
    Compare different filtering strategies on the same predictions.
    
    Args:
        predictions: List of predicted entities (before filtering)
        gold_standard: Set of gold standard entities
        prompt_text: Source text for verification
    
    Returns:
        dict: Metrics for each strategy
    """
    from sklearn.metrics import precision_score, recall_score, f1_score
    
    results = {}
    
    # Strategy 1: Standard filtering
    standard_filtered = filter_items_against_text(predictions, prompt_text)
    standard_set = set(standard_filtered)
    
    tp_standard = len(standard_set & gold_standard)
    fp_standard = len(standard_set - gold_standard)
    fn_standard = len(gold_standard - standard_set)
    
    results['standard'] = {
        'filtered_count': len(standard_set),
        'tp': tp_standard,
        'fp': fp_standard,
        'fn': fn_standard,
        'precision': tp_standard / len(standard_set) if standard_set else 0,
        'recall': tp_standard / len(gold_standard) if gold_standard else 0,
        'f1': 2 * tp_standard / (2 * tp_standard + fp_standard + fn_standard) if (2 * tp_standard + fp_standard + fn_standard) > 0 else 0,
        'fp_examples': list(standard_set - gold_standard)[:5]
    }
    
    # Strategy 2: Strict filtering
    strict_filtered = strict_filter_items_against_text(predictions, prompt_text, min_length=2)
    strict_set = set(strict_filtered)
    
    tp_strict = len(strict_set & gold_standard)
    fp_strict = len(strict_set - gold_standard)
    fn_strict = len(gold_standard - strict_set)
    
    results['strict'] = {
        'filtered_count': len(strict_set),
        'tp': tp_strict,
        'fp': fp_strict,
        'fn': fn_strict,
        'precision': tp_strict / len(strict_set) if strict_set else 0,
        'recall': tp_strict / len(gold_standard) if gold_standard else 0,
        'f1': 2 * tp_strict / (2 * tp_strict + fp_strict + fn_strict) if (2 * tp_strict + fp_strict + fn_strict) > 0 else 0,
        'fp_examples': list(strict_set - gold_standard)[:5],
        'fp_reduction': fp_standard - fp_strict
    }
    
    # Strategy 3: Fuzzy matching (on standard filtered)
    tp_fuzzy_set = enhanced_match_with_fuzzy(standard_set, gold_standard, threshold=0.9)
    fp_fuzzy = len(standard_set - tp_fuzzy_set)
    fn_fuzzy = len(gold_standard - tp_fuzzy_set)
    
    results['fuzzy'] = {
        'filtered_count': len(standard_set),
        'tp': len(tp_fuzzy_set),
        'fp': fp_fuzzy,
        'fn': fn_fuzzy,
        'precision': len(tp_fuzzy_set) / len(standard_set) if standard_set else 0,
        'recall': len(tp_fuzzy_set) / len(gold_standard) if gold_standard else 0,
        'f1': 2 * len(tp_fuzzy_set) / (2 * len(tp_fuzzy_set) + fp_fuzzy + fn_fuzzy) if (2 * len(tp_fuzzy_set) + fp_fuzzy + fn_fuzzy) > 0 else 0,
        'fp_examples': list(standard_set - tp_fuzzy_set)[:5],
        'fp_reduction': fp_standard - fp_fuzzy
    }
    
    return results

def display_comparison_results(results):
    """Display comparative results in a formatted table."""
    print("=" * 100)
    print("FILTERING STRATEGY COMPARISON")
    print("=" * 100)
    
    print(f"\n{'Strategy':<15} {'Filtered':<10} {'TP':<6} {'FP':<6} {'FN':<6} {'Precision':<12} {'Recall':<12} {'F1':<12} {'FP Reduction':<12}")
    print("-" * 100)
    
    for strategy, metrics in results.items():
        fp_reduction = metrics.get('fp_reduction', 0)
        print(f"{strategy.capitalize():<15} "
              f"{metrics['filtered_count']:<10} "
              f"{metrics['tp']:<6} "
              f"{metrics['fp']:<6} "
              f"{metrics['fn']:<6} "
              f"{metrics['precision']:<12.3f} "
              f"{metrics['recall']:<12.3f} "
              f"{metrics['f1']:<12.3f} "
              f"{fp_reduction:<12}")
    
    print("\n" + "=" * 100)
    
    # Show FP examples for each strategy
    print("\nüìã False Positive Examples by Strategy:")
    for strategy, metrics in results.items():
        if metrics['fp_examples']:
            print(f"\n{strategy.capitalize()} ({metrics['fp']} total FPs):")
            for ex in metrics['fp_examples']:
                print(f"   ‚Ä¢ {ex}")

print("‚úì Comparative evaluation functions loaded")
print("\n  Usage:")
print("    results = compare_filtering_strategies(predictions, gold_set, prompt_text)")
print("    display_comparison_results(results)")

## üìö Summary of False Positive Reduction Improvements

This notebook now includes comprehensive FP reduction strategies based on data exploration insights:

### üéØ Enhancements Added:

1. **Enhanced Post-Filtering** (Cell with `strict_filter_items_against_text`)
   - Stricter word boundary matching
   - Minimum entity length requirements
   - Addresses ~459 hyphenated entities and multi-word complexity
   
2. **Fuzzy Matching Support** (Cell with `fuzzy_match`)
   - Handles minor typos and formatting differences
   - 90% similarity threshold by default
   - Addresses 13 types of special characters
   
3. **False Positive Categorization** (Cell with `categorize_false_positive`)
   - Classifies FPs by root cause:
     * Hyphen variations (~459 at risk)
     * Partial multi-word extraction (1.7 avg words)
     * Special character mismatches (13 types)
     * Synonym generation (13,710 vocab)
     * True hallucinations
   - Provides diagnostic statistics and examples
   
4. **Training Data Format Verification** (Cell with `check_format_conversion`)
   - Validates that 2,050 OLD format relationships were converted
   - Detects format inconsistencies that cause inference errors
   - Sample inspection for quality assurance
   
5. **Comparative Evaluation** (Cell with `compare_filtering_strategies`)
   - Quantifies impact of different filtering approaches
   - Side-by-side metrics: Standard vs Strict vs Fuzzy
   - FP reduction tracking

### üìä Data Exploration Insights Applied:

| Finding | Impact | Solution |
|---------|--------|----------|
| ~459 hyphenated entities | "type-2" vs "type 2" mismatches | Fuzzy matching + strict boundaries |
| Avg 1.7 words (diseases) | Partial extractions | Multi-word validation |
| 13 special char types | Format mismatches | Special char normalization |
| 2,050 OLD format (68%) | Format confusion | Conversion verification |
| 13,710 vocab words | Synonym generation | Synonym detection |

### üöÄ Recommended Workflow:

1. **Run standard evaluation** (existing cells)
2. **Apply strict filtering** to reduce noise
3. **Run FP analysis** to identify root causes
4. **Verify format conversion** if relationship extraction has issues
5. **Compare strategies** to quantify improvements
6. **Apply fuzzy matching** if precision is stable but recall suffers

### üí° Expected Outcomes:

- **Precision improvement**: Strict filtering removes fragment/noise FPs
- **Recall preservation**: Fuzzy matching recovers minor variation mismatches
- **Diagnostic clarity**: FP categorization guides targeted fixes
- **Format consistency**: Conversion verification prevents systematic errors

### ‚öôÔ∏è Next Steps:

1. Run evaluation with your fine-tuned model
2. Collect actual false positives
3. Use `analyze_false_positives()` to categorize them
4. Apply appropriate filtering strategy based on FP breakdown
5. Iterate on training data if format/quality issues detected

## 10. Suggested Next Steps

- **Compare results against exploration baselines**:
  - Reference: [DATA_EXPLORATION_SUMMARY.md](../../docs/DATA_EXPLORATION_SUMMARY.md)
  - Expected vocabulary: 13,710 words (ensure model covers this)
  - Entity complexity: ~459 hyphenated entities (verify preservation)
  - Format handling: 2,050 relationships trained in NEW pipe format

- **Evaluate the full test set** (set `num_test_samples = len(test_data)`) to capture complete performance
  - Test set should have ~300 samples (validated above)
  - Distribution should be 33.3% / 33.3% / 33.3% across tasks (stratified)

- **Compare with the base model** to quantify the lift from fine-tuning
  - Baseline (pre-training): No medical entity recognition
  - Expected improvement: Significant gain in precision/recall on medical terms

- **Log metrics to Weights & Biases** or another tracker for experiment history
  - Compare across different checkpoints (every 50 steps)
  - Track how format conversion affects relationship extraction F1

- **Export predictions for manual spot checks** with subject-matter experts
  - Focus on false positives (predicted but not in gold)
  - Focus on false negatives (in gold but missed)
  - Verify hyphen preservation in entity names

- **Analyze vocabulary coverage**:
  - Expected: 13,710 unique words in training set
  - Check if model generalizes to unseen entity combinations
  - Validate entity complexity handling (multi-word, special chars)

## 11. Usage Example (Optional)

How to load the model in a production script or service.

In [None]:
# Example: How to load and use the model later
usage_code = '''
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct",
    torch_dtype=torch.float16,
    device_map="auto"
)

# Load LoRA adapter from Hub
model = PeftModel.from_pretrained(
    base_model,
    "your-username/llama3-medical-ner-lora"  # Your model ID
)
model.eval()

# Use the model
prompt = """The following article contains technical terms including diseases, drugs and chemicals. Create a list only of the chemicals mentioned.

Patient was treated with metformin and insulin for diabetes management.

List of extracted chemicals:
"""

# Generate response
# ... (use the generate_response function from above)
'''

print("Usage Example:")
print("="*80)
print(usage_code)

---

## Summary

This notebook:
1. ‚úÖ Configured environment variables and authentication for Hugging Face and W&B.
2. ‚úÖ Installed required evaluation dependencies.
3. ‚úÖ Loaded the fine-tuned medical NER model (base + LoRA adapter).
4. ‚úÖ Evaluated performance on unseen test samples with detailed metrics.
5. ‚úÖ Aggregated precision, recall, and F1 across all evaluated examples.
6. ‚úÖ Validated behaviour on curated chemical, disease, and relationship prompts.
7. ‚úÖ Outlined next steps and provided a ready-to-use inference snippet.

**Your medical NER evaluation workflow is ready! üöÄ**