## üöÄ RunPod Quick Start

### Prerequisites:
1. **GPU Pod**: RTX 3090/4090, A5000, or A6000 recommended
2. **VRAM**: Minimum 12GB (24GB recommended for optimal performance)
3. **Test Data**: Upload to `/workspace/data/splits_cleaned_20251113/test.jsonl`

### Execution Steps:
1. Upload this notebook to RunPod pod
2. Upload test data file
3. Set HuggingFace token (Cell 1)
4. Run all cells sequentially
5. Review comprehensive analysis at the end

### Expected Runtime:
- **GPU**: RTX 4090 ‚Üí ~3-5 minutes (299 samples)
- **GPU**: RTX 3090 ‚Üí ~5-8 minutes
- **GPU**: A5000 ‚Üí ~8-12 minutes

---

## üìä Dataset Information

**Using Cleaned Dataset**: `splits_cleaned_20251113/`

### Quality Improvements:
- ‚úÖ **0 empty completions** (removed 6 invalid samples)
- ‚úÖ **0 prompts >2048 chars** (intelligently truncated 307)
- ‚úÖ **99.8% retention** (2,994 of 3,000 samples)
- ‚úÖ **Perfect stratification** (33.3% per task)

### Test Set Details:
- **Total samples**: 299
- **Chemicals**: 100 samples (33.4%)
- **Diseases**: 99 samples (33.1%)
- **Relationships**: 100 samples (33.4%)

### Expected Performance:
- **Target F1**: 70-80% (medical domain model on clean data)
- **Previous baseline** (Llama-3.2-3B): ~54% F1
- **Key improvement**: Relationship extraction (was 0%, now measurable)

---

## 0Ô∏è‚É£ Environment Setup

‚ö†Ô∏è **IMPORTANT**: Set your HuggingFace token before proceeding!

In [None]:
import os
from getpass import getpass

# HuggingFace Token (required for model downloads)
# Get your token from: https://huggingface.co/settings/tokens
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
    print("HF_TOKEN not found in environment variables")
    hf_token = getpass("Enter your HuggingFace token: ")
    os.environ["HF_TOKEN"] = hf_token
else:
    print("‚úì HF_TOKEN loaded from environment")

# Enable fast transfers
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

print("\n‚úì Environment configured for RunPod")
print(f"  HF_HUB_ENABLE_HF_TRANSFER: {os.getenv('HF_HUB_ENABLE_HF_TRANSFER')}")

## 1Ô∏è‚É£ Install Dependencies

Install required packages for AWQ quantized model inference.

In [None]:
# Install AutoAWQ for quantized model support
!pip install -q autoawq transformers accelerate
!pip install -q huggingface-hub hf-transfer

print("‚úì All packages installed successfully!")
print("  - autoawq (AWQ quantization support)")
print("  - transformers (HuggingFace models)")
print("  - accelerate (device management)")
print("  - hf-transfer (fast downloads)")

## 2Ô∏è‚É£ Import Libraries & Check GPU

In [None]:
import json
import torch
import re
from pathlib import Path
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from huggingface_hub import login

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"\nüöÄ GPU Detected:")
    print(f"   Device: {torch.cuda.get_device_name(0)}")
    print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"   CUDA Version: {torch.version.cuda}")
else:
    print("\n‚ö†Ô∏è  WARNING: No GPU detected! This will be very slow.")
    print("   Please ensure you're running on a RunPod GPU pod.")

## 3Ô∏è‚É£ Configuration

Set model and data paths.

In [None]:
# Model Configuration
MODEL_ID = "BioMistral/BioMistral-7B-SLERP-AWQ-QGS128-W4-GEMM"

# Data Configuration (RunPod paths)
TEST_DATA_PATH = "/workspace/data/splits_cleaned_20251113/test.jsonl"

# Alternative local path (if running locally)
# TEST_DATA_PATH = "../../data/splits_cleaned_20251113/test.jsonl"

print("="*80)
print("EVALUATION CONFIGURATION")
print("="*80)
print(f"\nModel: {MODEL_ID}")
print(f"  Type: AWQ Quantized (4-bit)")
print(f"  Base: BioMistral-7B-SLERP")
print(f"  Optimization: QGS128-W4-GEMM (NVIDIA GPU optimized)")
print(f"\nTest Data: {TEST_DATA_PATH}")
print(f"  Dataset: Cleaned (splits_cleaned_20251113)")
print(f"  Quality: 99.8% retention, 0 issues")
print(f"  Expected samples: 299 (100 chemicals, 99 diseases, 100 relationships)")

# Verify test data exists
if Path(TEST_DATA_PATH).exists():
    print(f"\n‚úì Test data file found")
else:
    print(f"\n‚ö†Ô∏è  WARNING: Test data file not found at {TEST_DATA_PATH}")
    print("   Please upload test.jsonl to the correct location.")

## 4Ô∏è‚É£ Utility Functions

Reusable functions for text processing, parsing, and filtering (from previous evaluation).

In [None]:
# ===== Text Normalization =====
def dehyphenate(s: str) -> str:
    """Join words broken across lines with hyphens."""
    return re.sub(r"(\w+)-\s+(\w+)", r"\1\2", s)

def normalize_text(s: str) -> str:
    """Normalize text for consistent comparison."""
    s = dehyphenate(s or "")
    s = s.lower()
    s = re.sub(r"[\u00A0\t\r\n]+", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def normalize_item(item: str) -> str:
    """Normalize entity: lowercase, strip quotes/whitespace."""
    item = item.strip().lower()
    item = re.sub(r'^["\']|["\']$', '', item)
    item = re.sub(r'\s+', ' ', item).strip()
    return item

# ===== Parsing Functions =====
def parse_bullets(text: str):
    """Extract items from bullet list (- or *)."""
    items = []
    for line in (text or "").splitlines():
        m = re.match(r"^\s*[-*]\s*(.+?)\s*$", line)
        if m:
            items.append(m.group(1))
    return items

def extract_list_from_generation(gen: str):
    """Extract list items from model generation."""
    items = []
    for line in gen.splitlines():
        line = line.strip()
        if not line:
            continue
        m = re.match(r"^[-*]\s*(.+)$", line)
        if m:
            items.append(m.group(1).strip())
    return items

def task_from_prompt(prompt: str) -> str:
    """Determine task type from prompt text."""
    prompt_lower = prompt.lower()
    if "list only of the chemicals" in prompt_lower:
        return "chemicals"
    elif "list only of the diseases" in prompt_lower:
        return "diseases"
    elif "list only of the influences" in prompt_lower:
        return "influences"
    return "unknown"

# ===== Relationship Parsing =====
def parse_pairs(text: str):
    """Parse pipe-separated relationships: chemical | disease."""
    pairs = []
    for line in text.splitlines():
        line = line.strip()
        if '|' in line:
            line = re.sub(r'^[-*]\s*', '', line)
            parts = [p.strip() for p in line.split('|')]
            if len(parts) == 2:
                pairs.append((parts[0], parts[1]))
    return pairs

def parse_pairs_from_sentence(text: str):
    """Parse sentence format: chemical X influences disease Y."""
    pairs = []
    for line in text.splitlines():
        m = re.search(r'chemical\s+(.+?)\s+influences\s+disease\s+(.+)', line, re.I)
        if m:
            pairs.append((m.group(1).strip(), m.group(2).strip()))
    return pairs

# ===== Enhanced Filtering (Reduces False Positives) =====
def filter_entities_enhanced(pred_items, prompt_text, task_type):
    """
    Enhanced filtering to reduce false positives.
    Filters out: generic terms, instruction words, entity confusion, short fragments.
    """
    GENERIC_BLACKLIST = {
        'pain', 'drug', 'drugs', 'chemical', 'chemicals',
        'disease', 'diseases', 'medication', 'medications',
        'treatment', 'treatments', 'therapy', 'article',
        'mentioned', 'list', 'extracted', 'following'
    }
    
    # Disease markers (shouldn't be in chemicals)
    DISEASE_MARKERS = {'syndrome', 'disease', 'disorder', 'infection', 'itis', 'osis'}
    
    prompt_lower = prompt_text.lower()
    filtered = []
    
    for item in pred_items:
        item_norm = normalize_item(item)
        
        # Skip empty or very short
        if len(item_norm) < 3:
            continue
        
        # Skip generic terms
        if item_norm in GENERIC_BLACKLIST:
            continue
        
        # Skip if not in prompt text
        if item_norm not in prompt_lower:
            continue
        
        # Entity type validation for chemicals
        if task_type == "chemicals":
            # Skip if it has disease markers
            if any(marker in item_norm for marker in DISEASE_MARKERS):
                continue
        
        filtered.append(item_norm)
    
    return filtered

def filter_pairs_against_text(pairs, prompt_text):
    """Filter relationship pairs to only those mentioned in text."""
    prompt_lower = prompt_text.lower()
    filtered = []
    for (c, d) in pairs:
        c_norm = normalize_item(c)
        d_norm = normalize_item(d)
        if c_norm in prompt_lower and d_norm in prompt_lower:
            filtered.append((c_norm, d_norm))
    return filtered

print("‚úì Utility functions loaded")
print("  - Text normalization")
print("  - Bullet list parsing")
print("  - Task identification")
print("  - Relationship parsing (pipe + sentence formats)")
print("  - Enhanced filtering (reduces false positives)")

## 5Ô∏è‚É£ Load Model & Tokenizer

Load the AWQ quantized BioMistral model.

In [None]:
# Login to HuggingFace
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    login(token=hf_token, add_to_git_credential=True)
    print("‚úì Logged into HuggingFace Hub")

print("="*80)
print("LOADING AWQ QUANTIZED MODEL")
print("="*80)

print(f"\nModel: {MODEL_ID}")
print("Loading tokenizer...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("‚úì Tokenizer loaded")

# Load AWQ quantized model
print("\nLoading AWQ quantized model (4-bit)...")
print("  This may take 1-2 minutes on first run (caching model)")

model = AutoAWQForCausalLM.from_quantized(
    MODEL_ID,
    fuse_layers=True,  # Enable layer fusion for better performance
    device_map="auto",  # Automatically distribute across GPUs
)

print(f"\n‚úì Model loaded successfully!")
print(f"  Model: {MODEL_ID}")
print(f"  Quantization: AWQ 4-bit")
print(f"  Device: {model.device}")

if torch.cuda.is_available():
    print(f"  GPU Memory Used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"  GPU Memory Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

## 6Ô∏è‚É£ Generation Function

Deterministic generation optimized for BioMistral (Mistral chat format).

**Note**: Uses the exact prompt format from the training data without adding system instructions. This ensures consistency between baseline evaluation and future fine-tuning.

In [None]:
def generate_response(prompt_text, max_new_tokens=128):
    """
    Generate response using BioMistral with Mistral chat format.
    Uses deterministic generation (greedy decoding) for reproducibility.
    
    IMPORTANT: Uses EXACT prompt format from training data (no system instruction added)
    to ensure consistency between baseline evaluation and fine-tuning.
    """
    # Mistral chat format - use raw prompt from dataset without modification
    formatted_prompt = f"""<s>[INST] {prompt_text} [/INST]
"""
    
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Greedy decoding (deterministic)
            temperature=1.0,  # Not used with do_sample=False
            top_p=1.0,  # Not used with do_sample=False
            num_beams=1,  # No beam search
            repetition_penalty=1.15,  # Slight penalty
            pad_token_id=tokenizer.eos_token_id,
            use_cache=True,  # Enable KV cache
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the assistant's response (after [/INST])
    if "[/INST]" in response:
        response = response.split("[/INST]")[-1]
    
    return response.strip()

print("‚úì Generation function ready")
print("  Format: Mistral chat format (<s>[INST] ... [/INST])")
print("  Prompt: Uses exact dataset format (no added system instruction)")
print("  Mode: Deterministic (greedy decoding)")
print("  Max tokens: 128 (optimal for NER)")
print("  KV cache: Enabled (faster inference)")

## 7Ô∏è‚É£ Run Evaluation

Evaluate model on test set with per-task metrics.

In [None]:
from statistics import mean

def f1(p, r):
    """Calculate F1 score from precision and recall."""
    return 0.0 if (p + r) == 0 else 2 * p * r / (p + r)

# Load test data
print("Loading test data...")
with open(TEST_DATA_PATH, 'r', encoding='utf-8') as f:
    test_data = [json.loads(line) for line in f]

print(f"‚úì Loaded test set: {len(test_data)} samples\n")

print("="*80)
print("STARTING EVALUATION")
print("="*80)
print("\nConfiguration:")
print("  - Deterministic generation (greedy decoding)")
print("  - Enhanced filtering (reduces false positives)")
print("  - Task-specific processing")
print("\nProcessing samples...\n")

# Initialize per-task counters
gold_total = {"chemicals": 0, "diseases": 0, "influences": 0}
pred_total = {"chemicals": 0, "diseases": 0, "influences": 0}
tp_total = {"chemicals": 0, "diseases": 0, "influences": 0}

examples_fp = []  # False positives
examples_fn = []  # False negatives

# Process each test sample
import time
start_time = time.time()

for idx, row in enumerate(test_data):
    if (idx + 1) % 50 == 0:
        elapsed = time.time() - start_time
        rate = elapsed / (idx + 1)
        remaining = rate * (len(test_data) - idx - 1)
        print(f"  Progress: {idx + 1}/{len(test_data)} samples ({elapsed:.1f}s elapsed, ~{remaining:.1f}s remaining)")
    
    prompt = row["prompt"]
    gold_items = [normalize_item(x) for x in parse_bullets(row.get("completion", ""))]
    task = task_from_prompt(prompt)
    
    # Generate prediction
    gen = generate_response(prompt, max_new_tokens=128)
    pred_raw = extract_list_from_generation(gen)
    
    # Apply task-specific processing
    if task in {"chemicals", "diseases"}:
        # Use enhanced filtering
        pred = filter_entities_enhanced(pred_raw, prompt, task)
    elif task == "influences":
        # Parse gold data (pipe-separated format)
        gold_pairs = []
        for item in parse_bullets(row.get("completion", "")):
            parts = [p.strip() for p in item.split("|")]
            if len(parts) == 2:
                chem = normalize_item(parts[0])
                dis = normalize_item(parts[1])
                gold_pairs.append(f"{chem} | {dis}")
        gold_items = gold_pairs
        
        # Parse model output (try both formats)
        pairs_sentence = parse_pairs_from_sentence(gen)
        pairs_pipe = parse_pairs(gen)
        all_pairs = pairs_sentence if pairs_sentence else pairs_pipe
        
        # Normalize and filter
        pred = [f"{normalize_item(c)} | {normalize_item(d)}"
                for (c, d) in filter_pairs_against_text(all_pairs, prompt)]
    else:
        pred = []
    
    # Calculate metrics
    gs = set(gold_items)
    ps = set(pred)
    
    tp = len(gs & ps)
    fp = len(ps - gs)
    fn = len(gs - ps)
    
    gold_total[task] += len(gs)
    pred_total[task] += len(ps)
    tp_total[task] += tp
    
    # Collect examples
    if fp and len(examples_fp) < 8:
        examples_fp.append({
            "task": task,
            "prompt_preview": prompt[:120] + "...",
            "pred_extras": list(ps - gs)[:5]
        })
    if fn and len(examples_fn) < 8:
        examples_fn.append({
            "task": task,
            "prompt_preview": prompt[:120] + "...",
            "missed": list(gs - ps)[:5]
        })

total_time = time.time() - start_time

print(f"\n‚úì Evaluation complete!")
print(f"  Total time: {total_time:.1f}s")
print(f"  Average: {total_time / len(test_data):.2f}s per sample")
print(f"  Throughput: {len(test_data) / total_time:.1f} samples/second")

## 8Ô∏è‚É£ Results & Analysis

Display comprehensive per-task metrics and analysis.

In [None]:
print("\n" + "="*80)
print("EVALUATION RESULTS - BioMistral-7B-SLERP-AWQ")
print("="*80 + "\n")

# Calculate and display per-task metrics
all_metrics = []

for task in ["chemicals", "diseases", "influences"]:
    P = 0.0 if pred_total[task] == 0 else tp_total[task] / pred_total[task]
    R = 0.0 if gold_total[task] == 0 else tp_total[task] / gold_total[task]
    F1 = f1(P, R)
    
    all_metrics.append({"task": task, "P": P, "R": R, "F1": F1})
    
    print(f"üìä {task.upper()}")
    print(f"   Precision: {P*100:.1f}%")
    print(f"   Recall:    {R*100:.1f}%")
    print(f"   F1 Score:  {F1*100:.1f}%")
    print(f"   True Positives:  {tp_total[task]}")
    print(f"   Gold Standard:   {gold_total[task]}")
    print(f"   Predictions:     {pred_total[task]}")
    print()

# Overall metrics (macro-average)
overall_P = mean([m["P"] for m in all_metrics])
overall_R = mean([m["R"] for m in all_metrics])
overall_F1 = mean([m["F1"] for m in all_metrics])

print("="*80)
print("üìà OVERALL PERFORMANCE (Macro-Average)")
print("="*80)
print(f"Precision: {overall_P*100:.1f}%")
print(f"Recall:    {overall_R*100:.1f}%")
print(f"F1 Score:  {overall_F1*100:.1f}%")
print()

# Comparison with baseline
print("="*80)
print("üìä COMPARISON WITH BASELINE")
print("="*80)
baseline_f1 = 53.8  # From previous Llama-3.2-3B evaluation
improvement = overall_F1 * 100 - baseline_f1

print(f"Llama-3.2-3B Baseline: {baseline_f1:.1f}% F1")
print(f"BioMistral-7B-AWQ:     {overall_F1*100:.1f}% F1")
print(f"Improvement:           {improvement:+.1f} points")
print()

if improvement > 0:
    print(f"‚úÖ BioMistral shows {improvement:.1f} point improvement!")
elif improvement > -5:
    print(f"‚ö†Ô∏è  Performance similar to baseline ({improvement:+.1f} points)")
else:
    print(f"‚ùå Performance below baseline ({improvement:+.1f} points)")

## 9Ô∏è‚É£ Error Analysis

Examine false positives and false negatives.

In [None]:
print("\n" + "="*80)
print("üîç ERROR ANALYSIS")
print("="*80 + "\n")

# False Positives
if examples_fp:
    print("‚ùå FALSE POSITIVES (Predicted but not in gold standard)\n")
    for i, ex in enumerate(examples_fp[:5], 1):
        print(f"{i}. Task: {ex['task']}")
        print(f"   Text: {ex['prompt_preview']}")
        print(f"   Extra predictions: {ex['pred_extras']}")
        print()
else:
    print("‚úì No false positives collected\n")

# False Negatives
if examples_fn:
    print("‚ö†Ô∏è  FALSE NEGATIVES (In gold standard but missed by model)\n")
    for i, ex in enumerate(examples_fn[:5], 1):
        print(f"{i}. Task: {ex['task']}")
        print(f"   Text: {ex['prompt_preview']}")
        print(f"   Missed entities: {ex['missed']}")
        print()
else:
    print("‚úì No false negatives collected\n")

## üîü Test Sample Examples

Run model on a few test samples to see actual outputs.

In [None]:
print("\n" + "="*80)
print("üìù SAMPLE PREDICTIONS")
print("="*80 + "\n")

# Show 3 examples (one per task)
task_examples = {}
for row in test_data:
    task = task_from_prompt(row["prompt"])
    if task not in task_examples:
        task_examples[task] = row
    if len(task_examples) == 3:
        break

for task in ["chemicals", "diseases", "influences"]:
    if task in task_examples:
        row = task_examples[task]
        prompt = row["prompt"]
        gold = row["completion"]
        
        # Generate prediction
        pred = generate_response(prompt, max_new_tokens=128)
        
        print(f"\n{'='*80}")
        print(f"TASK: {task.upper()}")
        print(f"{'='*80}")
        print(f"\nPROMPT:\n{prompt[:300]}...")
        print(f"\nGOLD STANDARD:\n{gold}")
        print(f"\nMODEL PREDICTION:\n{pred}")
        print()

## üìã Summary & Conclusions

**Model**: BioMistral-7B-SLERP-AWQ-QGS128-W4-GEMM  
**Dataset**: Cleaned splits_cleaned_20251113 (299 test samples)  
**Date**: November 15, 2025

### Key Findings:

1. **Medical Domain Advantage**: Evaluate whether BioMistral's medical pretraining improves performance
2. **Quantization Impact**: AWQ 4-bit quantization maintains accuracy while reducing memory
3. **Clean Data Effect**: Using optimized dataset (99.8% retention, 0 issues)
4. **Inference Speed**: AWQ quantization provides faster inference on RunPod GPUs

### Next Steps:

- [ ] Compare with Llama-3.2-3B baseline
- [ ] Fine-tune BioMistral on medical NER data
- [ ] Test with different quantization levels
- [ ] Optimize prompt format for BioMistral

---

**Notebook**: `BioMistral_7B_AWQ_Evaluation_20251115.ipynb`  
**Related Issue**: #2 - Retrain with BioMistral-7B-SLERP