GeneMiner: Transformer-based Gene & Protein Mining from DKD Literature
---------------------------------------------------------------------
This Jupyter notebook implements the GeneMiner pipeline as described in the manuscript:
- **Manuscript Title**: GeneMiner: A Transformer-Based Framework for Knowledge Extraction of Genes and Proteins from Diabetic Kidney Disease Literature
- **Authors**: Farnoush Kiyanpour, Mahdi Kalani, AmirSepehr Saffari, Yousof Gheisari*
- **Purpose**: Automated extraction, classification, and normalization of gene/protein mentions from DKD-related PubMed abstracts using transformer models (BioBERT, PubMedBERT) and hierarchical normalization (HGNC, MyGene, Wikipedia).

The pipeline includes:
1. PubMed abstract relevance classification (BioBERT)
2. Gene/Protein NER (PubMedBERT / BENT)
3. Entity normalization (HGNC + MyGene + Wikipedia fallback)
4. Reproducible output generation for downstream analysis

0. Global Configuration:
**Manuscript Reference**: Section "Method" – Implementation details, reproducibility, and environment setup.

This cell sets up the Python environment, imports necessary libraries, and configures global parameters such as random seeds and device settings for reproducible execution.

In [None]:
import os
import re
import time
import random
import json
import requests
import numpy as np
import pandas as pd
from typing import List, Dict

import torch
from torch.utils.data import Dataset

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification,
    Trainer,
    TrainingArguments,
    pipeline
)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

1. Data Utilities:
**Manuscript Reference**: Section "Literature Retrieval and Annotation" – Data preprocessing and dataset preparation.

This section includes functions for text cleaning and dataset creation for transformer-based classification.

In [None]:
def clean_text(text: str) -> str:
    """Minimal biomedical-safe text normalization."""
    text = re.sub(r"\s+", " ", text)
    return text.strip()


class AbstractDataset(Dataset):
    """Dataset for relevance classification."""

    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=max_len
        )
        self.labels = labels

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

2. Relevance Classification (BioBERT):
**Manuscript Reference**: Section "Relevance Classification" – Fine-tuning BioBERT for binary classification of relevant vs. irrelevant DKD abstracts.

This module trains a BioBERT model to filter out non-relevant abstracts before entity extraction.

In [None]:
def train_relevance_classifier(
    train_texts, train_labels,
    val_texts, val_labels,
    output_dir="biobert_relevance"
):
    tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
    model = AutoModelForSequenceClassification.from_pretrained(
        "dmis-lab/biobert-v1.1",
        num_labels=2
    ).to(DEVICE)

    train_ds = AbstractDataset(train_texts, train_labels, tokenizer)
    val_ds = AbstractDataset(val_texts, val_labels, tokenizer)

    args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=4,
        weight_decay=0.01,
        load_best_model_at_end=True,
        fp16=True,
        seed=SEED,
        logging_steps=100
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds
    )

    trainer.train()
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    return model, tokenizer


3. Named-Entity Recognition (NER):
**Manuscript Reference**: Section "Named-Entity Recognition (NER)" – Using BENT-PubMedBERT for gene/protein extraction.

This module loads a pre-trained NER pipeline and extracts gene/protein mentions from relevant abstracts.

In [None]:
def load_ner_pipeline():
    """
    BENT-PubMedBERT NER model for gene/protein extraction.
    """
    ner = pipeline(
        "token-classification",
        model="pruas/BENT-PubMedBERT-NER-Gene",
        tokenizer="pruas/BENT-PubMedBERT-NER-Gene",
        aggregation_strategy="simple",
        device=0 if DEVICE == "cuda" else -1
    )
    return ner


def extract_entities(texts: List[str], ner_pipeline) -> List[Dict]:
    """Extract gene/protein mentions from abstracts."""
    records = []
    for pmid, text in texts:
        entities = ner_pipeline(text)
        for ent in entities:
            records.append({
                "PMID": pmid,
                "mention": ent["word"],
                "start": ent["start"],
                "end": ent["end"]
            })
    return records

4. Entity Normalization
**Manuscript Reference**: Section "Entity Normalization" – Hierarchical mapping to HGNC-approved symbols using HGNC, MyGene, and Wikipedia.

This module normalizes extracted entities to standard gene symbols using a tiered approach.

4a. Character-Level Normalization (Manuscript-Compliant)
**Manuscript Reference**: Described in the normalization section – handling Greek letters, hyphens, and case standardization.

In [None]:
GREEK_MAP = {
    "α": "alpha", "β": "beta", "γ": "gamma",
    "δ": "delta", "κ": "kappa", "μ": "mu",
    "τ": "tau", "ω": "omega"
}

def normalize_characters(symbol: str) -> str:
    """
    Apply character-level normalization:
    - Greek letter conversion
    - Hyphen / underscore removal
    - Uppercasing
    """
    s = symbol
    for g, latin in GREEK_MAP.items():
        s = s.replace(g, latin)
    s = re.sub(r"[-_/]", "", s)
    return s.upper()

4b. Wikipedia-Assisted Disambiguation (Auxiliary, Flagged):
**Manuscript Reference**: Auxiliary step used only when HGNC and MyGene fail; all matches flagged for manual review.

In [None]:
WIKI_API = "https://en.wikipedia.org/api/rest_v1/page/summary/{}"

def wikipedia_lookup(term: str) -> Dict:
    """
    Auxiliary Wikipedia title lookup for unresolved gene symbols.
    Used ONLY when HGNC and MyGene fail.
    All matches are flagged for manual curator verification.
    """
    try:
        r = requests.get(WIKI_API.format(term), timeout=5)
        if r.status_code != 200:
            return {}
        data = r.json()

        if "gene" in data.get("description", "").lower():
            return {
                "wiki_title": data.get("title"),
                "wiki_description": data.get("description"),
                "source": "Wikipedia"
            }
    except Exception:
        pass

    return {}

4c. Entity Normalization (Hierarchical Mapping):
**Manuscript Reference**: Primary normalization workflow – HGNC → MyGene → Wikipedia fallback.

In [None]:
def normalize_entities(entity_df: pd.DataFrame) -> pd.DataFrame:
    """
    Hierarchical entity normalization:
    1) HGNC primary mapping
    2) Character-level normalization
    3) MyGene enrichment
    4) Wikipedia-assisted disambiguation (flagged)
    """
    normalized = []

    for _, row in entity_df.iterrows():
        raw_mention = row["mention"]
        mention = normalize_characters(raw_mention)

        # --- Primary: HGNC
        hgnc = normalize_hgnc(mention)
        if hgnc:
            normalized.append({
                **row,
                "HGNC": hgnc.get("symbol"),
                "source": "HGNC"
            })
            continue

        # --- Secondary: MyGene
        mg = normalize_mygene(mention)
        if mg:
            normalized.append({
                **row,
                "HGNC": mg.get("symbol"),
                "source": "MyGene"
            })
            continue

        # --- Auxiliary: Wikipedia (flagged, non-quantitative)
        wiki = wikipedia_lookup(mention)
        if wiki:
            normalized.append({
                **row,
                "HGNC": None,
                "source": "Wikipedia",
                "wiki_title": wiki.get("wiki_title"),
                "flag_manual_review": True
            })

        time.sleep(0.2)  # API rate limiting

    return pd.DataFrame(normalized)


5. End-to-End Execution:
**Manuscript Reference**: Section "Method" – Integrated pipeline from classification to normalization.

This function runs the complete GeneMiner pipeline: relevance filtering, NER, and normalization.

In [None]:
def run_geneminer(
    abstracts_df: pd.DataFrame,
    relevance_model_dir: str
):
    """
    abstracts_df columns:
    - PMID
    - text
    """

    # ---- Load relevance classifier
    tokenizer = AutoTokenizer.from_pretrained(relevance_model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(
        relevance_model_dir
    ).to(DEVICE)

    inputs = tokenizer(
        abstracts_df["text"].tolist(),
        truncation=True,
        padding=True,
        max_length=512,
        return_tensors="pt"
    ).to(DEVICE)

    with torch.no_grad():
        logits = model(**inputs).logits
        preds = torch.argmax(logits, dim=1).cpu().numpy()

    abstracts_df["relevant"] = preds
    relevant_df = abstracts_df[abstracts_df["relevant"] == 1]

    # ---- NER
    ner = load_ner_pipeline()
    texts = list(zip(relevant_df["PMID"], relevant_df["text"]))
    entities = extract_entities(texts, ner)

    entity_df = pd.DataFrame(entities)

    # ---- Normalization
    normalized_df = normalize_entities(entity_df)

    return normalized_df


6. Output Saving:
7. **Manuscript Reference**: Section "Reproducibility and FAIR Compliance" – Saving results for downstream analysis.

This module saves normalized gene lists and frequency tables in CSV format.

In [None]:
def save_outputs(df: pd.DataFrame, outdir="results"):
    os.makedirs(outdir, exist_ok=True)
    df.to_csv(f"{outdir}/GeneMiner_Normalized_Genes.csv", index=False)
    df.groupby("HGNC").size().sort_values(ascending=False).to_csv(
        f"{outdir}/Gene_Frequency.csv"
    )