# 07 - Inference Demo

## Goal

Try the model on real abstracts: paste text or a PubMed URL; show predicted labels with probabilities. We'll also add a tiny Gradio block.


In [3]:
# === TODO (you code this) ===
# Goal: Import libraries for inference demo.
# Hints:
# 1) transformers, torch, gradio
# 2) src.utils (eutils_get)
# 3) lxml.etree for PubMed parsing
# Acceptance:
# - All imports successful

# TODO: import libraries
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from src.utils import eutils_get
import lxml.etree
import os
import logging
import pandas as pd
import torch
from pathlib import Path


In [4]:
# === TODO (you code this) ===
# Goal: Define LABELS (same order as training!).
# Acceptance:
# - LABELS list with all 10 categories

# TODO: define LABELS
LABELS = [
    'SystematicReview',  # 1. Systematic reviews
    'MetaAnalysis',      # 2. Meta-analyses (quantitative synthesis)
    'RCT',               # 3. Randomized Controlled Trials
    'ClinicalTrial',     # 4. Non-randomized clinical trials
    'Cohort',            # 5. Cohort studies (prospective/retrospective)
    'CaseControl',       # 6. Case-control studies
    'CaseReport',        # 7. Case reports / case series
    'InVitro',           # 8. In vitro or ex vivo laboratory studies
    'Animal',            # 9. Animal studies
    'Human'              # 10. Human subjects (not mutually exclusive)
]



## Load from Hub

Load the model directly from Hugging Face (or local path if not yet pushed).


In [None]:
# === TODO (you code this) ===
# Goal: Load model and tokenizer from Hugging Face Hub (or local).
# Hints:
# 1) Use "Tuminha/dental-evidence-triage" or "../artifacts/model/best"
# 2) Set model to eval mode
# Acceptance:
# - tokenizer and model loaded
# - Ready for inference

# TODO: load model
# Load from HF Hub (or use "../artifacts/model/best" for local)
hf_path = "Tuminha/dental-evidence-triage"
# Alternative: hf_path = "../artifacts/model/best"  # Use this if model not yet on HF

tokenizer = AutoTokenizer.from_pretrained(hf_path)
model = AutoModelForSequenceClassification.from_pretrained(hf_path)

# Use CPU for inference (more reliable than MPS for Hugging Face models)
# MPS can have issues with certain operations, CPU is more stable
device = torch.device("cpu")
# Alternative: device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Note: MPS (Apple Silicon) can cause "Placeholder storage" errors, so we use CPU

model = model.to(device)

# Set to evaluation mode
model.eval()

print(f"âœ… Model loaded from: {hf_path}")
print(f"âœ… Device: {device} (using CPU for stability)")
print(f"âœ… Model ready for inference")



config.json:   0%|          | 0.00/982 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


## Utility: Fetch PubMed Abstract by PMID


In [9]:
# === TODO (you code this) ===
# Goal: Utility to fetch abstract from PubMed by PMID.
# Hints:
# 1) Use eutils_get with efetch.fcgi
# 2) Parse XML with lxml.etree
# 3) Extract title and abstract, return as tuple
# Acceptance:
# - Function fetch_abstract_by_pmid(pmid) -> (title, abstract)
# - Returns strings

def fetch_abstract_by_pmid(pmid):
    """Fetch abstract text from PubMed by PMID."""
    # Prepare parameters for efetch
    params = {
        'db': 'pubmed',
        'id': str(pmid),
        'retmode': 'xml',
        'rettype': 'medline'
    }
    
    # Call eutils_get - returns requests.Response object
    response = eutils_get('efetch.fcgi', params)
    
    # Parse XML from response text
    root = lxml.etree.fromstring(response.text.encode('utf-8'))
    
    # Extract title (handle None case)
    title_elem = root.find('.//ArticleTitle')
    title = title_elem.text if title_elem is not None and title_elem.text else ""
    
    # Extract abstract (handle None case and multiple AbstractText elements)
    abstract_elem = root.find('.//AbstractText')
    if abstract_elem is not None:
        # AbstractText can have text directly or nested elements
        if abstract_elem.text:
            abstract = abstract_elem.text
        else:
            # If no direct text, get all text content from nested elements
            abstract = ''.join(abstract_elem.itertext())
    else:
        abstract = ""
    
    return title, abstract

# Test with a real PMID (remove test with fake PMID)
# Example: A real dental research paper
print(fetch_abstract_by_pmid("24660200"))  

('Loading protocols for single-implant crowns: a systematic review and meta-analysis.', 'To test whether or not immediate loading of single-implant crowns renders different results from early and conventional loading with respect to implant survival, marginal bone loss, stability of peri-implant soft tissue, esthetics, and patient satisfaction.')


## Predict Function

Returns top-k labels with probabilities, sorted descending.


In [None]:
# === TODO (you code this) ===
# Goal: Predict top-k labels for input text.
# Hints:
# 1) Tokenize, run through model, apply sigmoid
# 2) Sort probabilities, return top-k (label, score) pairs
# Acceptance:
# - Function predict(text, top_k) -> list[(label, score)]
# - Returns top predictions sorted by score

def predict(text, top_k=5):
    """Predict study design labels for text."""
    # Truncate text to 2000 chars (same as training)
    text = str(text)[:2000]
    
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
    
    # Get device from model (works even if device variable not in scope)
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Run through model
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    
    # Apply sigmoid to get probabilities
    probs = torch.sigmoid(logits)[0]  # Get first (and only) batch item
    
    # Create list of (label, probability) pairs
    label_probs = [(LABELS[i], float(probs[i])) for i in range(len(LABELS))]
    
    # Sort by probability (descending) and return top_k
    label_probs.sort(key=lambda x: x[1], reverse=True)
    
    return label_probs[:top_k]

# Test
print(predict("To test whether or not immediate loading of single-implant crowns renders different results from early and conventional loading with respect to implant survival, marginal bone loss, stability of peri-implant soft tissue, esthetics, and patient satisfaction."))    


NameError: name 'device' is not defined

## Test on Known Abstracts

Try 3-5 diverse examples:
1. **Systematic Review** â€” should predict `[SystematicReview, MetaAnalysis?]`
2. **RCT** â€” should predict `[RCT, Human]`
3. **Case Report** â€” should predict `[CaseReport, Human]`
4. **In Vitro** â€” should predict `[InVitro]`
5. **Animal Study** â€” should predict `[Animal]`


In [None]:
# === TODO (you code this) ===
# Goal: Test predict function on known examples.
# Hints:
# 1) Create dict with 5 diverse examples (SR, RCT, CaseReport, InVitro, Animal)
# 2) For each, call predict() and print top-3 results
# 3) Verify predictions match expected study designs
# Acceptance:
# - Tests 5 examples covering different label types
# - Shows label : probability for each

# TODO: test on examples


## Gradio Demo (Optional)

Build a simple interface for interactive testing.


In [None]:
# === TODO (you code this) ===
# Goal: (Optional) Create interactive Gradio demo.
# Hints:
#     if not text.strip():
#         return "Please enter some text."
#     
#     preds = predict(text, top_k=10)
#     output = "\\n".join([f"{label}: {prob:.3f}" for label, prob in preds])
#     return output
# 
# # Create interface
# demo = gr.Interface(
#     fn=gradio_predict,
#     inputs=gr.Textbox(lines=10, placeholder="Paste title + abstract here..."),
#     outputs=gr.Textbox(label="Predicted Labels"),
#     title="ðŸ¦· Dental Evidence Triage",
#     description="Classify dental research abstracts by study design."
# )
# 
# # Launch
# demo.launch(share=True)


## Recommendations

- **Add 3-5 known abstracts** (one SR, one RCT, one CaseReport) as quick sanity checks
- **Remind users:** assistive triage, not ground truth
- **Test edge cases:** very short abstracts, non-English (should fail gracefully), missing abstract

## ðŸ§˜ Reflection Log

**What did you learn in this session?**
- 

**What challenges did you encounter?**
- 

**How will this improve Periospot AI?**
- 
