# 07 - Inference Demo

## Goal

Try the model on real abstracts: paste text or a PubMed URL; show predicted labels with probabilities. We'll also add a tiny Gradio block.


In [37]:
# === TODO (you code this) ===
# Goal: Import libraries for inference demo.
# Hints:
# 1) transformers, torch, gradio
# 2) src.utils (eutils_get)
# 3) lxml.etree for PubMed parsing
# Acceptance:
# - All imports successful

# TODO: import libraries
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from src.utils import eutils_get
import lxml.etree
import os
import logging
import pandas as pd
import torch
from pathlib import Path


In [38]:
# === TODO (you code this) ===
# Goal: Define LABELS (same order as training!).
# Acceptance:
# - LABELS list with all 10 categories

# TODO: define LABELS
LABELS = [
    'SystematicReview',  # 1. Systematic reviews
    'MetaAnalysis',      # 2. Meta-analyses (quantitative synthesis)
    'RCT',               # 3. Randomized Controlled Trials
    'ClinicalTrial',     # 4. Non-randomized clinical trials
    'Cohort',            # 5. Cohort studies (prospective/retrospective)
    'CaseControl',       # 6. Case-control studies
    'CaseReport',        # 7. Case reports / case series
    'InVitro',           # 8. In vitro or ex vivo laboratory studies
    'Animal',            # 9. Animal studies
    'Human'              # 10. Human subjects (not mutually exclusive)
]



## Load from Hub

Load the model directly from Hugging Face (or local path if not yet pushed).


In [39]:
# === TODO (you code this) ===
# Goal: Load model and tokenizer from Hugging Face Hub (or local).
# Hints:
# 1) Use "Tuminha/dental-evidence-triage" or "../artifacts/model/best"
# 2) Set model to eval mode
# Acceptance:
# - tokenizer and model loaded
# - Ready for inference

# TODO: load model
# Load from HF Hub (or use "../artifacts/model/best" for local)
hf_path = "Tuminha/dental-evidence-triage"
# Alternative: hf_path = "../artifacts/model/best"  # Use this if model not yet on HF

tokenizer = AutoTokenizer.from_pretrained(hf_path)
model = AutoModelForSequenceClassification.from_pretrained(hf_path)

# Use CPU for inference (more reliable than MPS for Hugging Face models)
# MPS can have issues with certain operations, CPU is more stable
device = torch.device("cpu")
# Alternative: device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Note: MPS (Apple Silicon) can cause "Placeholder storage" errors, so we use CPU

model = model.to(device)

# Set to evaluation mode
model.eval()

print(f"‚úÖ Model loaded from: {hf_path}")
print(f"‚úÖ Device: {device} (using CPU for stability)")
print(f"‚úÖ Model ready for inference")



‚úÖ Model loaded from: Tuminha/dental-evidence-triage
‚úÖ Device: cpu (using CPU for stability)
‚úÖ Model ready for inference


## Utility: Fetch PubMed Abstract by PMID


In [40]:
# === TODO (you code this) ===
# Goal: Utility to fetch abstract from PubMed by PMID.
# Hints:
# 1) Use eutils_get with efetch.fcgi
# 2) Parse XML with lxml.etree
# 3) Extract title and abstract, return as tuple
# Acceptance:
# - Function fetch_abstract_by_pmid(pmid) -> (title, abstract)
# - Returns strings

def fetch_abstract_by_pmid(pmid):
    """Fetch abstract text and metadata from PubMed by PMID.
    
    Returns:
        tuple: (title, abstract, pub_types, mesh_terms)
        - title: Article title
        - abstract: Abstract text
        - pub_types: List of Publication Types from PubMed
        - mesh_terms: List of MeSH terms from PubMed
    """
    # Prepare parameters for efetch
    params = {
        'db': 'pubmed',
        'id': str(pmid),
        'retmode': 'xml',
        'rettype': 'medline'
    }
    
    # Call eutils_get - returns requests.Response object
    response = eutils_get('efetch.fcgi', params)
    
    # Parse XML from response text
    root = lxml.etree.fromstring(response.text.encode('utf-8'))
    
    # Extract title (handle None case)
    title_elem = root.find('.//ArticleTitle')
    title = title_elem.text if title_elem is not None and title_elem.text else ""
    
    # Extract abstract (handle None case and multiple AbstractText elements)
    abstract_elem = root.find('.//AbstractText')
    if abstract_elem is not None:
        # AbstractText can have text directly or nested elements
        if abstract_elem.text:
            abstract = abstract_elem.text
        else:
            # If no direct text, get all text content from nested elements
            abstract = ''.join(abstract_elem.itertext())
    else:
        abstract = ""
    
    # Extract Publication Types (what PubMed provides - NOT our labels!)
    pub_types = root.xpath('.//PublicationType/text()')
    
    # Extract MeSH terms (what PubMed provides - NOT our labels!)
    mesh_terms = root.xpath('.//MeshHeading/DescriptorName/text()')
    
    return title, abstract, pub_types, mesh_terms

# Test with a real PMID (remove test with fake PMID)
# Example: A real dental research paper
title, abstract, pub_types, mesh_terms = fetch_abstract_by_pmid("24660200")
print(f"Title: {title}")
print(f"Abstract: {abstract[:100]}...")
print(f"Publication Types (from PubMed): {pub_types}")
print(f"MeSH Terms (from PubMed): {mesh_terms[:3]}...")
print(f"\n‚ö†Ô∏è  Note: These are raw metadata from PubMed.")
print(f"   True labels come from our dataset files (created in Notebook 02).")  

Title: Loading protocols for single-implant crowns: a systematic review and meta-analysis.
Abstract: To test whether or not immediate loading of single-implant crowns renders different results from ear...
Publication Types (from PubMed): ['Journal Article', 'Meta-Analysis', "Research Support, Non-U.S. Gov't", 'Systematic Review']
MeSH Terms (from PubMed): ['Female', 'Humans', 'Middle Aged']...

‚ö†Ô∏è  Note: These are raw metadata from PubMed.
   True labels come from our dataset files (created in Notebook 02).


## Predict Function

Returns top-k labels with probabilities, sorted descending.


In [41]:
# === TODO (you code this) ===
# Goal: Predict top-k labels for input text.
# Hints:
# 1) Tokenize, run through model, apply sigmoid
# 2) Sort probabilities, return top-k (label, score) pairs
# Acceptance:
# - Function predict(text, top_k) -> list[(label, score)]
# - Returns top predictions sorted by score

def predict(text, top_k=5):
    """Predict study design labels for text."""
    # Truncate text to 2000 chars (same as training)
    text = str(text)[:2000]
    
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
    
    # Get device from model (works even if device variable not in scope)
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Run through model
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    
    # Apply sigmoid to get probabilities
    probs = torch.sigmoid(logits)[0]  # Get first (and only) batch item
    
    # Create list of (label, probability) pairs
    label_probs = [(LABELS[i], float(probs[i])) for i in range(len(LABELS))]
    
    # Sort by probability (descending) and return top_k
    label_probs.sort(key=lambda x: x[1], reverse=True)
    
    return label_probs[:top_k]

# Test
print(predict("To test whether or not immediate loading of single-implant crowns renders different results from early and conventional loading with respect to implant survival, marginal bone loss, stability of peri-implant soft tissue, esthetics, and patient satisfaction."))    



[('Human', 0.9927734732627869), ('CaseControl', 0.009760159067809582), ('CaseReport', 0.0034319646656513214), ('Cohort', 0.002296593738719821), ('Animal', 0.0019579280633479357)]


## Test on Known Abstracts

Try 3-5 diverse examples:
1. **Systematic Review** ‚Äî should predict `[SystematicReview, MetaAnalysis?]`
2. **RCT** ‚Äî should predict `[RCT, Human]`
3. **Case Report** ‚Äî should predict `[CaseReport, Human]`
4. **In Vitro** ‚Äî should predict `[InVitro]`
5. **Animal Study** ‚Äî should predict `[Animal]`


In [43]:
# === TODO (you code this) ===
# Goal: Test predict function on known examples.
# Hints:
# 1) Create dict with 5 diverse examples (SR, RCT, CaseReport, InVitro, Animal)
# 2) For each, call predict() and print top-3 results
# 3) Verify predictions match expected study designs
# Acceptance:
# - Tests 5 examples covering different label types
# - Shows label : probability for each

def get_true_labels_from_dataset(pmid):
    """Get true labels for a PMID from our processed dataset."""
    # Convert PMID to string for comparison
    pmid_str = str(pmid)
    processed_data_path = Path("../data/processed")
    
    # First, try the original dataset file (before splits)
    original_path = processed_data_path / "dental_abstracts.parquet"
    if original_path.exists():
        try:
            df = pd.read_parquet(original_path)
            # Try both string and integer comparison (PMID might be stored as either)
            matching_rows = df[(df['pmid'] == pmid_str) | (df['pmid'].astype(str) == pmid_str)]
            if len(matching_rows) > 0:
                labels = matching_rows.iloc[0]['labels']
                return labels, "original"
        except Exception as e:
            print(f"   (Note: Could not check original dataset: {e})")
    
    # Then try train, val, test splits
    for split_name in ['train', 'val', 'test']:
        split_path = processed_data_path / f"{split_name}.parquet"
        if split_path.exists():
            try:
                df = pd.read_parquet(split_path)
                # Try both string and integer comparison
                matching_rows = df[(df['pmid'] == pmid_str) | (df['pmid'].astype(str) == pmid_str)]
                if len(matching_rows) > 0:
                    labels = matching_rows.iloc[0]['labels']
                    return labels, split_name
            except Exception as e:
                continue  # Skip this split if there's an error
    
    return None, None

def test_with_pmid(pmid, threshold=0.5):
    """Test model predictions on a PMID and compare with true labels.
    
    Important: True labels come from our dataset files (created in Notebook 02),
    NOT directly from PubMed. PubMed only provides Publication Types and MeSH terms,
    which were converted to canonical labels in Notebook 02.
    """
    print(f"\n{'='*80}")
    print(f"Testing PMID: {pmid}")
    print(f"{'='*80}")
    
    # Fetch abstract and metadata from PubMed
    try:
        title, abstract, pub_types, mesh_terms = fetch_abstract_by_pmid(pmid)
        print(f"\nüìÑ Title: {title}")
        print(f"üìù Abstract: {abstract[:200]}..." if len(abstract) > 200 else f"üìù Abstract: {abstract}")
        
        # Show what PubMed provides (raw metadata)
        print(f"\nüìã PubMed Metadata (raw - NOT our labels):")
        print(f"   Publication Types: {pub_types if pub_types else 'None'}")
        print(f"   MeSH Terms: {mesh_terms[:5] if mesh_terms else 'None'}..." if len(mesh_terms) > 5 else f"   MeSH Terms: {mesh_terms if mesh_terms else 'None'}")
        print(f"   ‚ö†Ô∏è  Note: These are converted to canonical labels in Notebook 02")
    except Exception as e:
        print(f"‚ùå Error fetching abstract: {e}")
        return
    
    # Get true labels from our dataset (these were created in Notebook 02)
    # The labels come from mapping Publication Types + MeSH terms ‚Üí canonical labels
    true_labels, split_name = get_true_labels_from_dataset(pmid)
    if true_labels:
        print(f"\n‚úÖ Found in our dataset ({split_name} split)")
        print(f"üè∑Ô∏è  True Labels (derived in Notebook 02): {true_labels}")
        print(f"   üìå These labels were created by mapping:")
        print(f"      - Publication Types ‚Üí labels")
        print(f"      - MeSH terms ‚Üí labels")
        print(f"      - Keywords ‚Üí labels (fallback)")
    else:
        print(f"\n‚ö†Ô∏è  Not found in our dataset")
        print(f"   (Searched in: dental_abstracts.parquet, train.parquet, val.parquet, test.parquet)")
        print(f"   Possible reasons:")
        print(f"   - PMID from different time period (our dataset: 2018-2025)")
        print(f"   - Article was filtered out during labeling (no matching labels)")
        print(f"   - Article not in our dental query scope")
        # Optional: Show a sample of PMIDs from the dataset for debugging
        try:
            sample_path = Path("../data/processed/train.parquet")
            if sample_path.exists():
                sample_df = pd.read_parquet(sample_path)
                print(f"\n   Sample PMIDs from train split: {sample_df['pmid'].head(3).tolist()}")
        except:
            pass
        true_labels = []
    
    # Get predictions
    text = f"{title} {abstract}"
    predictions = predict(text, top_k=10)
    
    # Apply threshold to get binary predictions
    predicted_labels = [label for label, prob in predictions if prob >= threshold]
    
    print(f"\nü§ñ Model Predictions (threshold={threshold}):")
    print(f"   Predicted Labels: {predicted_labels}")
    print(f"\nüìä Top Predictions with Probabilities:")
    for label, prob in predictions[:5]:
        marker = "‚úÖ" if label in true_labels else "‚ùå" if true_labels else "  "
        print(f"   {marker} {label:20s}: {prob:.4f}")
    
    # Compare if we have true labels
    if true_labels:
        print(f"\nüìà Comparison:")
        correct = set(predicted_labels) & set(true_labels)
        false_positives = set(predicted_labels) - set(true_labels)
        false_negatives = set(true_labels) - set(predicted_labels)
        
        if correct:
            print(f"   ‚úÖ Correct: {sorted(correct)}")
        if false_positives:
            print(f"   ‚ùå False Positives: {sorted(false_positives)}")
        if false_negatives:
            print(f"   ‚ö†Ô∏è  False Negatives (missed): {sorted(false_negatives)}")
        
        # Calculate simple accuracy metrics
        if true_labels:
            precision = len(correct) / len(predicted_labels) if predicted_labels else 0
            recall = len(correct) / len(true_labels) if true_labels else 0
            f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            
            print(f"\n   Precision: {precision:.2%} ({len(correct)}/{len(predicted_labels)})")
            print(f"   Recall: {recall:.2%} ({len(correct)}/{len(true_labels)})")
            print(f"   F1 Score: {f1_score:.2%}")
            
            # Overall accuracy assessment
            print(f"\n{'='*80}")
            print(f"üìä ACCURACY ASSESSMENT")
            print(f"{'='*80}")
            
            # Check if predictions match well
            all_correct = (set(predicted_labels) == set(true_labels))
            mostly_correct = len(correct) >= len(true_labels) * 0.8 and len(false_positives) <= 1
            partially_correct = len(correct) > 0 and f1_score >= 0.5
            
            if all_correct:
                print(f"‚úÖ EXCELLENT! Model predicted perfectly!")
                print(f"   All {len(true_labels)} true label(s) were correctly identified.")
                print(f"   No false positives or false negatives.")
            elif mostly_correct:
                print(f"‚úÖ VERY GOOD! Model predicted accurately!")
                print(f"   Correctly identified {len(correct)}/{len(true_labels)} true label(s).")
                if false_positives:
                    print(f"   Minor issue: {len(false_positives)} false positive(s): {sorted(false_positives)}")
                if false_negatives:
                    print(f"   Minor issue: {len(false_negatives)} missed label(s): {sorted(false_negatives)}")
            elif partially_correct:
                print(f"‚ö†Ô∏è  MODERATE: Model predicted partially correctly.")
                print(f"   Correctly identified {len(correct)}/{len(true_labels)} true label(s).")
                if false_positives:
                    print(f"   ‚ö†Ô∏è  {len(false_positives)} false positive(s): {sorted(false_positives)}")
                if false_negatives:
                    print(f"   ‚ö†Ô∏è  {len(false_negatives)} missed label(s): {sorted(false_negatives)}")
            else:
                print(f"‚ùå POOR: Model predictions don't match well with true labels.")
                print(f"   Only {len(correct)}/{len(true_labels)} true label(s) correctly identified.")
                if false_positives:
                    print(f"   ‚ùå {len(false_positives)} false positive(s): {sorted(false_positives)}")
                if false_negatives:
                    print(f"   ‚ùå {len(false_negatives)} missed label(s): {sorted(false_negatives)}")
            
            # Additional feedback based on F1 score
            print(f"\nüìà Performance Summary:")
            if f1_score >= 0.9:
                print(f"   üèÜ Outstanding performance (F1 ‚â• 0.90)")
            elif f1_score >= 0.7:
                print(f"   üëç Good performance (F1 ‚â• 0.70)")
            elif f1_score >= 0.5:
                print(f"   ‚ö†Ô∏è  Moderate performance (F1 ‚â• 0.50)")
            else:
                print(f"   ‚ùå Needs improvement (F1 < 0.50)")
            
            print(f"{'='*80}")

# Helper function to get a PMID from the dataset for testing
def get_sample_pmid_from_dataset(split='train', index=0):
    """Get a sample PMID from the dataset for testing."""
    processed_data_path = Path("../data/processed")
    split_path = processed_data_path / f"{split}.parquet"
    
    if split_path.exists():
        df = pd.read_parquet(split_path)
        if len(df) > index:
            return df.iloc[index]['pmid']
    return None

# Test with the PMID from Cell 8 (may not be in dataset)
print("="*80)
print("TEST 1: PMID 24660200 (may not be in dataset)")
print("="*80)
test_with_pmid("24660200")

# Test with a PMID that IS in the dataset
print("\n\n" + "="*80)
print("TEST 2: Using a PMID from your dataset")
print("="*80)

# Get a sample PMID from the train split
sample_pmid = get_sample_pmid_from_dataset('train', 0)
if sample_pmid:
    print(f"Using sample PMID from train split: {sample_pmid}")
    test_with_pmid(sample_pmid)
else:
    print("‚ö†Ô∏è  Could not find a sample PMID from the dataset")

# You can also test with specific PMIDs from your dataset
print("\n\n" + "="*80)
print("TEST 3: Additional Examples")
print("="*80)
print("To test with more PMIDs from your dataset, uncomment and modify:")
print("# df = pd.read_parquet('../data/processed/train.parquet')")
print("# test_with_pmid(df.iloc[5]['pmid'])  # Test with 6th article")
print("# test_with_pmid(df.iloc[10]['pmid'])  # Test with 11th article")


TEST 1: PMID 24660200 (may not be in dataset)

Testing PMID: 24660200

üìÑ Title: Loading protocols for single-implant crowns: a systematic review and meta-analysis.
üìù Abstract: To test whether or not immediate loading of single-implant crowns renders different results from early and conventional loading with respect to implant survival, marginal bone loss, stability of peri-...

üìã PubMed Metadata (raw - NOT our labels):
   Publication Types: ['Journal Article', 'Meta-Analysis', "Research Support, Non-U.S. Gov't", 'Systematic Review']
   MeSH Terms: ['Female', 'Humans', 'Middle Aged', 'Alveolar Bone Loss', 'Bone Density']...
   ‚ö†Ô∏è  Note: These are converted to canonical labels in Notebook 02

‚ö†Ô∏è  Not found in our dataset
   (Searched in: dental_abstracts.parquet, train.parquet, val.parquet, test.parquet)
   Possible reasons:
   - PMID from different time period (our dataset: 2018-2025)
   - Article was filtered out during labeling (no matching labels)
   - Article not in

## Gradio Demo (Optional)

Build a simple interface for interactive testing.


In [44]:
!pip install gradio

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting gradio
  Downloading gradio-5.49.1-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting brotli>=1.1.0 (from gradio)
  Downloading brotli-1.2.0-cp311-cp311-macosx_10_9_universal2.whl.metadata (6.1 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.121.2-py3-none-any.whl.metadata (28 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-1.0.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.13.3 (from gradio)
  Downloading gradio_client-1.13.3-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloadi

In [46]:
# === TODO (you code this) ===
# Goal: (Optional) Create interactive Gradio demo.
import gradio as gr

def gradio_predict(text):
    """Wrapper function for Gradio interface."""
    if not text.strip():
        return "Please enter some text."
    
    preds = predict(text, top_k=10)
    output = "\n".join([f"{label}: {prob:.3f}" for label, prob in preds])
    return output

# Create interface
demo = gr.Interface(
    fn=gradio_predict,
    inputs=gr.Textbox(lines=10, placeholder="Paste title + abstract here..."),
    outputs=gr.Textbox(label="Predicted Labels", lines=15),
    title="ü¶∑ Dental Evidence Triage",
    description="Classify dental research abstracts by study design."
)

# Launch
demo.launch(share=True)


* Running on local URL:  http://127.0.0.1:7861


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


* Running on public URL: https://5afec64398548b08b1.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




## Recommendations

- **Add 3-5 known abstracts** (one SR, one RCT, one CaseReport) as quick sanity checks
- **Remind users:** assistive triage, not ground truth
- **Test edge cases:** very short abstracts, non-English (should fail gracefully), missing abstract

## üßò Reflection Log

**What did you learn in this session?**
- 

**What challenges did you encounter?**
- 

**How will this improve Periospot AI?**
- 
