In [1]:
import json
from pathlib import Path

INPUT_DIR = Path("/kaggle/input/fullfull/annotations")
OUTPUT_FILE = Path("merged_spans_with_entities.jsonl")

merged = []

for span_path in sorted(INPUT_DIR.glob("*_spans.jsonl")):
    filename = span_path.name
    with span_path.open("r", encoding="utf-8") as f:
        for lineno, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                print(f"Skipping empty line at {filename}:{lineno}")
                continue
            try:
                rec = json.loads(line)
            except json.JSONDecodeError as e:
                print(f"JSON decode error at {filename}:{lineno} — {e}")
                continue

            spans = rec.get("spans", [])
            if not spans:
                continue

            entry = {
                "text": rec.get("text", ""),
                "tokens": rec.get("tokens", []),
                "spans": spans,
            }
            merged.append(entry)

with OUTPUT_FILE.open("w", encoding="utf-8") as fw:
    for entry in merged:
        fw.write(json.dumps(entry, ensure_ascii=False) + "\n")

print(f"Merged and saved {len(merged)} entity-containing records to: {OUTPUT_FILE.resolve()}")

Skipping empty line at polg_16919951_spans.jsonl:1
JSON decode error at polg_16919951_spans.jsonl:2 — Expecting value: line 1 column 1 (char 0)
Skipping empty line at polg_16919951_spans.jsonl:3
Merged and saved 675 entity-containing records to: /kaggle/working/merged_spans_with_entities.jsonl


In [2]:
!pip install seqeval evaluate torchcrf

Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Collecting torchcrf
  Downloading TorchCRF-1.1.0-py3-none-any.whl.metadata (2.3 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.0->torchcrf)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.0->torchcrf)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.0->torchcrf)
  Downloading nvidia_cuda_cupti_cu1

No silver data

In [3]:
import json
from pathlib import Path
from sklearn.model_selection import train_test_split

# Constants & Paths
FILE_MERGED = Path("/kaggle/working/merged_spans_with_entities.jsonl")
OUT_DIR     = Path("/kaggle/working/bio_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_FILE = OUT_DIR / "train.jsonl"
DEV_FILE   = OUT_DIR / "dev.jsonl"
TEST_FILE  = OUT_DIR / "test.jsonl"

ENTITY_TYPES = {
    "AGE_ONSET", "AGE_FOLLOWUP", "AGE_DEATH",
    "PATIENT", "HPO_TERM", "GENE", "GENE_VARIANT"
}

# Utility Functions
def iter_jsonl(path: Path):
    with path.open("r", encoding="utf-8") as fh:
        for line in fh:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                continue

def filter_valid_entities(rec):
    spans = [s for s in rec.get("spans", []) if s.get("label") in ENTITY_TYPES]
    if spans:
        return {
            "text": rec["text"],
            "spans": spans
        }
    return None

def dump_jsonl(path: Path, data):
    with path.open("w", encoding="utf-8") as fh:
        for obj in data:
            fh.write(json.dumps(obj, ensure_ascii=False) + "\n")

# Load and filter gold data
print(">> Loading gold data …")
merged_filtered = []
for rec in iter_jsonl(FILE_MERGED):
    filtered = filter_valid_entities(rec)
    if filtered:
        merged_filtered.append(filtered)
print(f"Total valid records in gold: {len(merged_filtered)}")

# Split into train/dev/test (80/10/10)
train_dev, test_set = train_test_split(
    merged_filtered,
    test_size=0.20,
    random_state=42
)
train_set, dev_set = train_test_split(
    train_dev,
    test_size=0.25,
    random_state=42
)
print(f"Split sizes – TRAIN: {len(train_set)}, DEV: {len(dev_set)}, TEST: {len(test_set)}")


# Save to disk
dump_jsonl(TRAIN_FILE, train_set)
dump_jsonl(DEV_FILE, dev_set)
dump_jsonl(TEST_FILE, test_set)

print(f"\nSaved to:")
print(f"  ➜ {TRAIN_FILE}")
print(f"  ➜ {DEV_FILE}")
print(f"  ➜ {TEST_FILE}")


>> Loading gold data …
Total valid records in gold: 675
Split sizes – TRAIN: 405, DEV: 135, TEST: 135

Saved to:
  ➜ /kaggle/working/bio_outputs/train.jsonl
  ➜ /kaggle/working/bio_outputs/dev.jsonl
  ➜ /kaggle/working/bio_outputs/test.jsonl


In [4]:
import json
from pathlib import Path
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)
import evaluate

# === 1. Load pre-split data with silver already included ===
BIO_DIR = Path("/kaggle/working/bio_outputs")

def load_jsonl(path: Path):
    with path.open(encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]

train_data = load_jsonl(BIO_DIR / "train.jsonl")
dev_data   = load_jsonl(BIO_DIR / "dev.jsonl")
test_data  = load_jsonl(BIO_DIR / "test.jsonl")

ds_raw = DatasetDict({
    "train": Dataset.from_list(train_data),
    "validation": Dataset.from_list(dev_data),
    "test": Dataset.from_list(test_data),
})
print("Loaded dataset sizes:", {k: len(v) for k, v in ds_raw.items()})

# === 2. Tokenizer and label mappings ===
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    use_fast=True
)

label_list = ["O", "B-HPO_TERM", "I-HPO_TERM"]
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for label, idx in label2id.items()}

# === 3. Span-to-token label encoder ===
def encode_and_align_labels(example):
    text = example["text"]
    spans = example["spans"]

    # 仅使用 HPO_TERM 的 span 构建标签
    hpo_spans = [(s["start"], s["end"]) for s in spans if s["label"] == "HPO_TERM"]

    # Tokenize 带 offset，用于将 span 映射到 token 层
    encoding = tokenizer(
        text,
        return_offsets_mapping=True,
        truncation=True,
        max_length=512,
    )

    labels = []
    for offset in encoding["offset_mapping"]:
        if offset == (0, 0):
            labels.append("O")
            continue

        tag = "O"
        for start, end in hpo_spans:
            if offset[0] >= start and offset[1] <= end:
                tag = "B-HPO_TERM" if offset[0] == start else "I-HPO_TERM"
                break

        labels.append(tag)

    encoding["labels"] = [label2id[tag] for tag in labels]
    return encoding

# === 4. Encode all splits ===
ds_encoded = ds_raw.map(
    encode_and_align_labels,
    batched=False,
    remove_columns=["text", "spans"]
)
print("Encoding complete.")

# === 5. Load model ===
model = AutoModelForTokenClassification.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
)

# === 6. Evaluation metrics ===
seqeval = evaluate.load("seqeval")

def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    labels = p.label_ids
    true_labels = [
        [id2label[lid] for lid in seq if lid != -100]
        for seq in labels
    ]
    pred_labels = [
        [id2label[pid] for pid, lid in zip(pred_seq, label_seq) if lid != -100]
        for pred_seq, label_seq in zip(preds, labels)
    ]
    result = seqeval.compute(predictions=pred_labels, references=true_labels)
    return {
        "overall_precision": result["overall_precision"],
        "overall_recall":    result["overall_recall"],
        "overall_f1":        result["overall_f1"],
        "overall_accuracy":  result["overall_accuracy"],
    }

# === 7. Training configuration ===
training_args = TrainingArguments(
    output_dir="ner_pubmedbert",
    eval_strategy="steps",
    eval_steps=50,
    save_steps=500,
    logging_strategy="steps",
    logging_steps=50,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    learning_rate=3e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="overall_f1",
    greater_is_better=True,
    report_to=["none"],
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_encoded["train"],
    eval_dataset=ds_encoded["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
    compute_metrics=compute_metrics,
)

# === 8. Train and evaluate ===
trainer.train()
trainer.evaluate()

# === 9. Predict on test set ===
print("\n--- Predicting on test set ---")
pred_output = trainer.predict(ds_encoded["test"])
preds = pred_output.predictions.argmax(-1)
labels = pred_output.label_ids

true_labels = [
    [id2label[lid] for lid in seq if lid != -100]
    for seq in labels
]
pred_labels = [
    [id2label[pid] for pid, lid in zip(pred_seq, label_seq) if lid != -100]
    for pred_seq, label_seq in zip(preds, labels)
]

detailed_result = seqeval.compute(predictions=pred_labels, references=true_labels)

print("\n HPO_TERM classification report:")
for label, metrics in detailed_result.items():
    if label == "HPO_TERM":
        print(f"{label:20} | Precision: {metrics['precision']:.3f} | Recall: {metrics['recall']:.3f} | F1: {metrics['f1']:.3f}")

2025-09-05 04:54:47.595486: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757048087.975777      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757048088.087212      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


✅ Loaded dataset sizes: {'train': 405, 'validation': 135, 'test': 135}


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

Map:   0%|          | 0/405 [00:00<?, ? examples/s]

Map:   0%|          | 0/135 [00:00<?, ? examples/s]

Map:   0%|          | 0/135 [00:00<?, ? examples/s]

✅ Encoding complete.


pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading builder script: 0.00B [00:00, ?B/s]

  trainer = Trainer(


Step,Training Loss,Validation Loss,Overall Precision,Overall Recall,Overall F1,Overall Accuracy
50,0.2974,0.17867,0.368146,0.567404,0.446556,0.927922
100,0.1403,0.154322,0.449527,0.573441,0.503979,0.943191
150,0.0924,0.158061,0.489431,0.605634,0.541367,0.946278
200,0.063,0.171512,0.495868,0.603622,0.544465,0.948467
250,0.0453,0.181366,0.484663,0.635815,0.550044,0.948917





--- Predicting on test set ---

 HPO_TERM classification report:
HPO_TERM             | Precision: 0.476 | Recall: 0.608 | F1: 0.534


Add silver data

In [5]:
import json
from pathlib import Path
from sklearn.model_selection import train_test_split

# Constants & Paths
FILE_MERGED = Path("/kaggle/working/merged_spans_with_entities.jsonl")
DIR_SILVER  = Path("/kaggle/input/hpo-only")
OUT_DIR     = Path("/kaggle/working/bio_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_FILE = OUT_DIR / "train.jsonl"
DEV_FILE   = OUT_DIR / "dev.jsonl"
TEST_FILE  = OUT_DIR / "test.jsonl"

ENTITY_TYPES = {
    "AGE_ONSET", "AGE_FOLLOWUP", "AGE_DEATH",
    "PATIENT", "HPO_TERM", "GENE", "GENE_VARIANT"
}

# Utility Functions
def iter_jsonl(path: Path):
    with path.open("r", encoding="utf-8") as fh:
        for line in fh:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                continue

def filter_valid_entities(rec):
    spans = [s for s in rec.get("spans", []) if s.get("label") in ENTITY_TYPES]
    if spans:
        return {
            "text": rec["text"],
            "spans": spans
        }
    return None

def dump_jsonl(path: Path, data):
    with path.open("w", encoding="utf-8") as fh:
        for obj in data:
            fh.write(json.dumps(obj, ensure_ascii=False) + "\n")

def load_filtered_silver(path: Path):
    extra = []
    for rec in iter_jsonl(path):
        rec = filter_valid_entities(rec)
        if rec:
            extra.append(rec)
    return extra

# Step 1: Load and convert gold data
print(">> Loading gold data …")
merged_filtered = []
for rec in iter_jsonl(FILE_MERGED):
    filtered = filter_valid_entities(rec)
    if filtered:
        merged_filtered.append(filtered)
print(f"Total valid records in gold: {len(merged_filtered)}")

# Step 2: Split gold into train/dev/test
train_dev, test_set = train_test_split(
    merged_filtered,
    test_size=0.20,
    random_state=42
)
train_set, dev_set = train_test_split(
    train_dev,
    test_size=0.25,
    random_state=42
)
print(f"Split sizes – TRAIN: {len(train_set)}, DEV: {len(dev_set)}, TEST: {len(test_set)}")

# Step 3: Add silver data to train set
extra_train = []
if DIR_SILVER.exists():
    print(">> Loading silver data from hpo-only/")
    for jf in sorted(DIR_SILVER.glob("*.jsonl")):
        print(f"  - {jf.name}")
        extra_train.extend(load_filtered_silver(jf))
else:
    print(">> Silver data directory not found.")

train_final = train_set + extra_train
print(f"Final train size: {len(train_final)} (including {len(extra_train)} silver records)")

# Step 4: Save to disk
dump_jsonl(TRAIN_FILE, train_final)
dump_jsonl(DEV_FILE, dev_set)
dump_jsonl(TEST_FILE, test_set)

print(f"\nSaved to:")
print(f"  ➜ {TRAIN_FILE}")
print(f"  ➜ {DEV_FILE}")
print(f"  ➜ {TEST_FILE}")

>> Loading gold data …
Total valid records in gold: 675
Split sizes – TRAIN: 405, DEV: 135, TEST: 135
>> Loading silver data from hpo-only/
  - HPO.jsonl
  - HPO_only.jsonl
Final train size: 9331 (including 8926 silver records)

Saved to:
  ➜ /kaggle/working/bio_outputs/train.jsonl
  ➜ /kaggle/working/bio_outputs/dev.jsonl
  ➜ /kaggle/working/bio_outputs/test.jsonl


In [6]:
import json
from pathlib import Path
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)
import evaluate

# === 1. Load pre-split data with silver already included ===
BIO_DIR = Path("/kaggle/working/bio_outputs")

def load_jsonl(path: Path):
    with path.open(encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]

train_data = load_jsonl(BIO_DIR / "train.jsonl")
dev_data   = load_jsonl(BIO_DIR / "dev.jsonl")
test_data  = load_jsonl(BIO_DIR / "test.jsonl")

ds_raw = DatasetDict({
    "train": Dataset.from_list(train_data),
    "validation": Dataset.from_list(dev_data),
    "test": Dataset.from_list(test_data),
})
print(" Loaded dataset sizes:", {k: len(v) for k, v in ds_raw.items()})

# === 2. Tokenizer and label mappings ===
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    use_fast=True
)

label_list = ["O", "B-HPO_TERM", "I-HPO_TERM"]
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for label, idx in label2id.items()}

# === 3. Span-to-token label encoder ===
def encode_and_align_labels(example):
    text = example["text"]
    spans = example["spans"]

    hpo_spans = [(s["start"], s["end"]) for s in spans if s["label"] == "HPO_TERM"]

    # Tokenize 
    encoding = tokenizer(
        text,
        return_offsets_mapping=True,
        truncation=True,
        max_length=512,
    )

    labels = []
    for offset in encoding["offset_mapping"]:
        if offset == (0, 0):
            labels.append("O")
            continue

        tag = "O"
        for start, end in hpo_spans:
            if offset[0] >= start and offset[1] <= end:
                tag = "B-HPO_TERM" if offset[0] == start else "I-HPO_TERM"
                break

        labels.append(tag)

    encoding["labels"] = [label2id[tag] for tag in labels]
    return encoding

# === 4. Encode all splits ===
ds_encoded = ds_raw.map(
    encode_and_align_labels,
    batched=False,
    remove_columns=["text", "spans"]
)
print(" Encoding complete.")

# === 5. Load model ===
model = AutoModelForTokenClassification.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
)

# === 6. Evaluation metrics ===
seqeval = evaluate.load("seqeval")

def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    labels = p.label_ids
    true_labels = [
        [id2label[lid] for lid in seq if lid != -100]
        for seq in labels
    ]
    pred_labels = [
        [id2label[pid] for pid, lid in zip(pred_seq, label_seq) if lid != -100]
        for pred_seq, label_seq in zip(preds, labels)
    ]
    result = seqeval.compute(predictions=pred_labels, references=true_labels)
    return {
        "overall_precision": result["overall_precision"],
        "overall_recall":    result["overall_recall"],
        "overall_f1":        result["overall_f1"],
        "overall_accuracy":  result["overall_accuracy"],
    }

# === 7. Training configuration ===
training_args = TrainingArguments(
    output_dir="ner_pubmedbert",
    eval_strategy="steps",
    eval_steps=50,
    save_steps=500,
    logging_strategy="steps",
    logging_steps=50,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    learning_rate=3e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="overall_f1",
    greater_is_better=True,
    report_to=["none"],
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_encoded["train"],
    eval_dataset=ds_encoded["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
    compute_metrics=compute_metrics,
)

# === 8. Train and evaluate ===
trainer.train()
trainer.evaluate()

# === 9. Predict on test set ===
print("\n--- Predicting on test set ---")
pred_output = trainer.predict(ds_encoded["test"])
preds = pred_output.predictions.argmax(-1)
labels = pred_output.label_ids

true_labels = [
    [id2label[lid] for lid in seq if lid != -100]
    for seq in labels
]
pred_labels = [
    [id2label[pid] for pid, lid in zip(pred_seq, label_seq) if lid != -100]
    for pred_seq, label_seq in zip(preds, labels)
]

detailed_result = seqeval.compute(predictions=pred_labels, references=true_labels)

print("\n HPO_TERM classification report:")
for label, metrics in detailed_result.items():
    if label == "HPO_TERM":
        print(f"{label:20} | Precision: {metrics['precision']:.3f} | Recall: {metrics['recall']:.3f} | F1: {metrics['f1']:.3f}")


 Loaded dataset sizes: {'train': 9331, 'validation': 135, 'test': 135}


Map:   0%|          | 0/9331 [00:00<?, ? examples/s]

Map:   0%|          | 0/135 [00:00<?, ? examples/s]

Map:   0%|          | 0/135 [00:00<?, ? examples/s]

 Encoding complete.


Some weights of BertForTokenClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Step,Training Loss,Validation Loss,Overall Precision,Overall Recall,Overall F1,Overall Accuracy
50,0.4154,0.259372,0.036866,0.016097,0.022409,0.900696
100,0.3023,0.220026,0.15,0.078471,0.103038,0.911474
150,0.2722,0.201136,0.263158,0.28169,0.272109,0.918828
200,0.2608,0.188201,0.380531,0.432596,0.404896,0.930616
250,0.2414,0.191155,0.421384,0.404427,0.412731,0.9323
300,0.2417,0.186311,0.382766,0.384306,0.383534,0.933255
350,0.2333,0.182933,0.427105,0.418511,0.422764,0.937016
400,0.2311,0.167323,0.379245,0.404427,0.391431,0.938419
450,0.1955,0.153104,0.456113,0.585513,0.512775,0.944706
500,0.2168,0.143694,0.434716,0.569416,0.493031,0.948131





--- Predicting on test set ---

 HPO_TERM classification report:
HPO_TERM             | Precision: 0.793 | Recall: 0.823 | F1: 0.808


In [24]:
from collections import defaultdict

def extract_entities(labels):
    spans = []
    start = None
    current_label = None
    for i, lab_id in enumerate(labels):
        label = id2label.get(lab_id, "O")
        if label.startswith("B-HPO_TERM"):
            if current_label:
                spans.append((start, i - 1, current_label))
            start = i
            current_label = "HPO_TERM"
        elif label.startswith("I-HPO_TERM") and current_label:
            continue
        else:
            if current_label:
                spans.append((start, i - 1, current_label))
                current_label = None
                start = None
    if current_label:
        spans.append((start, len(labels) - 1, current_label))
    return spans

def iou(a, b):
    inter = max(0, min(a[1], b[1]) - max(a[0], b[0]) + 1)
    union = max(a[1], b[1]) - min(a[0], b[0]) + 1
    return inter / union

def relaxed_match(pred_span, true_span):
    ps, pe, plabel = pred_span
    ts, te, tlabel = true_span
    if plabel != tlabel:
        return False
    if abs(ps - ts) <= 1 and abs(pe - te) <= 1:
        return True
    if iou((ps, pe), (ts, te)) >= 0.5:
        return True
    return False

def relaxed_compute_metrics(preds, refs):
    tp, fp, fn = 0, 0, 0
    label_metrics = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})

    for pred_seq, ref_seq in zip(preds, refs):
        pred_ents = extract_entities(pred_seq)
        true_ents = extract_entities(ref_seq)
        matched = set()

        for pred_ent in pred_ents:
            match_found = False
            for i, true_ent in enumerate(true_ents):
                if i in matched:
                    continue
                if relaxed_match(pred_ent, true_ent):
                    tp += 1
                    label_metrics["HPO_TERM"]["tp"] += 1
                    matched.add(i)
                    match_found = True
                    break
            if not match_found:
                fp += 1
                label_metrics["HPO_TERM"]["fp"] += 1

        for i, true_ent in enumerate(true_ents):
            if i not in matched:
                fn += 1
                label_metrics["HPO_TERM"]["fn"] += 1

    precision = tp / (tp + fp + 1e-10)
    recall    = tp / (tp + fn + 1e-10)
    f1        = 2 * precision * recall / (precision + recall + 1e-10)

    print("\nRelaxed Per-label HPO_TERM classification report:")
    for label, m in label_metrics.items():
        lp = m["tp"] / (m["tp"] + m["fp"] + 1e-10)
        lr = m["tp"] / (m["tp"] + m["fn"] + 1e-10)
        lf1 = 2 * lp * lr / (lp + lr + 1e-10)
        print(f"{label:20} | Precision: {lp:.3f} | Recall: {lr:.3f} | F1: {lf1:.3f}")

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

filtered_preds = []
filtered_labels = []

for pred_seq, label_seq in zip(preds, labels):
    filtered_pred = [p for p, l in zip(pred_seq, label_seq) if l != -100]
    filtered_label = [l for l in label_seq if l != -100]
    filtered_preds.append(filtered_pred)
    filtered_labels.append(filtered_label)

print("\n Running relaxed evaluation on test set...")
relaxed_metrics = relaxed_compute_metrics(filtered_preds, filtered_labels)
print("\n Relaxed HPO_TERM test set metrics:", relaxed_metrics)


 Running relaxed evaluation on test set...

Relaxed Per-label HPO_TERM classification report:

 Relaxed HPO_TERM test set metrics: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}


In [8]:
trainer.save_model("ner_pubmedbert_saved_HPO")
tokenizer.save_pretrained("ner_pubmedbert_saved_HPO")

('ner_pubmedbert_saved_HPO/tokenizer_config.json',
 'ner_pubmedbert_saved_HPO/special_tokens_map.json',
 'ner_pubmedbert_saved_HPO/vocab.txt',
 'ner_pubmedbert_saved_HPO/added_tokens.json',
 'ner_pubmedbert_saved_HPO/tokenizer.json')

In [9]:
pip install transformers obonet rapidfuzz

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting obonet
  Downloading obonet-1.1.1-py3-none-any.whl.metadata (6.7 kB)
Collecting rapidfuzz
  Downloading rapidfuzz-3.14.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading obonet-1.1.1-py3-none-any.whl (9.2 kB)
Downloading rapidfuzz-3.14.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m30.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: rapidfuzz, obonet
Successfully installed obonet-1.1.1 rapidfuzz-3.14.0
Note: you may need to restart the kernel to use updated packages.


As long as two strings are 85% similar, they will be considered as an acceptable match.

In [10]:
import re
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import obonet
from rapidfuzz import process

# === Config ===
MODEL_DIR = "/kaggle/input/ner-pubmedbert-saved-hpo/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE  = Path("/kaggle/working/bio_outputs/mapped_mentions.jsonl")
MAX_LENGTH = 512
DEVICE = 0  # use -1 for CPU

# === Step 1: Load test data ===
print(">> Loading test data")
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [ex["text"] for ex in test_data]

# === Step 2: Load model and tokenizer with pipeline ===
print(">> Loading model and tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",
    device=DEVICE
)

# === Step 3: Load HPO terms from hp.obo
print(">> Loading HPO terms from obo")
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
graph = obonet.read_obo(obo_url)
hpo_map = {}
for node_id, data in graph.nodes(data=True):
    name = data.get("name")
    if name:
        hpo_map.setdefault(name.lower(), []).append(node_id)
    for syn in data.get("synonym", []):
        match = re.search(r'"(.+?)"', syn)
        if match:
            hpo_map.setdefault(match.group(1).lower(), []).append(node_id)

def normalize_mention(text: str):
    key = text.lower()
    if key in hpo_map:
        return hpo_map[key][0]
    match = process.extractOne(key, hpo_map.keys(), score_cutoff=85)
    if match:
        return hpo_map[match[0]][0]
    return None

# === Step 4: Run NER + Normalize (No Noise Filtering)
print(">> Running NER and normalization")
mapped_mentions = []
unmapped_mentions = []
normalized_mentions = []

for idx, sentence in enumerate(orig_sentences):
    results = ner_pipeline(sentence)
    for ent in results:
        if ent["entity_group"] != "HPO_TERM":
            continue
        mention = ent["word"].strip()
        hpo_id = normalize_mention(mention)
        (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))
        normalized_mentions.append({
            "sentence_id": idx,
            "mention": mention,
            "hpo_id": hpo_id
        })

total  = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"\nTotal mentions:  {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})")
print(f"Failed to map:   {total - mapped} ({(total - mapped)/total:.1%})")

with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"\n Mapped Mentions ({len(mapped_mentions)}):")
for mention, hpo_id in mapped_mentions:
    print(f"{mention} --> {hpo_id}")

print(f"\n Unmapped Mentions ({len(unmapped_mentions)}):")
for mention, _ in unmapped_mentions:
    print(mention)


>> Loading test data
>> Loading model and tokenizer


Device set to use cuda:0


>> Loading HPO terms from obo


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


>> Running NER and normalization


You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset



Total mentions:  644
Mapped to HP ID: 630 (97.8%)
Failed to map:   14 (2.2%)

 Mapped Mentions (630):
cytochrome c oxidase negative muscle fibers --> HP:0003688
reduction of activities of complex i and iv --> HP:0000002
lower extremities were diffusely thin --> HP:0020034
power to be 0 / 5 at wrist ﬂexors and extensors --> HP:0000152
0 / 5 at biceps --> HP:0000062
5 / 5 at deltoid muscles --> HP:0000062
/ 5 at --> HP:0000062
was 5 / 5 in --> HP:0000027
/ 5 at tibialis anterior --> HP:0000062
absent reﬂexes in upper extremities --> HP:0000027
polyneuropathy --> HP:0001271
neuropathic pain --> HP:6000040
ragged - red fibers --> HP:0003200
dyskinesia --> HP:0100660
severe stand - ing tremor --> HP:0001337
dysphagia --> HP:0002015
loss of weight --> HP:0001824
malnutrition --> HP:0004395
ventilatory failure --> HP:0000198
mechanical ventilation --> HP:0004887
serum ck level remains high --> HP:0000218
high serum pyruvate and lactate levels --> HP:0000093
low vitamin b12 and 25 - hydroxyvi

For the HPO _ TERM entity text identified by the model, firstly, the autocomplete interface of Monarch Initiative v3 API is called to find the most relevant HPO_TERM in the category of PhenotypicFeature and return their HP: ID; If Monarch doesn't have a matching result, call the ClinicalTables HPO API to search by keyword, and get the first ID that meets the HP: prefix; If the two steps don't hit, go back to the locally loaded hp.obo file, and do lowercase exact matching first. If it still doesn't hit, do fuzzy matching (similarity threshold is 85%) to find the closest term ID. Finally returns the first HP:ID found, otherwise returns None.

Without fine screening

In [11]:
import re
import json
from pathlib import Path
from typing import Optional
from functools import lru_cache

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import obonet
from rapidfuzz import process

# === Config ===
MODEL_DIR = "/kaggle/input/ner-pubmedbert-saved-hpo/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE  = Path("/kaggle/working/bio_outputs/mapped_mentions.jsonl")
MAX_LENGTH = 512  
DEVICE = 0  # use -1 for CPU

# Monarch v3 API & ClinicalTables
MONARCH_BASE = "https://api-v3.monarchinitiative.org/v3/api"
CT_HPO_SEARCH = "https://clinicaltables.nlm.nih.gov/api/hpo/v3/search"

FUZZY_CUTOFF = 85

# === HTTP session with retries ===
def make_session() -> requests.Session:
    s = requests.Session()
    retries = Retry(
        total=3,
        backoff_factor=0.3,
        status_forcelist=(429, 500, 502, 503, 504),
        allowed_methods=frozenset(["GET"])
    )
    s.mount("https://", HTTPAdapter(max_retries=retries))
    s.mount("http://", HTTPAdapter(max_retries=retries))
    return s

SESSION = make_session()

# === Step 1: Load test data ===
print(">> Loading test data")
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [ex["text"] for ex in test_data]

# === Step 2: Load model and tokenizer with pipeline ===
print(">> Loading model and tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",
    device=DEVICE
)

# === Step 3: Load HPO terms from hp.obo
print(">> Loading HPO terms from obo")
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
try:
    graph = obonet.read_obo(obo_url)
except Exception as e:
    print(f">> Failed to fetch hp.obo (will still try online normalization): {e}")
    graph = None

hpo_map = {}
if graph is not None:
    for node_id, data in graph.nodes(data=True):
        name = data.get("name")
        if name:
            hpo_map.setdefault(name.lower(), []).append(node_id)
        for syn in data.get("synonym", []):
            m = re.search(r'"(.+?)"', syn)
            if m:
                hpo_map.setdefault(m.group(1).lower(), []).append(node_id)
hpo_keys = list(hpo_map.keys())

# === Online normalizers (Monarch v3 -> ClinicalTables) ===
@lru_cache(maxsize=10000)
def normalize_via_monarch(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    params = {"q": q, "category": "biolink:PhenotypicFeature", "limit": 5}
    try:
        r = SESSION.get(f"{MONARCH_BASE}/autocomplete", params=params, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        items = data.get("items") or data.get("results") or data
        if not isinstance(items, list):
            return None
        for it in items:
            curie = it.get("id") or it.get("curie")
            cats  = it.get("category") or it.get("categories") or []
            if isinstance(cats, str):
                cats = [cats]
            if curie and str(curie).startswith("HP:"):
                if not cats or any("PhenotypicFeature" in c for c in cats):
                    return curie
    except Exception:
        return None
    return None

@lru_cache(maxsize=10000)
def normalize_via_ct(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    try:
        r = SESSION.get(CT_HPO_SEARCH, params={"terms": q, "maxList": 10}, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        # 格式通常为 [ids[], names[], ...]
        if isinstance(data, list) and len(data) >= 2:
            ids = data[0]
            for hp_id in ids:
                if isinstance(hp_id, str) and hp_id.startswith("HP:"):
                    return hp_id
    except Exception:
        return None
    return None

# === Your original normalize_mention, upgraded with online -> local fallback ===
def normalize_mention(text: str):
    # 1) Monarch v3
    curie = normalize_via_monarch(text)
    if curie:
        return curie
    # 2) ClinicalTables
    curie = normalize_via_ct(text)
    if curie:
        return curie
    # 3) Local exact + fuzzy
    key = text.lower()
    if key in hpo_map:
        return hpo_map[key][0]
    match = process.extractOne(key, hpo_keys, score_cutoff=FUZZY_CUTOFF) if hpo_keys else None
    if match:
        return hpo_map[match[0]][0]
    return None

# === Step 4: Run NER + Normalize (No Noise Filtering)
print(">> Running NER and normalization")
mapped_mentions = []
unmapped_mentions = []
normalized_mentions = []

for idx, sentence in enumerate(orig_sentences):
    #truncation/max_length
    results = ner_pipeline(sentence)
    for ent in results:
        if ent["entity_group"] != "HPO_TERM":
            continue
        mention = ent["word"].strip()
        hpo_id = normalize_mention(mention)
        (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))
        normalized_mentions.append({
            "sentence_id": idx,
            "mention": mention,
            "hpo_id": hpo_id
        })


total  = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"\nTotal mentions:  {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})" if total else "Mapped to HP ID: 0 (n/a)")
print(f"Failed to map:   {total - mapped} ({(total - mapped)/total:.1%})" if total else "Failed to map: 0 (n/a)")

OUT_FILE.parent.mkdir(parents=True, exist_ok=True)
with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"\n Mapped Mentions ({len(mapped_mentions)}):")
for mention, hpo_id in mapped_mentions:
    print(f"{mention} --> {hpo_id}")

print(f"\n Unmapped Mentions ({len(unmapped_mentions)}):")
for mention, _ in unmapped_mentions:
    print(mention)


Device set to use cuda:0


>> Loading test data
>> Loading model and tokenizer
>> Loading HPO terms from obo


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


>> Running NER and normalization

Total mentions:  644
Mapped to HP ID: 632 (98.1%)
Failed to map:   12 (1.9%)

 Mapped Mentions (632):
cytochrome c oxidase negative muscle fibers --> HP:0003688
reduction of activities of complex i and iv --> HP:0000002
lower extremities were diffusely thin --> HP:0020034
power to be 0 / 5 at wrist ﬂexors and extensors --> HP:0000152
0 / 5 at biceps --> HP:0000062
5 / 5 at deltoid muscles --> HP:0000062
/ 5 at --> HP:0000062
was 5 / 5 in --> HP:0000027
/ 5 at tibialis anterior --> HP:0000062
absent reﬂexes in upper extremities --> HP:0000027
polyneuropathy --> HP:0001271
neuropathic pain --> HP:6000040
ragged - red fibers --> HP:0003200
dyskinesia --> HP:0100660
severe stand - ing tremor --> HP:0001337
dysphagia --> HP:0002015
loss of weight --> HP:0001824
malnutrition --> HP:0004395
ventilatory failure --> HP:0000198
mechanical ventilation --> HP:0004887
serum ck level remains high --> HP:0000218
high serum pyruvate and lactate levels --> HP:0002151
l

In [12]:
import re
import json
from pathlib import Path
from typing import Optional
from functools import lru_cache

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import obonet
from rapidfuzz import process
from rapidfuzz.fuzz import token_set_ratio  


MODEL_DIR = "/kaggle/input/ner-pubmedbert-saved-hpo/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE  = Path("/kaggle/working/bio_outputs/mapped_mentions.jsonl")
MAX_LENGTH = 512 
DEVICE = 0 

# Monarch v3 API & ClinicalTables
MONARCH_BASE = "https://api-v3.monarchinitiative.org/v3/api"
CT_HPO_SEARCH = "https://clinicaltables.nlm.nih.gov/api/hpo/v3/search"


FUZZY_CUTOFF = 85

SIM_THRESH = 65

# === HTTP session with retries ===
def make_session() -> requests.Session:
    s = requests.Session()
    retries = Retry(
        total=3,
        backoff_factor=0.3,
        status_forcelist=(429, 500, 502, 503, 504),
        allowed_methods=frozenset(["GET"])
    )
    s.mount("https://", HTTPAdapter(max_retries=retries))
    s.mount("http://", HTTPAdapter(max_retries=retries))
    return s

SESSION = make_session()

# === Step 1: Load test data ===
print(">> Loading test data")
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [ex["text"] for ex in test_data]

# === Step 2: Load model and tokenizer with pipeline ===
print(">> Loading model and tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",
    device=DEVICE
)

# === Step 3: Load HPO terms from hp.obo
print(">> Loading HPO terms from obo")
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
try:
    graph = obonet.read_obo(obo_url)
except Exception as e:
    print(f">> Failed to fetch hp.obo (will still try online normalization): {e}")
    graph = None

hpo_map = {}
if graph is not None:
    for node_id, data in graph.nodes(data=True):
        name = data.get("name")
        if name:
            hpo_map.setdefault(name.lower(), []).append(node_id)
        for syn in data.get("synonym", []):
            m = re.search(r'"(.+?)"', syn)
            if m:
                hpo_map.setdefault(m.group(1).lower(), []).append(node_id)
hpo_keys = list(hpo_map.keys())

# === Online normalizers (Monarch v3 -> ClinicalTables) ===
@lru_cache(maxsize=10000)
def normalize_via_monarch(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    params = {"q": q, "category": "biolink:PhenotypicFeature", "limit": 5}
    try:
        r = SESSION.get(f"{MONARCH_BASE}/autocomplete", params=params, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        items = data.get("items") or data.get("results") or data
        if not isinstance(items, list):
            return None

        best_curie, best_score = None, -1
        for it in items:
            curie = it.get("id") or it.get("curie")
            label = (it.get("label") or it.get("name") or "").strip()
            cats  = it.get("category") or it.get("categories") or []
            if isinstance(cats, str):
                cats = [cats]
            if curie and str(curie).startswith("HP:"):
                if not cats or any("PhenotypicFeature" in c for c in cats):
                    score = token_set_ratio(q.lower(), label.lower()) if label else 0
                    if score > best_score:
                        best_score, best_curie = score, curie
        return best_curie if best_score >= SIM_THRESH else None
    except Exception:
        return None

@lru_cache(maxsize=10000)
def normalize_via_ct(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    try:
        r = SESSION.get(CT_HPO_SEARCH, params={"terms": q, "maxList": 10}, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        if isinstance(data, list) and len(data) >= 2:
            ids, names = data[0], data[1]
            best_id, best_score = None, -1
            for hp_id, name in zip(ids, names):
                if isinstance(hp_id, str) and hp_id.startswith("HP:"):
                    score = token_set_ratio(q.lower(), (name or "").lower())
                    if score > best_score:
                        best_score, best_id = score, hp_id
            return best_id if best_score >= SIM_THRESH else None
    except Exception:
        return None
    return None

# === Your original normalize_mention, upgraded with online -> local fallback
def normalize_mention(text: str):
    # 1) Monarch v3
    curie = normalize_via_monarch(text)
    if curie:
        return curie
    # 2) ClinicalTables
    curie = normalize_via_ct(text)
    if curie:
        return curie
    key = text.lower().strip()

    if key in hpo_map:
        score = token_set_ratio(key, key)  
        return hpo_map[key][0] if score >= SIM_THRESH else None

    match = process.extractOne(key, hpo_keys, score_cutoff=FUZZY_CUTOFF) if hpo_keys else None
    if match:
        matched_term = match[0]  
        score = token_set_ratio(key, matched_term)
        if score >= SIM_THRESH:
            return hpo_map[matched_term][0]
    return None

# === Step 4: Run NER + Normalize (No Noise Filtering)
print(">> Running NER and normalization")
mapped_mentions = []
unmapped_mentions = []
normalized_mentions = []

for idx, sentence in enumerate(orig_sentences):
    # 按你的要求：不传 truncation/max_length
    results = ner_pipeline(sentence)
    for ent in results:
        if ent["entity_group"] != "HPO_TERM":
            continue
        mention = ent["word"].strip()
        hpo_id = normalize_mention(mention)
        (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))
        normalized_mentions.append({
            "sentence_id": idx,
            "mention": mention,
            "hpo_id": hpo_id
        })

total  = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"\nTotal mentions:  {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})" if total else "Mapped to HP ID: 0 (n/a)")
print(f"Failed to map:   {total - mapped} ({(total - mapped)/total:.1%})" if total else "Failed to map: 0 (n/a)")

OUT_FILE.parent.mkdir(parents=True, exist_ok=True)
with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"\n Mapped Mentions ({len(mapped_mentions)}):")
for mention, hpo_id in mapped_mentions:
    print(f"{mention} --> {hpo_id}")

print(f"\n Unmapped Mentions ({len(unmapped_mentions)}):")
for mention, _ in unmapped_mentions:
    print(mention)


Device set to use cuda:0


>> Loading test data
>> Loading model and tokenizer
>> Loading HPO terms from obo


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


>> Running NER and normalization

Total mentions:  644
Mapped to HP ID: 447 (69.4%)
Failed to map:   197 (30.6%)

 Mapped Mentions (447):
cytochrome c oxidase negative muscle fibers --> HP:0003688
polyneuropathy --> HP:0001271
neuropathic pain --> HP:6000040
ragged - red fibers --> HP:0003200
dyskinesia --> HP:0100660
severe stand - ing tremor --> HP:0001337
dysphagia --> HP:0002015
loss of weight --> HP:0001824
malnutrition --> HP:0004395
mechanical ventilation --> HP:0004887
t2 hyperintensities in the --> HP:6000416
cerebellar white matter --> HP:0007033
neurological deterioration --> HP:0002344
tetraparesis --> HP:0002273
hypotonia --> HP:0001252
cognitive impairment --> HP:0100543
multifocal myoclonus --> HP:0040148
bilateral cerebellum --> HP:0012832
frontal cortex --> HP:0031421
bilateral parietal cortex --> HP:0012832
elevated lactate concentration --> HP:0002151
partial status epilepticus --> HP:0032662
progressive cognitive dysfunction --> HP:0003676
psychotic features --> HP:

In [13]:
import json
from pathlib import Path
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

# ====== CONFIG ======
MODEL_DIR = "/kaggle/input/ner-pubmedbert-saved-hpo/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")      
OUTPUT_XLSX = Path("/kaggle/working/hpo_spans_dual.xlsx")     

LIGATURE_MAP = {
    "ﬁ": "fi",
    "ﬂ": "fl",
    "ﬀ": "ff",
    "ﬃ": "ffi",
    "ﬄ": "ffl",
}

def clean_with_map(orig_text: str):
    clean_chars = []
    c2o = []
    i = 0
    while i < len(orig_text):
        ch = orig_text[i]
        if ch in LIGATURE_MAP:
            repl = LIGATURE_MAP[ch]
            for rc in repl:
                clean_chars.append(rc)
                c2o.append(i) 
            i += 1
        else:
            clean_chars.append(ch)
            c2o.append(i)
            i += 1
    return "".join(clean_chars), c2o

device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple", 
    device=device
)

# ====== READ TEST DATA ======
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]

rows = []
for rec in test_data:
    orig_text = rec["text"]
    clean_text, c2o = clean_with_map(orig_text)

    ents = ner_pipeline(clean_text)

    for ent in ents:
        if ent.get("entity_group") != "HPO_TERM":
            continue

        clean_start = ent.get("start")
        clean_end = ent.get("end")
        if clean_start is None or clean_end is None:
            continue
        if clean_start < 0 or clean_end > len(clean_text) or clean_start >= clean_end:
            continue

        clean_span = clean_text[clean_start:clean_end]

        orig_start = c2o[clean_start]
        orig_end   = c2o[clean_end - 1] + 1 
        original_span = orig_text[orig_start:orig_end]

        rows.append({
            "original_text": orig_text,
            "original_span": original_span,
            "original_start": int(orig_start),
            "original_end": int(orig_end),
            "clean_text": clean_text,
            "clean_span": clean_span,
            "clean_start": int(clean_start),
            "clean_end": int(clean_end),
        })

if rows:
    df = pd.DataFrame(rows).sort_values(by=["original_start", "original_end"]).reset_index(drop=True)
else:
    df = pd.DataFrame(columns=[
        "original_text","original_span","original_start","original_end",
        "clean_text","clean_span","clean_start","clean_end"
    ])

df.to_excel(OUTPUT_XLSX, index=False, engine="openpyxl")
print(f"Saved {len(df)} rows to: {OUTPUT_XLSX.resolve()}")


Device set to use cuda:0
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Saved 658 rows to: /kaggle/working/hpo_spans_dual.xlsx


In [14]:
import re
import json
from pathlib import Path
from typing import Optional
from functools import lru_cache
import unicodedata

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import obonet
from rapidfuzz import process
from rapidfuzz.fuzz import token_set_ratio 

# ========== CONFIG ==========
MODEL_DIR   = "/kaggle/input/ner-pubmedbert-saved-hpo/ner_pubmedbert_saved_HPO"
TEST_FILE   = Path("/kaggle/working/bio_outputs/test.jsonl")  
OUTPUT_XLSX = Path("/kaggle/working/hpo_for_thiloka.xlsx")

SIM_THRESH   = 50
FUZZY_CUTOFF = 85

# Monarch v3 API & ClinicalTables
MONARCH_BASE  = "https://api-v3.monarchinitiative.org/v3/api"
CT_HPO_SEARCH = "https://clinicaltables.nlm.nih.gov/api/hpo/v3/search"

LIGATURE_MAP = {
    "ﬁ": "fi",
    "ﬂ": "fl",
    "ﬀ": "ff",
    "ﬃ": "ffi",
    "ﬄ": "ffl",
}
def clean_with_map(orig_text: str):
    clean_chars, c2o = [], []
    i = 0
    while i < len(orig_text):
        ch = orig_text[i]
        if ch in LIGATURE_MAP:
            repl = LIGATURE_MAP[ch]
            for _ in repl:
                clean_chars.append(_)    
                c2o.append(i)         
            i += 1
        else:
            clean_chars.append(ch)
            c2o.append(i)
            i += 1
    return "".join(clean_chars), c2o

# ========== HTTP session with retries ==========
def make_session() -> requests.Session:
    s = requests.Session()
    retries = Retry(
        total=3,
        backoff_factor=0.3,
        status_forcelist=(429, 500, 502, 503, 504),
        allowed_methods=frozenset(["GET"])
    )
    s.mount("https://", HTTPAdapter(max_retries=retries))
    s.mount("http://", HTTPAdapter(max_retries=retries))
    return s

SESSION = make_session()

print(">> Loading HPO terms from hp.obo …")
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
try:
    graph = obonet.read_obo(obo_url)
except Exception as e:
    print(f">> Failed to fetch hp.obo (we'll still try online normalization): {e}")
    graph = None

hpo_map = {}
id_to_label = {}
if graph is not None:
    for node_id, data in graph.nodes(data=True):
        name = data.get("name")
        if node_id and name:
            id_to_label[node_id] = name
        if name:
            hpo_map.setdefault(name.lower(), []).append(node_id)
        for syn in data.get("synonym", []):
            m = re.search(r'"(.+?)"', syn)
            if m:
                hpo_map.setdefault(m.group(1).lower(), []).append(node_id)
hpo_keys = list(hpo_map.keys())

@lru_cache(maxsize=10000)
def normalize_via_monarch(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    params = {"q": q, "category": "biolink:PhenotypicFeature", "limit": 5}
    try:
        r = SESSION.get(f"{MONARCH_BASE}/autocomplete", params=params, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        items = data.get("items") or data.get("results") or data
        if not isinstance(items, list):
            return None
        best_curie, best_score = None, -1
        for it in items:
            curie  = it.get("id") or it.get("curie")
            label  = (it.get("label") or it.get("name") or "").strip()
            cats   = it.get("category") or it.get("categories") or []
            if isinstance(cats, str):
                cats = [cats]
            if curie and str(curie).startswith("HP:"):
                if not cats or any("PhenotypicFeature" in c for c in cats):
                    score = token_set_ratio(q.lower(), label.lower()) if label else 0
                    if score > best_score:
                        best_score, best_curie = score, curie
        return best_curie if best_score >= SIM_THRESH else None
    except Exception:
        return None

@lru_cache(maxsize=10000)
def normalize_via_ct(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    try:
        r = SESSION.get(CT_HPO_SEARCH, params={"terms": q, "maxList": 10}, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        if isinstance(data, list) and len(data) >= 2:
            ids, names = data[0], data[1]
            best_id, best_score = None, -1
            for hp_id, name in zip(ids, names):
                if isinstance(hp_id, str) and hp_id.startswith("HP:"):
                    score = token_set_ratio(q.lower(), (name or "").lower())
                    if score > best_score:
                        best_score, best_id = score, hp_id
            return best_id if best_score >= SIM_THRESH else None
    except Exception:
        return None
    return None

def normalize_mention(text: str) -> Optional[str]:
    curie = normalize_via_monarch(text)
    if curie:
        return curie
    curie = normalize_via_ct(text)
    if curie:
        return curie
    key = text.lower().strip()
    if key in hpo_map:
        return hpo_map[key][0]
    match = process.extractOne(key, hpo_keys, score_cutoff=FUZZY_CUTOFF) if hpo_keys else None
    if match:
        matched_term = match[0]
        score = token_set_ratio(key, matched_term)
        if score >= SIM_THRESH:
            return hpo_map[matched_term][0]
    return None

def label_for(hpo_id: Optional[str]) -> str:
    if not hpo_id:
        return "no_match"
    return id_to_label.get(hpo_id, "no_match")

print(">> Loading NER model …")
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model     = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple", 
    device=device
)

print(">> Loading test set …")
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]

rows = []
print(">> Extracting spans and normalising …")
for rec in test_data:
    orig_text = rec["text"]
    clean_text, c2o = clean_with_map(orig_text)

    ents = ner_pipeline(clean_text)
    for ent in ents:
        if ent.get("entity_group") != "HPO_TERM":
            continue

        c_start = ent.get("start")
        c_end   = ent.get("end")
        if c_start is None or c_end is None:
            continue
        if c_start < 0 or c_end > len(clean_text) or c_start >= c_end:
            continue


        o_start = c2o[c_start]
        o_end   = c2o[c_end - 1] + 1
        span_text = orig_text[o_start:o_end]


        span_for_norm = clean_text[c_start:c_end]
        hp_id = normalize_mention(span_for_norm)
        hp_label = label_for(hp_id)


        rows.append({
            "Span": span_text,
            "Predicted standardised HPO term": hp_label if hp_id else "no_match",
            "Predicted HPO ID": hp_id if hp_id else "no_match",
            "Correct label? (Y/N/YN)": "",
            "Correct HPO term": "",
            "Correct HPO ID": ""
        })


df = pd.DataFrame(
    rows,
    columns=[
        "Span",
        "Predicted standardised HPO term",
        "Predicted HPO ID",
        "Correct label? (Y/N/YN)",
        "Correct HPO term",
        "Correct HPO ID",
    ]
)
df.to_excel(OUTPUT_XLSX, index=False, engine="openpyxl")
print(f">> Wrote {len(df)} rows to {OUTPUT_XLSX.resolve()}")


>> Loading HPO terms from hp.obo …


Device set to use cuda:0


>> Loading NER model …


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


>> Loading test set …
>> Extracting spans and normalising …
>> Wrote 658 rows to /kaggle/working/hpo_for_thiloka.xlsx


In [15]:
import re
import json
from pathlib import Path
from typing import Optional
from functools import lru_cache
import unicodedata

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import obonet
from rapidfuzz import process
from rapidfuzz.fuzz import token_set_ratio  

# ========== CONFIG ==========
MODEL_DIR   = "/kaggle/input/ner-pubmedbert-saved-hpo/ner_pubmedbert_saved_HPO"
TEST_FILE   = Path("/kaggle/working/bio_outputs/test.jsonl")  
OUTPUT_XLSX = Path("/kaggle/working/hpo_for_thiloka_process.xlsx")


SIM_THRESH   = 50
FUZZY_CUTOFF = 85

# Monarch v3 API & ClinicalTables
MONARCH_BASE  = "https://api-v3.monarchinitiative.org/v3/api"
CT_HPO_SEARCH = "https://clinicaltables.nlm.nih.gov/api/hpo/v3/search"


LIGATURE_MAP = {
    "ﬁ": "fi",
    "ﬂ": "fl",
    "ﬀ": "ff",
    "ﬃ": "ffi",
    "ﬄ": "ffl",
}
def clean_with_map(orig_text: str):

    clean_chars, c2o = [], []
    i = 0
    while i < len(orig_text):
        ch = orig_text[i]
        if ch in LIGATURE_MAP:
            repl = LIGATURE_MAP[ch]
            for rc in repl:
                clean_chars.append(rc)
                c2o.append(i)
            i += 1
        else:
            clean_chars.append(ch)
            c2o.append(i)
            i += 1
    return "".join(clean_chars), c2o


def clean_text(m: str) -> str:
    if not m:
        return m
    m = unicodedata.normalize("NFKC", m)
    m = m.replace("–", "-").replace("—", "-")
    m = m.translate(str.maketrans({
        "“": '"', "”": '"', "„": '"', "‟": '"',
        "’": "'", "‘": "'", "‚": "'", "‛": "'"
    }))
    m = re.sub(r"\s+", " ", m).strip()
    return m

_ABBR_PATTERNS = [
    (re.compile(r"\brrf\b", flags=re.I), "ragged red fibers"),
    (re.compile(r"\bragged[-\s]?red\b", flags=re.I), "ragged red"),
    (re.compile(r"\bcox\b", flags=re.I), "cytochrome c oxidase"),
    (re.compile(r"\bsdh\b", flags=re.I), "succinate dehydrogenase"),
    (re.compile(r"\bg[-\s]?tube\b", flags=re.I), "gastrostomy tube"),
]
def expand_abbrev(m: str) -> str:
    t = m
    for pat, rep in _ABBR_PATTERNS:
        t = pat.sub(rep, t)
    return t

_CANON_SUBS = [
    (re.compile(r"\bdysphasia\b", re.I), "aphasia"),
    (re.compile(r"\bwheel[-\s]?chair\s*bound\b", re.I), "wheelchair dependence"),
    (re.compile(r"\bfailing\s+to\s+thrive\b", re.I), "failure to thrive"),
    (re.compile(r"\bsyncopal\s+episode\b", re.I), "syncope"),
    (re.compile(r"\blumbosacral\s+radiculopathy\b", re.I), "radiculopathy"),
    (re.compile(r"\bragged\s+blue\b(?!\s*fib)", re.I), "ragged blue fibers"),
    (re.compile(r"cytochrome\s+c\s+oxidase\s*[-–—]?\s*negative", re.I),
     "cytochrome c oxidase-negative muscle fibers"),
    (re.compile(r"cytochrome\s+c\s+oxidase\s*[-–—]?\s*deficien\w*", re.I),
     "cytochrome c oxidase-deficient muscle fibers"),
    (re.compile(r"\boligoclonal\s+bands?\b.*\b(csfs?|cerebrospinal\s+fluid)\b", re.I),
     "oligoclonal bands in cerebrospinal fluid"),
    (re.compile(r"\bjerky\s+\w+(\s+and\s+\w+)?\s+movements\b", re.I), "myoclonus"),
    (re.compile(r"\bcomplete\s+absence\s+of\s+proprioceptive\s+sensation\b", re.I),
     "loss of proprioception"),
    (re.compile(r"\bpolymini\W*myoclonus\b", re.I), "polyminimyoclonus"),
    (re.compile(r"\bdystonic\s+toe\s+curling\b", re.I), "dystonia of toes"),
    (re.compile(r"\bstriatal\s+toes?\b", re.I), "striatal toe"),
    (re.compile(r"\bnear\s+falls?\b", re.I), "recurrent falls"),
]
def canonicalize_synonyms(m: str) -> str:
    s = m
    for pat, rep in _CANON_SUBS:
        s = pat.sub(rep, s)
    return s

_REGION_PATTERNS = [
    (re.compile(r"\bcortical\s+gr[ae]y(?:\s+matter)?\b", re.I), "cortical gray matter"),
    (re.compile(r"\bsubcortical\s+white\s+matter\b", re.I), "subcortical white matter"),
    (re.compile(r"\bbasal\s+ganglia\b", re.I), "basal ganglia"),
    (re.compile(r"\bthalam(?:us|i)\b", re.I), "thalamus"),
    (re.compile(r"\bparieto[-\s]?occipital\b", re.I), "parieto-occipital region"),
    (re.compile(r"\bfrontoparietal\s+subcortical\s+white\s+matter\b", re.I), "frontoparietal subcortical white matter"),
    (re.compile(r"\bpre[-\s]?rolandic\b", re.I), "pre-rolandic cortex"),
    (re.compile(r"\binferior\s+olivary\s+nucleus\b", re.I), "inferior olivary nucleus"),
]
_FINDING_PATTERNS = [
    (re.compile(r"\bt2\s*[- ]?\s*hyperintens\w*\b", re.I), "T2 hyperintensity"),
    (re.compile(r"\bhypersignal(s)?\b", re.I), "T2 hyperintensity"),
    (re.compile(r"\bhyperintens\w*\b", re.I), "T2 hyperintensity"),
    (re.compile(r"\bhemorrhag\w*\b", re.I), "hemorrhage"),
    (re.compile(r"\bswell\w*\b", re.I), "swelling"),
    (re.compile(r"\batroph\w*\b", re.I), "atrophy"),
    (re.compile(r"\bglios\w*\b", re.I), "gliosis"),
    (re.compile(r"\blesion\w*\b", re.I), "lesion"),
    (re.compile(r"\bprolongation\b", re.I), "T2 prolongation"),
    (re.compile(r"\bhypointens\w*\b", re.I), "T2 hypointensity"),
]
_PATHO_TEMPLATES = [
    (re.compile(r"\bcytochrome\s+c\s+oxidase\b.*\bnegative\b.*\bfib(er|re)s\b", re.I),
     "cytochrome c oxidase-negative muscle fibers"),
    (re.compile(r"\bcox[-\s]?negative\b.*\bfib(er|re)s\b", re.I),
     "cytochrome c oxidase-negative muscle fibers"),
    (re.compile(r"\bcox[-\s]?deficien\w*\b.*\bfib(er|re)s\b", re.I),
     "cytochrome c oxidase-deficient muscle fibers"),
    (re.compile(r"\bragged\s+blue\b.*\bfib(er|re)s\b", re.I),
     "ragged blue fibers"),
    (re.compile(r"\boligoclonal\s+bands?\b", re.I),
     "oligoclonal bands in cerebrospinal fluid"),
    (re.compile(r"\bvariation\s+of\s+fiber\s+calib(er|re)\b", re.I),
     "variation in skeletal muscle fiber size"),
    (re.compile(r"\bnuclear\s+centralization\b", re.I),
     "increased central nuclei in skeletal muscle fibers"),
    (re.compile(r"\bfatty\s+replacement\b.*\bendomysial\b", re.I),
     "fatty infiltration of skeletal muscle"),
    (re.compile(r"\bparacrystalline\s+inclusion\w*\b", re.I),
     "mitochondrial paracrystalline inclusions"),
    (re.compile(r"\bswollen\s+mitochondria\b|\babnormally\s+swollen\s+mitochondria\b", re.I),
     "swollen mitochondria"),
    (re.compile(r"\b(concentric|tubular|irregular)\s+cristae\b", re.I),
     "abnormal mitochondrial cristae morphology"),
    (re.compile(r"\b(poly)?spike\s*-\s*(and\s*-\s*)?slow\s*waves?\b", re.I),
     "EEG with epileptiform discharges"),
    (re.compile(r"\bsharp\s+and\s+slow\s+wave(s)?\b", re.I),
     "EEG with epileptiform discharges"),
]
def rewrite_imaging_pathology(m: str) -> Optional[str]:
    s = m
    for pat, rep in _PATHO_TEMPLATES:
        if pat.search(s):
            return rep
    found_region = None
    for pat, norm in _REGION_PATTERNS:
        if pat.search(s):
            found_region = norm
            break
    found_finding = None
    for pat, norm in _FINDING_PATTERNS:
        if pat.search(s):
            found_finding = norm
            break
    if found_finding and found_region:
        return f"{found_finding} of {found_region}"
    if found_finding:
        return found_finding
    return None

def rewrite_numeric_to_qualitative(m: str) -> Optional[str]:
    s = m.lower()
    if re.search(r"\b(vitamin\s*b12|b\s*12)\b", s) and re.search(r"\b(low|decreas\w*|deficien\w*)\b", s):
        return "decreased circulating vitamin B12"
    if re.search(r"\b(25[-\s]?hydroxyvitamin\s*d|25ohd|25\W*oh\W*d)\b", s) and re.search(r"\b(low|decreas\w*|deficien\w*)\b", s):
        return "decreased circulating 25-hydroxyvitamin D"
    if re.search(r"\b(csfs?|cerebrospinal\s+fluid)\b.*\blactate\b", s) and re.search(r"\b(increas\w*|elevat\w*|high)\b", s):
        return "increased cerebrospinal fluid lactate"
    if re.search(r"\blactate\b.*\b(spectroscop\w*|mrs)\b", s):
        return "increased brain lactate on magnetic resonance spectroscopy"
    if re.search(r"\b(serum\s+)?ck\b", s) and (re.search(r"\blevel\b", s) or re.search(r"\d", s)):
        if re.search(r"\b(increas\w*|elevat\w*|high|above\s+normal|x\s*normal)\b", s) or re.search(r"\b\d+(\.\d+)?\s*x\s*normal\b", s):
            return "elevated serum creatine kinase"
    if re.search(r"\bmyoglobin\b", s) and (re.search(r"\blevel\b", s) or re.search(r"\d", s)):
        if re.search(r"\b(increas\w*|elevat\w*|high)\b", s):
            return "elevated myoglobin"
    if re.search(r"\b(pyruvate)\b", s) and (re.search(r"\b(blood|serum)\b", s) or re.search(r"\b\du?mol\b", s)):
        if re.search(r"\b(increas\w*|elevat\w*|high)\b", s) or re.search(r"\d", s):
            return "increased blood pyruvate"
    if re.search(r"\bvanillat\w*\b", s) and re.search(r"\b(increas\w*|elevat\w*|high)\b", s):
        return "increased vanillate level"
    if re.search(r"\bdopamine\s+transporter\b", s) and re.search(r"\b(reduc\w*|decreas\w*|low)\b", s):
        return "decreased dopamine transporter level"
    if re.search(r"\blogmar\b", s):
        return "decreased visual acuity"
    if re.search(r"\bketonuria\b", s):
        return "ketonuria"
    if re.search(r"\bvisual\s+disturbance(s)?\b", s):
        return "visual impairment"
    return None

def apply_templates(m: str) -> str:
    img = rewrite_imaging_pathology(m)
    if img:
        return img
    numq = rewrite_numeric_to_qualitative(m)
    if numq:
        return numq
    return m

def dedup_adjacent_words(s: str) -> str:
    return re.sub(r'\b(\w+)(\s+\1\b)+', r'\1', s, flags=re.I)

_MRC = re.compile(r"\b\d+\s*/\s*\d+\b")  # 例如 5/5, 0/5
_STOPWORDS = {"in", "at", "of", "was", "to", "is", "are", "be", "the", "and"}
_BAD_START = re.compile(r"^##")
_BAD_FRAGMENTS = [
    re.compile(r"\bfed\s+through\s+a\b", re.I),
    re.compile(r"\bshowed\s+an\b", re.I),
    re.compile(r"^\s*(hepat|dysm|##ar|##et)\b", re.I),
]
def is_noise(m: str) -> bool:
    if not m:
        return True
    m_strip = m.strip()
    if _BAD_START.search(m_strip):
        return True
    for pat in _BAD_FRAGMENTS:
        if pat.search(m_strip):
            return True
    alpha = sum(c.isalpha() for c in m_strip)
    if alpha < 2 or len(m_strip) < 3:
        return True
    if _MRC.search(m_strip):
        return True
    toks = re.findall(r"[a-z]+", m_strip.lower())
    if toks and all(t in _STOPWORDS for t in toks):
        return True
    return False

# ========== HTTP session with retries ==========
def make_session() -> requests.Session:
    s = requests.Session()
    retries = Retry(
        total=3,
        backoff_factor=0.3,
        status_forcelist=(429, 500, 502, 503, 504),
        allowed_methods=frozenset(["GET"])
    )
    s.mount("https://", HTTPAdapter(max_retries=retries))
    s.mount("http://", HTTPAdapter(max_retries=retries))
    return s

SESSION = make_session()

print(">> Loading HPO terms from hp.obo …")
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
try:
    graph = obonet.read_obo(obo_url)
except Exception as e:
    print(f">> Failed to fetch hp.obo (we'll still try online normalization): {e}")
    graph = None

hpo_map = {}
id_to_label = {}
if graph is not None:
    for node_id, data in graph.nodes(data=True):
        name = data.get("name")
        if node_id and name:
            id_to_label[node_id] = name
        if name:
            hpo_map.setdefault(name.lower(), []).append(node_id)
        for syn in data.get("synonym", []):
            m = re.search(r'"(.+?)"', syn)
            if m:
                hpo_map.setdefault(m.group(1).lower(), []).append(node_id)
hpo_keys = list(hpo_map.keys())

@lru_cache(maxsize=10000)
def normalize_via_monarch(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    params = {"q": q, "category": "biolink:PhenotypicFeature", "limit": 5}
    try:
        r = SESSION.get(f"{MONARCH_BASE}/autocomplete", params=params, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        items = data.get("items") or data.get("results") or data
        if not isinstance(items, list):
            return None
        best_curie, best_score = None, -1
        for it in items:
            curie  = it.get("id") or it.get("curie")
            label  = (it.get("label") or it.get("name") or "").strip()
            cats   = it.get("category") or it.get("categories") or []
            if isinstance(cats, str):
                cats = [cats]
            if curie and str(curie).startswith("HP:"):
                if not cats or any("PhenotypicFeature" in c for c in cats):
                    score = token_set_ratio(q.lower(), label.lower()) if label else 0
                    if score > best_score:
                        best_score, best_curie = score, curie
        return best_curie if best_score >= SIM_THRESH else None
    except Exception:
        return None

@lru_cache(maxsize=10000)
def normalize_via_ct(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    try:
        r = SESSION.get(CT_HPO_SEARCH, params={"terms": q, "maxList": 10}, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        if isinstance(data, list) and len(data) >= 2:
            ids, names = data[0], data[1]
            best_id, best_score = None, -1
            for hp_id, name in zip(ids, names):
                if isinstance(hp_id, str) and hp_id.startswith("HP:"):
                    score = token_set_ratio(q.lower(), (name or "").lower())
                    if score > best_score:
                        best_score, best_id = score, hp_id
            return best_id if best_score >= SIM_THRESH else None
    except Exception:
        return None
    return None

def normalize_mention(text: str) -> Optional[str]:
    # 优先在线
    curie = normalize_via_monarch(text)
    if curie:
        return curie
    curie = normalize_via_ct(text)
    if curie:
        return curie
    # 本地兜底
    key = text.lower().strip()
    if key in hpo_map:
        return hpo_map[key][0]
    match = process.extractOne(key, hpo_keys, score_cutoff=FUZZY_CUTOFF) if hpo_keys else None
    if match:
        matched_term = match[0]
        score = token_set_ratio(key, matched_term)
        if score >= SIM_THRESH:
            return hpo_map[matched_term][0]
    return None

def label_for(hpo_id: Optional[str]) -> str:
    if not hpo_id:
        return "no_match"
    return id_to_label.get(hpo_id, "no_match")

print(">> Loading NER model …")
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model     = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",
    device=device
)

print(">> Loading test set …")
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]

rows = []
print(">> Extracting spans and normalising …")
for rec in test_data:
    orig_text = rec["text"]
    clean_text_mapped, c2o = clean_with_map(orig_text)

    ents = ner_pipeline(clean_text_mapped)
    for ent in ents:
        if ent.get("entity_group") != "HPO_TERM":
            continue

        c_start = ent.get("start")
        c_end   = ent.get("end")
        if c_start is None or c_end is None:
            continue
        if c_start < 0 or c_end > len(clean_text_mapped) or c_start >= c_end:
            continue

        o_start = c2o[c_start]
        o_end   = c2o[c_end - 1] + 1
        span_text = orig_text[o_start:o_end]

        span_for_norm = clean_text(clean_text_mapped[c_start:c_end])
        span_for_norm = expand_abbrev(span_for_norm)
        span_for_norm = canonicalize_synonyms(span_for_norm)
        span_for_norm = apply_templates(span_for_norm)
        span_for_norm = dedup_adjacent_words(span_for_norm)
        if is_noise(span_for_norm):
            hp_id = None
        else:
            hp_id = normalize_mention(span_for_norm)

        hp_label = label_for(hp_id)

        rows.append({
            "Span": span_text,
            "Predicted standardised HPO term": hp_label if hp_id else "no_match",
            "Predicted HPO ID": hp_id if hp_id else "no_match",
            "Correct label? (Y/N/YN)": "",
            "Correct HPO term": "",
            "Correct HPO ID": ""
        })

df = pd.DataFrame(
    rows,
    columns=[
        "Span",
        "Predicted standardised HPO term",
        "Predicted HPO ID",
        "Correct label? (Y/N/YN)",
        "Correct HPO term",
        "Correct HPO ID",
    ]
)
df.to_excel(OUTPUT_XLSX, index=False, engine="openpyxl")
print(f">> Wrote {len(df)} rows to {OUTPUT_XLSX.resolve()}")


>> Loading HPO terms from hp.obo …


Device set to use cuda:0


>> Loading NER model …


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


>> Loading test set …
>> Extracting spans and normalising …
>> Wrote 658 rows to /kaggle/working/hpo_for_thiloka_process.xlsx


In [16]:
import json
from pathlib import Path
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import obonet
import re
from rapidfuzz import process
from rapidfuzz.fuzz import token_set_ratio

import unicodedata 
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from typing import Optional
from functools import lru_cache

MODEL_DIR = "/kaggle/input/ner-pubmedbert-saved-hpo/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")
OUTPUT_XLSX = Path("/kaggle/working/hpo_compare.xlsx")


FUZZY_CUTOFF = 85    
SIM_THRESH   = 70    

LIGATURE_MAP = {
    "ﬁ": "fi",
    "ﬂ": "fl",
    "ﬀ": "ff",
    "ﬃ": "ffi",
    "ﬄ": "ffl",
}

def clean_with_map(orig_text: str):
    clean_chars, c2o = [], []
    for i, ch in enumerate(orig_text):
        if ch in LIGATURE_MAP:
            repl = LIGATURE_MAP[ch]
            for _ in repl:
                clean_chars.append(_)
                c2o.append(i)
        else:
            clean_chars.append(ch)
            c2o.append(i)
    return "".join(clean_chars), c2o

def clean_text(m: str) -> str:
    if not m:
        return m
    m = unicodedata.normalize("NFKC", m)
    m = m.replace("–", "-").replace("—", "-")
    m = m.translate(str.maketrans({
        "“": '"', "”": '"', "„": '"', "‟": '"',
        "’": "'", "‘": "'", "‚": "'", "‛": "'"
    }))
    m = re.sub(r"\s+", " ", m).strip()
    return m


gold_spans_by_text = []
with TEST_FILE.open(encoding="utf-8") as f:
    for line in f:
        rec = json.loads(line)
        text = rec["text"]
        spans = rec.get("spans", [])
        gold_hpo = []
        for sp in spans:
            if sp.get("label") == "HPO_TERM":
                start, end = sp["start"], sp["end"]
                gold_hpo.append((start, end, text[start:end]))
        gold_spans_by_text.append({"text": text, "gold_spans": gold_hpo})


print(">> Loading hp.obo for local HPO normalization")
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
try:
    graph = obonet.read_obo(obo_url)
except Exception as e:
    print(f">> Failed to fetch hp.obo: {e}")
    graph = None

hpo_keys = []
term2ids = {}
id2label = {}

if graph is not None:
    for node_id, data in graph.nodes(data=True):
        if not isinstance(node_id, str) or not node_id.startswith("HP:"):
            continue
        label = (data.get("name") or "").strip()
        if label:
            id2label[node_id] = label
            term2ids.setdefault(label.lower(), set()).add(node_id)
        for syn in data.get("synonym", []):
            m = re.search(r'"(.+?)"', syn)
            if m:
                syn_txt = m.group(1).strip()
                if syn_txt:
                    term2ids.setdefault(syn_txt.lower(), set()).add(node_id)
    hpo_keys = list(term2ids.keys())

# ====== Online standardization via Monarch and ClinicalTables ======
MONARCH_BASE = "https://api-v3.monarchinitiative.org/v3/api"
CT_HPO_SEARCH = "https://clinicaltables.nlm.nih.gov/api/hpo/v3/search"

# === Session creation for retries ===
def make_session() -> requests.Session:
    s = requests.Session()
    retries = Retry(
        total=3,
        backoff_factor=0.3,
        status_forcelist=(429, 500, 502, 503, 504),
        allowed_methods=frozenset(["GET"])
    )
    s.mount("https://", HTTPAdapter(max_retries=retries))
    s.mount("http://", HTTPAdapter(max_retries=retries))
    return s

SESSION = make_session()

STOPWORDS = {"a","an","the","to","of","in","on","for","and","or","with","at","by","from","is","are","was","were"}
KEY_HEADWORDS = {"sensation","sense","vibration","pinprick","pain","touch","temperature",
                 "proprioception","hearing","vision","taste","smell"}

def _content_tokens(s: str):
    toks = re.findall(r"[a-z]+", (s or "").lower())
    return [t for t in toks if t not in STOPWORDS]

def _overlap_ratio(mention: str, label: str):
    m = set(_content_tokens(mention))
    l = set(_content_tokens(label))
    if not m or not l:
        return 0.0, 0.0, 0
    inter = m & l
    return (len(inter)/max(1,len(m))), (len(inter)/max(1,len(l))), len(inter)

def _headword_bonus(mention: str, label: str):
    m = set(_content_tokens(mention))
    l = set(_content_tokens(label))
    return 1.1 if (m & KEY_HEADWORDS) and (l & KEY_HEADWORDS) else 1.0

def _contextual_score(mention: str, sent_text: str, label: str):
    s1 = token_set_ratio((mention or "").lower(), (label or "").lower())
    s2 = token_set_ratio((sent_text or "").lower(), (label or "").lower())
    return 0.7 * s1 + 0.3 * s2

@lru_cache(maxsize=10000)
def normalize_via_monarch_full(q: str):
    q = (q or "").strip()
    if not q:
        return []
    try:
        r = SESSION.get(
            f"{MONARCH_BASE}/autocomplete",
            params={"q": q, "category": "biolink:PhenotypicFeature", "limit": 15},
            timeout=6
        )
        if r.status_code != 200:
            return []
        data = r.json()
        items = data.get("items") or data.get("results") or data
        out = []
        if isinstance(items, list):
            for it in items:
                hp = it.get("id") or it.get("curie")
                lab = (it.get("label") or it.get("name") or "").strip()
                if isinstance(hp, str) and hp.startswith("HP:") and lab:
                    out.append((hp, lab))
        return out
    except Exception:
        return []

@lru_cache(maxsize=10000)
def normalize_via_ct_full(q: str):
    q = (q or "").strip()
    if not q:
        return []
    try:
        r = SESSION.get(CT_HPO_SEARCH, params={"terms": q, "maxList": 20}, timeout=6)
        if r.status_code != 200:
            return []
        data = r.json()
        out = []
        if isinstance(data, list) and len(data) >= 2:
            ids, names = data[0], data[1]
            for hp_id, name in zip(ids, names):
                if isinstance(hp_id, str) and hp_id.startswith("HP:") and name:
                    out.append((hp_id, name))
        return out
    except Exception:
        return []

def _local_candidates(q: str):
    key = (q or "").lower().strip()
    cands = []
    if key in term2ids: 
        for hp in term2ids[key]:
            cands.append((hp, id2label.get(hp, key)))
    if hpo_keys:

        match = process.extract(key, hpo_keys, limit=15, score_cutoff=max(70, FUZZY_CUTOFF-10))
        for cand, _score, _ in match:
            for hp in term2ids.get(cand, []):
                cands.append((hp, id2label.get(hp, cand)))
    seen, uniq = set(), []
    for hp, lab in cands:
        if (hp, lab) not in seen:
            uniq.append((hp, lab))
            seen.add((hp, lab))
    return uniq

def normalize_mention(mention: str, sent_text: str):

    cands = []
    cands += normalize_via_monarch_full(mention)
    cands += normalize_via_ct_full(mention)


    seen, uniq = set(), []
    for hp, lab in cands:
        if (hp, lab) not in seen:
            uniq.append((hp, lab))
            seen.add((hp, lab))

    if not uniq:
        uniq = _local_candidates(mention)

    best = (None, None, -1.0)  # (hp, lab, score)
    for hp, lab in uniq:
        if not lab:
            continue
        base = _contextual_score(mention, sent_text, lab)
        m_overlap, l_overlap, inter = _overlap_ratio(mention, lab)
        if m_overlap < 0.5 or l_overlap < 0.3 or inter == 0:
            continue
        bonus = _headword_bonus(mention, lab)
        score = base * bonus
        if score > best[2]:
            best = (hp, lab, score)

    hp_id, hp_lab, score = best
    if hp_id is None:
        return None, None
    if score < max(65, SIM_THRESH):
        return None, None
    return hp_id, hp_lab

device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)
ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer,
                        aggregation_strategy="simple", device=device)

rows = []
for rec in gold_spans_by_text:
    orig_text = rec["text"]
    gold_spans = rec["gold_spans"] 
    clean_txt, c2o = clean_with_map(orig_text)

    preds = ner_pipeline(clean_txt)

    for ent in preds:
        if ent.get("entity_group") != "HPO_TERM":
            continue

        clean_start, clean_end = ent["start"], ent["end"]
        orig_start = c2o[clean_start]
        orig_end = c2o[clean_end - 1] + 1
        pred_span_text = orig_text[orig_start:orig_end]

        clean_span_text = clean_txt[clean_start:clean_end]
        pred_id, pred_label = normalize_mention(clean_span_text, orig_text)

        predicted_hpo_term = pred_label if pred_label else "no_match"
        predicted_hpo_id   = pred_id if pred_id else "no_match"

        gold_term = ""

        for gs in gold_spans:
            if gs[0] == orig_start and gs[1] == orig_end:
                gold_term = gs[2]
                break
        if not gold_term:
            overlapped = None
            for gs in gold_spans:
                gs_start, gs_end = gs[0], gs[1]
                if gs_start <= orig_end and gs_end >= orig_start:
                    overlapped = gs
                    break
            if overlapped:
                gold_term = overlapped[2]
            else:
                gold_term = "; ".join([gs[2] for gs in gold_spans]) if gold_spans else ""

        rows.append({
            "Gold HPO term": gold_term,                           
            "Span": pred_span_text,                              
            "Predicted standardised HPO term": predicted_hpo_term,
            "Predicted HPO ID": predicted_hpo_id,                 
            "Correct label? (Y/N/YN)": "",                        
            "Correct HPO term": "",                               
            "Correct HPO ID": ""
        })

df = pd.DataFrame(rows)
df.to_excel(OUTPUT_XLSX, index=False, engine="openpyxl")
print(f"Saved comparison file to {OUTPUT_XLSX.resolve()}")




>> Loading hp.obo for local HPO normalization


Device set to use cuda:0
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Saved comparison file to /kaggle/working/hpo_compare.xlsx


In [17]:
import json
from pathlib import Path
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline


MODEL_DIR = "/kaggle/input/ner-pubmedbert-saved-hpo/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")   
OUTPUT_XLSX = Path("/kaggle/working/hpo_pred_vs_gold.xlsx")


LIGATURE_MAP = {
    "ﬁ": "fi",
    "ﬂ": "fl",
    "ﬀ": "ff",
    "ﬃ": "ffi",
    "ﬄ": "ffl",
}
def clean_with_map(orig_text: str):
    clean_chars, c2o = [], []
    i = 0
    while i < len(orig_text):
        ch = orig_text[i]
        if ch in LIGATURE_MAP:
            repl = LIGATURE_MAP[ch]
            for rc in repl:
                clean_chars.append(rc)
                c2o.append(i)  
            i += 1
        else:
            clean_chars.append(ch)
            c2o.append(i)
            i += 1
    return "".join(clean_chars), c2o


device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)
ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",
    device=device
)


data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]


def format_spans(spans):
    # spans: list of dict with text/start/end
    return "; ".join([f"{s['text']}[{s['start']}:{s['end']}]" for s in spans])

def relaxed_match(pred, gold, text_len, tol=1):

    ps, pe = max(0, pred['start'] - tol), min(text_len, pred['end'] + tol)
    gs, ge = max(0, gold['start'] - tol), min(text_len, gold['end'] + tol)
    inter = max(0, min(pe, ge) - max(ps, gs))
    denom = max(pred['end'] - pred['start'], gold['end'] - gold['start'])
    if denom <= 0:
        return False, 0.0
    frac = inter / denom
    return (frac >= 0.5), frac


per_sentence_rows = []
detail_rows = []

for sid, rec in enumerate(data):
    orig_text = rec["text"]
    gold_spans = []
    for s in rec.get("spans", []):
        if s.get("label") == "HPO_TERM":
            gold_spans.append({
                "text": s.get("text", orig_text[s["start"]:s["end"]]),
                "start": int(s["start"]),
                "end": int(s["end"]),
            })

   
    clean_text, c2o = clean_with_map(orig_text)
    preds = []
    for ent in ner_pipeline(clean_text):
        if ent.get("entity_group") != "HPO_TERM":
            continue
        cs, ce = int(ent["start"]), int(ent["end"])
        if cs < 0 or ce > len(clean_text) or cs >= ce:
            continue
        os, oe = c2o[cs], c2o[ce - 1] + 1
        preds.append({
            "text": orig_text[os:oe],
            "start": os,
            "end": oe,
        })

    gold_used = [False] * len(gold_spans)
    pred_used = [False] * len(preds)

    for pi, p in enumerate(preds):
        best_g, best_frac = -1, 0.0
        for gi, g in enumerate(gold_spans):
            if gold_used[gi]:
                continue
            ok, frac = relaxed_match(p, g, len(orig_text), tol=1)
            if ok and frac > best_frac:
                best_frac, best_g = frac, gi
        if best_g >= 0:
            pred_used[pi] = True
            gold_used[best_g] = True
            g = gold_spans[best_g]
            detail_rows.append({
                "sent_id": sid,
                "type": "TP",
                "text": orig_text,
                "pred_text": p["text"], "pred_start": p["start"], "pred_end": p["end"],
                "gold_text": g["text"], "gold_start": g["start"], "gold_end": g["end"],
                "overlap_frac": round(best_frac, 3),
            })

    for pi, p in enumerate(preds):
        if not pred_used[pi]:
            detail_rows.append({
                "sent_id": sid,
                "type": "FP",
                "text": orig_text,
                "pred_text": p["text"], "pred_start": p["start"], "pred_end": p["end"],
                "gold_text": "", "gold_start": -1, "gold_end": -1,
                "overlap_frac": 0.0,
            })

    for gi, g in enumerate(gold_spans):
        if not gold_used[gi]:
            detail_rows.append({
                "sent_id": sid,
                "type": "FN",
                "text": orig_text,
                "pred_text": "", "pred_start": -1, "pred_end": -1,
                "gold_text": g["text"], "gold_start": g["start"], "gold_end": g["end"],
                "overlap_frac": 0.0,
            })

    TP = sum(1 for r in detail_rows if r["sent_id"] == sid and r["type"] == "TP")
    FP = sum(1 for r in detail_rows if r["sent_id"] == sid and r["type"] == "FP")
    FN = sum(1 for r in detail_rows if r["sent_id"] == sid and r["type"] == "FN")

    per_sentence_rows.append({
        "sent_id": sid,
        "text": orig_text,
        "gold_spans": format_spans(gold_spans),
        "pred_spans": format_spans(preds),
        "TP": TP, "FP": FP, "FN": FN
    })
with pd.ExcelWriter(OUTPUT_XLSX, engine="openpyxl") as writer:
    pd.DataFrame(per_sentence_rows).to_excel(writer, index=False, sheet_name="per_sentence")
    pd.DataFrame(detail_rows).to_excel(writer, index=False, sheet_name="details")

print(f"Saved comparison to: {OUTPUT_XLSX.resolve()}")


Device set to use cuda:0
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Saved comparison to: /kaggle/working/hpo_pred_vs_gold.xlsx


In [18]:
import re
import json
from pathlib import Path
from typing import Optional
from functools import lru_cache
import unicodedata

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import obonet
from rapidfuzz import process
from rapidfuzz.fuzz import token_set_ratio 

# === Config ===
MODEL_DIR = "/kaggle/input/ner-pubmedbert-saved-hpo/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE  = Path("/kaggle/working/bio_outputs/mapped_mentions.jsonl")
MAX_LENGTH = 512 
DEVICE = 0  # use -1 for CPU

# Monarch v3 API & ClinicalTables
MONARCH_BASE = "https://api-v3.monarchinitiative.org/v3/api"
CT_HPO_SEARCH = "https://clinicaltables.nlm.nih.gov/api/hpo/v3/search"


FUZZY_CUTOFF = 85


SIM_THRESH = 65


def clean_text(m: str) -> str:
    if not m:
        return m
    m = unicodedata.normalize("NFKC", m)
    m = m.replace("–", "-").replace("—", "-")
    m = m.translate(str.maketrans({
        "“": '"', "”": '"', "„": '"', "‟": '"',
        "’": "'", "‘": "'", "‚": "'", "‛": "'"
    }))
    m = re.sub(r"\s+", " ", m).strip()
    return m

_ABBR_PATTERNS = [
    (re.compile(r"\brrf\b", flags=re.I), "ragged red fibers"),
    (re.compile(r"\bragged[-\s]?red\b", flags=re.I), "ragged red"),
    (re.compile(r"\bcox\b", flags=re.I), "cytochrome c oxidase"),
    (re.compile(r"\bsdh\b", flags=re.I), "succinate dehydrogenase"),
    (re.compile(r"\bg[-\s]?tube\b", flags=re.I), "gastrostomy tube"),
]
def expand_abbrev(m: str) -> str:
    t = m
    for pat, rep in _ABBR_PATTERNS:
        t = pat.sub(rep, t)
    return t


_CANON_SUBS = [
    (re.compile(r"\bdysphasia\b", re.I), "aphasia"),
    (re.compile(r"\bwheel[-\s]?chair\s*bound\b", re.I), "wheelchair dependence"),
    (re.compile(r"\bfailing\s+to\s+thrive\b", re.I), "failure to thrive"),
    (re.compile(r"\bsyncopal\s+episode\b", re.I), "syncope"),
    (re.compile(r"\blumbosacral\s+radiculopathy\b", re.I), "radiculopathy"),
    (re.compile(r"\bragged\s+blue\b(?!\s*fib)", re.I), "ragged blue fibers"),

    (re.compile(r"cytochrome\s+c\s+oxidase\s*[-–—]?\s*negative", re.I),
     "cytochrome c oxidase-negative muscle fibers"),
    (re.compile(r"cytochrome\s+c\s+oxidase\s*[-–—]?\s*deficien\w*", re.I),
     "cytochrome c oxidase-deficient muscle fibers"),
    (re.compile(r"\boligoclonal\s+bands?\b.*\b(csfs?|cerebrospinal\s+fluid)\b", re.I),
     "oligoclonal bands in cerebrospinal fluid"),
    (re.compile(r"\bjerky\s+\w+(\s+and\s+\w+)?\s+movements\b", re.I), "myoclonus"),
    (re.compile(r"\bcomplete\s+absence\s+of\s+proprioceptive\s+sensation\b", re.I),
     "loss of proprioception"),
]
def canonicalize_synonyms(m: str) -> str:
    s = m
    for pat, rep in _CANON_SUBS:
        s = pat.sub(rep, s)
    return s

_MRC = re.compile(r"\b\d+\s*/\s*\d+\b")  # e.g., 5/5, 0/5
_STOPWORDS = {"in", "at", "of", "was", "to", "is", "are", "be", "the", "and"}
def is_noise(m: str) -> bool:
    if not m:
        return True
    m_strip = m.strip()
    alpha = sum(c.isalpha() for c in m_strip)
    if alpha < 2 or len(m_strip) < 3:
        return True
    if _MRC.search(m_strip):
        return True
    toks = re.findall(r"[a-z]+", m_strip.lower())
    if toks and all(t in _STOPWORDS for t in toks):
        return True
    return False

_PHENOTYPE_HINTS = {
    "weakness","atrophy","pain","tremor","paralysis","dystonia","rigidity",
    "contracture","spasm","edema","hyperintensity","hypointensity","lesion",
    "defect","deficiency","deficient","absence","loss","dysphagia","apraxia",
    "seizure","myoclonus","dysarthria","ataxia","paresis","dysesthesia",
    "fibers","fibres","ragged","ragged-red","cox-negative","sdh-positive",
    "aphasia","syncope","radiculopathy","gliosis","proprioception","wheelchair"
}
def looks_like_phenotype(m: str) -> bool:
    toks = set(re.findall(r"[a-z]+", m.lower()))
    return bool(toks & _PHENOTYPE_HINTS)

_REGION_PATTERNS = [
    (re.compile(r"\bcortical\s+gr[ae]y(?:\s+matter)?\b", re.I), "cortical gray matter"),
    (re.compile(r"\bsubcortical\s+white\s+matter\b", re.I), "subcortical white matter"),
    (re.compile(r"\bbasal\s+ganglia\b", re.I), "basal ganglia"),
    (re.compile(r"\bthalam(?:us|i)\b", re.I), "thalamus"),
    (re.compile(r"\bparieto[-\s]?occipital\b", re.I), "parieto-occipital region"),
    (re.compile(r"\bfrontoparietal\s+subcortical\s+white\s+matter\b", re.I), "frontoparietal subcortical white matter"),
    (re.compile(r"\bpre[-\s]?rolandic\b", re.I), "pre-rolandic cortex"),
    (re.compile(r"\binferior\s+olivary\s+nucleus\b", re.I), "inferior olivary nucleus"),
]

_FINDING_PATTERNS = [
    (re.compile(r"\bt2\s*[- ]?\s*hyperintens\w*\b", re.I), "T2 hyperintensity"),
    (re.compile(r"\bhyperintens\w*\b", re.I), "T2 hyperintensity"),  # 默认归一到 T2 hyperintensity
    (re.compile(r"\bhemorrhag\w*\b", re.I), "hemorrhage"),
    (re.compile(r"\bswell\w*\b", re.I), "swelling"),
    (re.compile(r"\batroph\w*\b", re.I), "atrophy"),
    (re.compile(r"\bglios\w*\b", re.I), "gliosis"),
    (re.compile(r"\blesion\w*\b", re.I), "lesion"),
    (re.compile(r"\bprolongation\b", re.I), "T2 prolongation"),
    (re.compile(r"\bhypointens\w*\b", re.I), "T2 hypointensity"),
]

_PATHO_TEMPLATES = [
    (re.compile(r"\bcytochrome\s+c\s+oxidase\b.*\bnegative\b.*\bfib(er|re)s\b", re.I),
     "cytochrome c oxidase-negative muscle fibers"),
    (re.compile(r"\bcox[-\s]?negative\b.*\bfib(er|re)s\b", re.I),
     "cytochrome c oxidase-negative muscle fibers"),
    (re.compile(r"\bcox[-\s]?deficien\w*\b.*\bfib(er|re)s\b", re.I),
     "cytochrome c oxidase-deficient muscle fibers"),
    (re.compile(r"\bragged\s+blue\b.*\bfib(er|re)s\b", re.I),
     "ragged blue fibers"),
    (re.compile(r"\boligoclonal\s+bands?\b", re.I),
     "oligoclonal bands in cerebrospinal fluid"),
    (re.compile(r"\bvariation\s+of\s+fiber\s+calib(er|re)\b", re.I),
     "variation in skeletal muscle fiber size"),
    (re.compile(r"\bnuclear\s+centralization\b", re.I),
     "increased central nuclei in skeletal muscle fibers"),
    (re.compile(r"\bfatty\s+replacement\b.*\bendomysial\b", re.I),
     "fatty infiltration of skeletal muscle"),
]

def rewrite_imaging_pathology(m: str) -> Optional[str]:
    s = m
    for pat, rep in _PATHO_TEMPLATES:
        if pat.search(s):
            return rep

    found_region = None
    for pat, norm in _REGION_PATTERNS:
        if pat.search(s):
            found_region = norm
            break

    found_finding = None
    for pat, norm in _FINDING_PATTERNS:
        if pat.search(s):
            found_finding = norm
            break

    if found_finding and found_region:
        return f"{found_finding} of {found_region}"
    if found_finding:
        return found_finding
    return None

def rewrite_numeric_to_qualitative(m: str) -> Optional[str]:
    s = m.lower()


    if re.search(r"\b(vitamin\s*b12|b\s*12)\b", s) and re.search(r"\b(low|decreas\w*|deficien\w*)\b", s):
        return "decreased circulating vitamin B12"


    if re.search(r"\b(25[-\s]?hydroxyvitamin\s*d|25ohd|25\W*oh\W*d)\b", s) and re.search(r"\b(low|decreas\w*|deficien\w*)\b", s):
        return "decreased circulating 25-hydroxyvitamin D"


    if re.search(r"\b(csfs?|cerebrospinal\s+fluid)\b.*\blactate\b", s) and re.search(r"\b(increas\w*|elevat\w*|high)\b", s):
        return "increased cerebrospinal fluid lactate"


    if re.search(r"\blactate\b.*\b(spectroscop\w*|mrs)\b", s):
        return "increased brain lactate on magnetic resonance spectroscopy"

    if re.search(r"\b(serum\s+)?ck\b", s) and (re.search(r"\blevel\b", s) or re.search(r"\d", s)):
        if re.search(r"\b(increas\w*|elevat\w*|high|above\s+normal|x\s*normal)\b", s) or re.search(r"\b\d+(\.\d+)?\s*x\s*normal\b", s):
            return "elevated serum creatine kinase"


    if re.search(r"\bmyoglobin\b", s) and (re.search(r"\blevel\b", s) or re.search(r"\d", s)):
        if re.search(r"\b(increas\w*|elevat\w*|high)\b", s):
            return "elevated myoglobin"


    if re.search(r"\b(pyruvate)\b", s) and (re.search(r"\b(blood|serum)\b", s) or re.search(r"\b\du?mol\b", s)):
        if re.search(r"\b(increas\w*|elevat\w*|high)\b", s) or re.search(r"\d", s):
            return "increased blood pyruvate"

    if re.search(r"\bvanillat\w*\b", s) and re.search(r"\b(increas\w*|elevat\w*|high)\b", s):
        return "increased vanillate level"

    if re.search(r"\bdopamine\s+transporter\b", s) and re.search(r"\b(reduc\w*|decreas\w*|low)\b", s):
        return "decreased dopamine transporter level"

    if re.search(r"\blogmar\b", s):
        return "decreased visual acuity"

    if re.search(r"\bketonuria\b", s):
        return "ketonuria"

    if re.search(r"\bvisual\s+disturbance(s)?\b", s):
        return "visual impairment"

    return None

def apply_templates(m: str) -> str:
    img = rewrite_imaging_pathology(m)
    if img:
        return img
    numq = rewrite_numeric_to_qualitative(m)
    if numq:
        return numq
    return m

def make_session() -> requests.Session:
    s = requests.Session()
    retries = Retry(
        total=3,
        backoff_factor=0.3,
        status_forcelist=(429, 500, 502, 503, 504),
        allowed_methods=frozenset(["GET"])
    )
    s.mount("https://", HTTPAdapter(max_retries=retries))
    s.mount("http://", HTTPAdapter(max_retries=retries))
    return s

SESSION = make_session()

# === Step 1: Load test data ===
print(">> Loading test data")
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [ex["text"] for ex in test_data]

# === Step 2: Load model and tokenizer with pipeline ===
print(">> Loading model and tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",
    device=DEVICE
)

# === Step 3: Load HPO terms from hp.obo
print(">> Loading HPO terms from obo")
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
try:
    graph = obonet.read_obo(obo_url)
except Exception as e:
    print(f">> Failed to fetch hp.obo (will still try online normalization): {e}")
    graph = None

hpo_map = {}
if graph is not None:
    for node_id, data in graph.nodes(data=True):
        name = data.get("name")
        if name:
            hpo_map.setdefault(name.lower(), []).append(node_id)
        for syn in data.get("synonym", []):
            m = re.search(r'"(.+?)"', syn)
            if m:
                hpo_map.setdefault(m.group(1).lower(), []).append(node_id)
hpo_keys = list(hpo_map.keys())

# === Online normalizers (Monarch v3 -> ClinicalTables) ===
@lru_cache(maxsize=10000)
def normalize_via_monarch(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    params = {"q": q, "category": "biolink:PhenotypicFeature", "limit": 5}
    try:
        r = SESSION.get(f"{MONARCH_BASE}/autocomplete", params=params, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        items = data.get("items") or data.get("results") or data
        if not isinstance(items, list):
            return None
        best_curie, best_score = None, -1
        for it in items:
            curie = it.get("id") or it.get("curie")
            label = (it.get("label") or it.get("name") or "").strip()
            cats  = it.get("category") or it.get("categories") or []
            if isinstance(cats, str):
                cats = [cats]
            if curie and str(curie).startswith("HP:"):
                if not cats or any("PhenotypicFeature" in c for c in cats):
                    score = token_set_ratio(q.lower(), label.lower()) if label else 0
                    if score > best_score:
                        best_score, best_curie = score, curie
        return best_curie if best_score >= SIM_THRESH else None
    except Exception:
        return None

@lru_cache(maxsize=10000)
def normalize_via_ct(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    try:
        r = SESSION.get(CT_HPO_SEARCH, params={"terms": q, "maxList": 10}, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        if isinstance(data, list) and len(data) >= 2:
            ids, names = data[0], data[1]
            best_id, best_score = None, -1
            for hp_id, name in zip(ids, names):
                if isinstance(hp_id, str) and hp_id.startswith("HP:"):
                    score = token_set_ratio(q.lower(), (name or "").lower())
                    if score > best_score:
                        best_score, best_id = score, hp_id
            return best_id if best_score >= SIM_THRESH else None
    except Exception:
        return None
    return None

# === normalize_mention：
def normalize_mention(text: str):
    curie = normalize_via_monarch(text)
    if curie:
        return curie
    curie = normalize_via_ct(text)
    if curie:
        return curie
    key = text.lower().strip()
    if key in hpo_map:
        score = token_set_ratio(key, key)
        return hpo_map[key][0] if score >= SIM_THRESH else None
    match = process.extractOne(key, hpo_keys, score_cutoff=FUZZY_CUTOFF) if hpo_keys else None
    if match:
        matched_term = match[0]
        score = token_set_ratio(key, matched_term)
        if score >= SIM_THRESH:
            return hpo_map[matched_term][0]
    return None

# === Step 4: Run NER + Normalize ===
print(">> Running NER and normalization")
mapped_mentions = []
unmapped_mentions = []
normalized_mentions = []

for idx, sentence in enumerate(orig_sentences):
    results = ner_pipeline(sentence)
    for ent in results:
        if ent["entity_group"] != "HPO_TERM":
            continue
        mention = ent["word"].strip()

        mention = clean_text(mention)
        mention = expand_abbrev(mention)
        mention = canonicalize_synonyms(mention)
        mention = apply_templates(mention)
        if is_noise(mention):
            continue
        # if not looks_like_phenotype(mention):
        #     continue

        hpo_id = normalize_mention(mention)
        (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))
        normalized_mentions.append({
            "sentence_id": idx,
            "mention": mention,
            "hpo_id": hpo_id
        })

total  = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"\nTotal mentions:  {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})" if total else "Mapped to HP ID: 0 (n/a)")
print(f"Failed to map:   {total - mapped} ({(total - mapped)/total:.1%})" if total else "Failed to map: 0 (n/a)")

OUT_FILE.parent.mkdir(parents=True, exist_ok=True)
with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"\n Mapped Mentions ({len(mapped_mentions)}):")
for mention, hpo_id in mapped_mentions:
    print(f"{mention} --> {hpo_id}")

print(f"\n Unmapped Mentions ({len(unmapped_mentions)}):")
for mention, _ in unmapped_mentions:
    print(mention)



Device set to use cuda:0


>> Loading test data
>> Loading model and tokenizer
>> Loading HPO terms from obo


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


>> Running NER and normalization

Total mentions:  638
Mapped to HP ID: 496 (77.7%)
Failed to map:   142 (22.3%)

 Mapped Mentions (496):
cytochrome c oxidase-negative muscle fibers --> HP:0003688
absent reflexes in upper extremities --> HP:0001284
polyneuropathy --> HP:0001271
neuropathic pain --> HP:6000040
ragged - red fibers --> HP:0003200
dyskinesia --> HP:0100660
severe stand - ing tremor --> HP:0001337
dysphagia --> HP:0002015
loss of weight --> HP:0001824
malnutrition --> HP:0004395
mechanical ventilation --> HP:0004887
elevated serum creatine kinase --> HP:0008180
increased blood pyruvate --> HP:0003542
decreased circulating vitamin B12 --> HP:0100502
T2 hyperintensity --> HP:0031206
cerebellar white matter --> HP:0007033
atrophy --> HP:0000029
neurological deterioration --> HP:0002344
tetraparesis --> HP:0002273
hypotonia --> HP:0001252
cognitive impairment --> HP:0100543
multifocal myoclonus --> HP:0040148
T2 hyperintensity --> HP:0031206
bilateral cerebellum --> HP:0012832


In [19]:
import re
import json
from pathlib import Path
from typing import Optional
from functools import lru_cache
import unicodedata  

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import obonet
from rapidfuzz import process
from rapidfuzz.fuzz import token_set_ratio 

# === Config ===
MODEL_DIR = "/kaggle/input/ner-pubmedbert-saved-hpo/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE  = Path("/kaggle/working/bio_outputs/mapped_mentions.jsonl")
MAX_LENGTH = 512  
DEVICE = 0  # use -1 for CPU


MONARCH_BASE = "https://api-v3.monarchinitiative.org/v3/api"
CT_HPO_SEARCH = "https://clinicaltables.nlm.nih.gov/api/hpo/v3/search"


FUZZY_CUTOFF = 85


SIM_THRESH = 50


def clean_text(m: str) -> str:
    if not m:
        return m
    m = unicodedata.normalize("NFKC", m)
    m = m.replace("–", "-").replace("—", "-")
    m = m.translate(str.maketrans({
        "“": '"', "”": '"', "„": '"', "‟": '"',
        "’": "'", "‘": "'", "‚": "'", "‛": "'"
    }))
    m = re.sub(r"\s+", " ", m).strip()
    return m

_ABBR_PATTERNS = [
    (re.compile(r"\brrf\b", flags=re.I), "ragged red fibers"),
    (re.compile(r"\bragged[-\s]?red\b", flags=re.I), "ragged red"),
    (re.compile(r"\bcox\b", flags=re.I), "cytochrome c oxidase"),
    (re.compile(r"\bsdh\b", flags=re.I), "succinate dehydrogenase"),
    (re.compile(r"\bg[-\s]?tube\b", flags=re.I), "gastrostomy tube"),
]
def expand_abbrev(m: str) -> str:
    t = m
    for pat, rep in _ABBR_PATTERNS:
        t = pat.sub(rep, t)
    return t


_CANON_SUBS = [
    (re.compile(r"\bdysphasia\b", re.I), "aphasia"),
    (re.compile(r"\bwheel[-\s]?chair\s*bound\b", re.I), "wheelchair dependence"),
    (re.compile(r"\bfailing\s+to\s+thrive\b", re.I), "failure to thrive"),
    (re.compile(r"\bsyncopal\s+episode\b", re.I), "syncope"),
    (re.compile(r"\blumbosacral\s+radiculopathy\b", re.I), "radiculopathy"),
    (re.compile(r"\bragged\s+blue\b(?!\s*fib)", re.I), "ragged blue fibers"),

    (re.compile(r"cytochrome\s+c\s+oxidase\s*[-–—]?\s*negative", re.I),
     "cytochrome c oxidase-negative muscle fibers"),
    (re.compile(r"cytochrome\s+c\s+oxidase\s*[-–—]?\s*deficien\w*", re.I),
     "cytochrome c oxidase-deficient muscle fibers"),
    (re.compile(r"\boligoclonal\s+bands?\b.*\b(csfs?|cerebrospinal\s+fluid)\b", re.I),
     "oligoclonal bands in cerebrospinal fluid"),
    (re.compile(r"\bjerky\s+\w+(\s+and\s+\w+)?\s+movements\b", re.I), "myoclonus"),
    (re.compile(r"\bcomplete\s+absence\s+of\s+proprioceptive\s+sensation\b", re.I),
     "loss of proprioception"),


    (re.compile(r"\bpolymini\W*myoclonus\b", re.I), "polyminimyoclonus"),
    (re.compile(r"\bdystonic\s+toe\s+curling\b", re.I), "dystonia of toes"),
    (re.compile(r"\bstriatal\s+toes?\b", re.I), "striatal toe"),
    (re.compile(r"\blaterocollis\b", re.I), "laterocollis"),
    (re.compile(r"\bnear\s+falls?\b", re.I), "recurrent falls"),
    (re.compile(r"\bdifficulties\s+in\s+handling\s+objects\b", re.I), "apraxia"),
]
def canonicalize_synonyms(m: str) -> str:
    s = m
    for pat, rep in _CANON_SUBS:
        s = pat.sub(rep, s)
    return s

_MRC = re.compile(r"\b\d+\s*/\s*\d+\b")  # e.g., 5/5, 0/5
_STOPWORDS = {"in", "at", "of", "was", "to", "is", "are", "be", "the", "and"}


_BAD_START = re.compile(r"^##")
_BAD_FRAGMENTS = [
    re.compile(r"\bfed\s+through\s+a\b", re.I),
    re.compile(r"\bshowed\s+an\b", re.I),
    re.compile(r"^\s*(hepat|dysm|##ar|##et)\b", re.I),
]
def is_noise(m: str) -> bool:
    if not m:
        return True
    m_strip = m.strip()
    if _BAD_START.search(m_strip):
        return True
    for pat in _BAD_FRAGMENTS:
        if pat.search(m_strip):
            return True
    alpha = sum(c.isalpha() for c in m_strip)
    if alpha < 2 or len(m_strip) < 3:
        return True
    if _MRC.search(m_strip):
        return True
    toks = re.findall(r"[a-z]+", m_strip.lower())
    if toks and all(t in _STOPWORDS for t in toks):
        return True
    return False


_PHENOTYPE_HINTS = {
    "weakness","atrophy","pain","tremor","paralysis","dystonia","rigidity",
    "contracture","spasm","edema","hyperintensity","hypointensity","lesion",
    "defect","deficiency","deficient","absence","loss","dysphagia","apraxia",
    "seizure","myoclonus","dysarthria","ataxia","paresis","dysesthesia",
    "fibers","fibres","ragged","ragged-red","cox-negative","sdh-positive",
    "aphasia","syncope","radiculopathy","gliosis","proprioception","wheelchair"
}
def looks_like_phenotype(m: str) -> bool:
    toks = set(re.findall(r"[a-z]+", m.lower()))
    return bool(toks & _PHENOTYPE_HINTS)


_REGION_PATTERNS = [
    (re.compile(r"\bcortical\s+gr[ae]y(?:\s+matter)?\b", re.I), "cortical gray matter"),
    (re.compile(r"\bsubcortical\s+white\s+matter\b", re.I), "subcortical white matter"),
    (re.compile(r"\bbasal\s+ganglia\b", re.I), "basal ganglia"),
    (re.compile(r"\bthalam(?:us|i)\b", re.I), "thalamus"),
    (re.compile(r"\bparieto[-\s]?occipital\b", re.I), "parieto-occipital region"),
    (re.compile(r"\bfrontoparietal\s+subcortical\s+white\s+matter\b", re.I), "frontoparietal subcortical white matter"),
    (re.compile(r"\bpre[-\s]?rolandic\b", re.I), "pre-rolandic cortex"),
    (re.compile(r"\binferior\s+olivary\s+nucleus\b", re.I), "inferior olivary nucleus"),
]

_FINDING_PATTERNS = [
    (re.compile(r"\bt2\s*[- ]?\s*hyperintens\w*\b", re.I), "T2 hyperintensity"),
    (re.compile(r"\bhypersignal(s)?\b", re.I), "T2 hyperintensity"),  # 新增：支持 hypersignal(s)
    (re.compile(r"\bhyperintens\w*\b", re.I), "T2 hyperintensity"),  # 默认归一到 T2 hyperintensity
    (re.compile(r"\bhemorrhag\w*\b", re.I), "hemorrhage"),
    (re.compile(r"\bswell\w*\b", re.I), "swelling"),
    (re.compile(r"\batroph\w*\b", re.I), "atrophy"),
    (re.compile(r"\bglios\w*\b", re.I), "gliosis"),
    (re.compile(r"\blesion\w*\b", re.I), "lesion"),
    (re.compile(r"\bprolongation\b", re.I), "T2 prolongation"),
    (re.compile(r"\bhypointens\w*\b", re.I), "T2 hypointensity"),
]


_PATHO_TEMPLATES = [
    (re.compile(r"\bcytochrome\s+c\s+oxidase\b.*\bnegative\b.*\bfib(er|re)s\b", re.I),
     "cytochrome c oxidase-negative muscle fibers"),
    (re.compile(r"\bcox[-\s]?negative\b.*\bfib(er|re)s\b", re.I),
     "cytochrome c oxidase-negative muscle fibers"),
    (re.compile(r"\bcox[-\s]?deficien\w*\b.*\bfib(er|re)s\b", re.I),
     "cytochrome c oxidase-deficient muscle fibers"),
    (re.compile(r"\bragged\s+blue\b.*\bfib(er|re)s\b", re.I),
     "ragged blue fibers"),
    (re.compile(r"\boligoclonal\s+bands?\b", re.I),
     "oligoclonal bands in cerebrospinal fluid"),
    (re.compile(r"\bvariation\s+of\s+fiber\s+calib(er|re)\b", re.I),
     "variation in skeletal muscle fiber size"),
    (re.compile(r"\bnuclear\s+centralization\b", re.I),
     "increased central nuclei in skeletal muscle fibers"),
    (re.compile(r"\bfatty\s+replacement\b.*\bendomysial\b", re.I),
     "fatty infiltration of skeletal muscle"),


    (re.compile(r"\bparacrystalline\s+inclusion\w*\b", re.I),
     "mitochondrial paracrystalline inclusions"),
    (re.compile(r"\bswollen\s+mitochondria\b|\babnormally\s+swollen\s+mitochondria\b", re.I),
     "swollen mitochondria"),
    (re.compile(r"\b(concentric|tubular|irregular)\s+cristae\b", re.I),
     "abnormal mitochondrial cristae morphology"),
    (re.compile(r"\b(poly)?spike\s*-\s*(and\s*-\s*)?slow\s*waves?\b", re.I),
     "EEG with epileptiform discharges"),
    (re.compile(r"\bsharp\s+and\s+slow\s+wave(s)?\b", re.I),
     "EEG with epileptiform discharges"),
]

def rewrite_imaging_pathology(m: str) -> Optional[str]:
    s = m

    for pat, rep in _PATHO_TEMPLATES:
        if pat.search(s):
            return rep

    found_region = None
    for pat, norm in _REGION_PATTERNS:
        if pat.search(s):
            found_region = norm
            break

    found_finding = None
    for pat, norm in _FINDING_PATTERNS:
        if pat.search(s):
            found_finding = norm
            break

    if found_finding and found_region:
        return f"{found_finding} of {found_region}"
    if found_finding:
        return found_finding
    return None


def rewrite_numeric_to_qualitative(m: str) -> Optional[str]:
    s = m.lower()


    if re.search(r"\b(vitamin\s*b12|b\s*12)\b", s) and re.search(r"\b(low|decreas\w*|deficien\w*)\b", s):
        return "decreased circulating vitamin B12"


    if re.search(r"\b(25[-\s]?hydroxyvitamin\s*d|25ohd|25\W*oh\W*d)\b", s) and re.search(r"\b(low|decreas\w*|deficien\w*)\b", s):
        return "decreased circulating 25-hydroxyvitamin D"


    if re.search(r"\b(csfs?|cerebrospinal\s+fluid)\b.*\blactate\b", s) and re.search(r"\b(increas\w*|elevat\w*|high)\b", s):
        return "increased cerebrospinal fluid lactate"


    if re.search(r"\blactate\b.*\b(spectroscop\w*|mrs)\b", s):
        return "increased brain lactate on magnetic resonance spectroscopy"


    if re.search(r"\b(serum\s+)?ck\b", s) and (re.search(r"\blevel\b", s) or re.search(r"\d", s)):
        if re.search(r"\b(increas\w*|elevat\w*|high|above\s+normal|x\s*normal)\b", s) or re.search(r"\b\d+(\.\d+)?\s*x\s*normal\b", s):
            return "elevated serum creatine kinase"


    if re.search(r"\bmyoglobin\b", s) and (re.search(r"\blevel\b", s) or re.search(r"\d", s)):
        if re.search(r"\b(increas\w*|elevat\w*|high)\b", s):
            return "elevated myoglobin"

    if re.search(r"\b(pyruvate)\b", s) and (re.search(r"\b(blood|serum)\b", s) or re.search(r"\b\du?mol\b", s)):
        if re.search(r"\b(increas\w*|elevat\w*|high)\b", s) or re.search(r"\d", s):
            return "increased blood pyruvate"

    if re.search(r"\bvanillat\w*\b", s) and re.search(r"\b(increas\w*|elevat\w*|high)\b", s):
        return "increased vanillate level"


    if re.search(r"\bdopamine\s+transporter\b", s) and re.search(r"\b(reduc\w*|decreas\w*|low)\b", s):
        return "decreased dopamine transporter level"


    if re.search(r"\blogmar\b", s):
        return "decreased visual acuity"


    if re.search(r"\bketonuria\b", s):
        return "ketonuria"


    if re.search(r"\bvisual\s+disturbance(s)?\b", s):
        return "visual impairment"

    return None

def apply_templates(m: str) -> str:
    img = rewrite_imaging_pathology(m)
    if img:
        return img
    numq = rewrite_numeric_to_qualitative(m)
    if numq:
        return numq
    return m

def dedup_adjacent_words(s: str) -> str:
    return re.sub(r'\b(\w+)(\s+\1\b)+', r'\1', s, flags=re.I)

# === HTTP session with retries ===
def make_session() -> requests.Session:
    s = requests.Session()
    retries = Retry(
        total=3,
        backoff_factor=0.3,
        status_forcelist=(429, 500, 502, 503, 504),
        allowed_methods=frozenset(["GET"])
    )
    s.mount("https://", HTTPAdapter(max_retries=retries))
    s.mount("http://", HTTPAdapter(max_retries=retries))
    return s

SESSION = make_session()

# === Step 1: Load test data ===
print(">> Loading test data")
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [ex["text"] for ex in test_data]

# === Step 2: Load model and tokenizer with pipeline ===
print(">> Loading model and tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",
    device=DEVICE
)

# === Step 3: Load HPO terms from hp.obo
print(">> Loading HPO terms from obo")
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
try:
    graph = obonet.read_obo(obo_url)
except Exception as e:
    print(f">> Failed to fetch hp.obo (will still try online normalization): {e}")
    graph = None

hpo_map = {}
if graph is not None:
    for node_id, data in graph.nodes(data=True):
        name = data.get("name")
        if name:
            hpo_map.setdefault(name.lower(), []).append(node_id)
        for syn in data.get("synonym", []):
            m = re.search(r'"(.+?)"', syn)
            if m:
                hpo_map.setdefault(m.group(1).lower(), []).append(node_id)
hpo_keys = list(hpo_map.keys())

# === Online normalizers (Monarch v3 -> ClinicalTables) ===
@lru_cache(maxsize=10000)
def normalize_via_monarch(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    params = {"q": q, "category": "biolink:PhenotypicFeature", "limit": 5}
    try:
        r = SESSION.get(f"{MONARCH_BASE}/autocomplete", params=params, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        items = data.get("items") or data.get("results") or data
        if not isinstance(items, list):
            return None
        best_curie, best_score = None, -1
        for it in items:
            curie = it.get("id") or it.get("curie")
            label = (it.get("label") or it.get("name") or "").strip()
            cats  = it.get("category") or it.get("categories") or []
            if isinstance(cats, str):
                cats = [cats]
            if curie and str(curie).startswith("HP:"):
                if not cats or any("PhenotypicFeature" in c for c in cats):
                    score = token_set_ratio(q.lower(), label.lower()) if label else 0
                    if score > best_score:
                        best_score, best_curie = score, curie
        return best_curie if best_score >= SIM_THRESH else None
    except Exception:
        return None

@lru_cache(maxsize=10000)
def normalize_via_ct(text: str) -> Optional[str]:
    q = text.strip()
    if not q:
        return None
    try:
        r = SESSION.get(CT_HPO_SEARCH, params={"terms": q, "maxList": 10}, timeout=6)
        if r.status_code != 200:
            return None
        data = r.json()
        if isinstance(data, list) and len(data) >= 2:
            ids, names = data[0], data[1]
            best_id, best_score = None, -1
            for hp_id, name in zip(ids, names):
                if isinstance(hp_id, str) and hp_id.startswith("HP:"):
                    score = token_set_ratio(q.lower(), (name or "").lower())
                    if score > best_score:
                        best_score, best_id = score, hp_id
            return best_id if best_score >= SIM_THRESH else None
    except Exception:
        return None
    return None

def normalize_mention(text: str):
    curie = normalize_via_monarch(text)
    if curie:
        return curie
    curie = normalize_via_ct(text)
    if curie:
        return curie
    key = text.lower().strip()
    if key in hpo_map:
        score = token_set_ratio(key, key)
        return hpo_map[key][0] if score >= SIM_THRESH else None
    match = process.extractOne(key, hpo_keys, score_cutoff=FUZZY_CUTOFF) if hpo_keys else None
    if match:
        matched_term = match[0]
        score = token_set_ratio(key, matched_term)
        if score >= SIM_THRESH:
            return hpo_map[matched_term][0]
    return None

print(">> Running NER and normalization")
mapped_mentions = []
unmapped_mentions = []
normalized_mentions = []

for idx, sentence in enumerate(orig_sentences):
    results = ner_pipeline(sentence)
    for ent in results:
        if ent["entity_group"] != "HPO_TERM":
            continue
        mention = ent["word"].strip()

        mention = clean_text(mention)
        mention = expand_abbrev(mention)
        mention = canonicalize_synonyms(mention)
        mention = apply_templates(mention)
        mention = dedup_adjacent_words(mention)
        if is_noise(mention):
            continue
        # if not looks_like_phenotype(mention):
        #     continue

        hpo_id = normalize_mention(mention)
        (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))
        normalized_mentions.append({
            "sentence_id": idx,
            "mention": mention,
            "hpo_id": hpo_id
        })

total  = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"\nTotal mentions:  {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})" if total else "Mapped to HP ID: 0 (n/a)")
print(f"Failed to map:   {total - mapped} ({(total - mapped)/total:.1%})" if total else "Failed to map: 0 (n/a)")

OUT_FILE.parent.mkdir(parents=True, exist_ok=True)
with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"\n Mapped Mentions ({len(mapped_mentions)}):")
for mention, hpo_id in mapped_mentions:
    print(f"{mention} --> {hpo_id}")

print(f"\n Unmapped Mentions ({len(unmapped_mentions)}):")
for mention, _ in unmapped_mentions:
    print(mention)


Device set to use cuda:0


>> Loading test data
>> Loading model and tokenizer
>> Loading HPO terms from obo


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


>> Running NER and normalization

Total mentions:  631
Mapped to HP ID: 556 (88.1%)
Failed to map:   75 (11.9%)

 Mapped Mentions (556):
cytochrome c oxidase-negative muscle fibers --> HP:0003688
absent reflexes in upper extremities --> HP:0001284
polyneuropathy --> HP:0001271
neuropathic pain --> HP:6000040
ragged - red fibers --> HP:0003200
dyskinesia --> HP:0100660
severe stand - ing tremor --> HP:0001337
dysphagia --> HP:0002015
loss of weight --> HP:0001824
malnutrition --> HP:0004395
ventilatory failure --> HP:0000198
mechanical ventilation --> HP:0004887
elevated serum creatine kinase --> HP:0008180
increased blood pyruvate --> HP:0003542
decreased circulating vitamin B12 --> HP:0100502
T2 hyperintensity --> HP:0031206
cerebellar white matter --> HP:0007033
atrophy --> HP:0000029
neurological deterioration --> HP:0002344
tetraparesis --> HP:0002273
hypotonia --> HP:0001252
cognitive impairment --> HP:0100543
multifocal myoclonus --> HP:0040148
T2 hyperintensity --> HP:0031206
bi