In [None]:
import json
from pathlib import Path

INPUT_DIR = Path("/kaggle/input/fullfull/annotations")
OUTPUT_FILE = Path("merged_spans_with_entities.jsonl")

merged = []

for span_path in sorted(INPUT_DIR.glob("*_spans.jsonl")):
    filename = span_path.name
    with span_path.open("r", encoding="utf-8") as f:
        for lineno, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                print(f"Skipping empty line at {filename}:{lineno}")
                continue
            try:
                rec = json.loads(line)
            except json.JSONDecodeError as e:
                print(f"JSON decode error at {filename}:{lineno} — {e}")
                continue

            spans = rec.get("spans", [])
            if not spans:
                continue

            entry = {
                "text": rec.get("text", ""),
                "tokens": rec.get("tokens", []),
                "spans": spans,
            }
            merged.append(entry)

with OUTPUT_FILE.open("w", encoding="utf-8") as fw:
    for entry in merged:
        fw.write(json.dumps(entry, ensure_ascii=False) + "\n")

print(f"Merged and saved {len(merged)} entity-containing records to: {OUTPUT_FILE.resolve()}")

In [None]:
import json
from pathlib import Path

INPUT_DIR = Path("/kaggle/input/fullfull/annotations")      # 修改为你的输入目录
OUTPUT_FILE = Path("merged_spans.jsonl")                    # 输出文件名可按需调整

merged = []

for span_path in sorted(INPUT_DIR.glob("*_spans.jsonl")):
    filename = span_path.name
    with span_path.open("r", encoding="utf-8") as f:
        for lineno, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                print(f"Skipping empty line at {filename}:{lineno}")
                continue
            try:
                rec = json.loads(line)
            except json.JSONDecodeError as e:
                print(f"JSON decode error at {filename}:{lineno} — {e}")
                continue

            spans = rec.get("spans", [])
            if not spans:
                continue

            # 只保留 text 和 spans
            merged.append({
                "text": rec.get("text", ""),
                "spans": spans,
            })

with OUTPUT_FILE.open("w", encoding="utf-8") as fw:
    for entry in merged:
        fw.write(json.dumps(entry, ensure_ascii=False) + "\n")

print(f"Merged and saved {len(merged)} records (text + spans) to: {OUTPUT_FILE.resolve()}")

In [None]:
import json
from pathlib import Path
import pandas as pd

# --- CONFIGURATION ---
INPUT_XLSX = Path("/kaggle/input/polgtable/subset_POLG.xlsx")
OUTPUT_FILE = Path("hpo_terms.jsonl")

# --- STEP 1: 读取数据 ---
df = pd.read_excel(INPUT_XLSX, engine="openpyxl")

# --- STEP 2: 提取所有 HPO 术语（保留顺序 + 去重）---
all_hpo_terms = []
seen = set()

for terms in df["HPO_Term"]:
    if pd.isna(terms):
        continue
    for term in str(terms).split(";"):
        term = term.strip()
        if term and term not in seen:
            all_hpo_terms.append(term)
            seen.add(term)

# --- STEP 3: 写入 JSONL 文件 ---
with OUTPUT_FILE.open("w", encoding="utf-8") as fout:
    for hpo in all_hpo_terms:
        fout.write(json.dumps({"HPO_TERM": hpo}, ensure_ascii=False) + "\n")

print(f"✅ 提取完成，共写入 {len(all_hpo_terms)} 个 HPO_TERM 到 {OUTPUT_FILE.resolve()}")



In [None]:
import json
from pathlib import Path
from collections import Counter, defaultdict
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

# -------------------
# Constants & Paths
# -------------------
FILE_MERGED = Path("/kaggle/working/merged_spans_with_entities.jsonl")
OUT_DIR     = Path("/kaggle/working/bio_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_BIO = OUT_DIR / "train.jsonl"
DEV_BIO   = OUT_DIR / "dev.jsonl"
TEST_BIO  = OUT_DIR / "test.jsonl"
TEST_TEXT_ONLY = OUT_DIR / "test_text_only.jsonl"   # ✅ 新增路径

ENTITY_TYPES = {
    "AGE_ONSET", "AGE_FOLLOWUP", "AGE_DEATH",
    "PATIENT", "HPO_TERM", "GENE", "GENE_VARIANT"
}

tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    use_fast=True
)

# -------------------
# Utility Functions
# -------------------
def iter_jsonl(path: Path):
    with path.open("r", encoding="utf-8") as fh:
        for line in fh:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                continue

def make_bio_labels(spans, enc):
    tokens   = enc.tokens()
    offsets  = enc["offset_mapping"]
    word_ids = enc.word_ids()
    tags     = ["O"] * len(tokens)
    span_to_tokens = []

    for sp in spans:
        s, e, typ = sp["start"], sp["end"], sp["label"]
        idxs = [
            i for i, (b, t) in enumerate(offsets)
            if not (t <= s or b >= e)
        ]
        span_to_tokens.append(idxs)
        if not idxs:
            continue
        tags[idxs[0]] = f"B-{typ}"
        for i in idxs[1:]:
            tags[i] = f"I-{typ}"

    for idxs in span_to_tokens:
        if len(idxs) == 1:
            tags[idxs[0]] = tags[idxs[0]].replace("I-", "B-")

    prev_wid = None
    for i, wid in enumerate(word_ids):
        if (wid is not None and wid == prev_wid
                and tags[i] == "O"
                and tags[i - 1].startswith(("B-", "I-"))):
            tags[i] = "I-" + tags[i - 1][2:]
        prev_wid = wid

    return tags

def record_to_bio(rec):
    text  = rec.get("text", "")
    spans = [s for s in rec.get("spans", []) if s.get("label") in ENTITY_TYPES]
    if not spans:
        return None
    enc = tokenizer(
        text,
        add_special_tokens=False,
        return_offsets_mapping=True,
        truncation=True,
        max_length=512
    )
    return {
        "text": text,  # ✅ 保留原始句子，供后续保存
        "tokens": enc.tokens(),
        "labels": make_bio_labels(spans, enc)
    }

def dump_jsonl(path: Path, data):
    with path.open("w", encoding="utf-8") as fh:
        for obj in data:
            fh.write(json.dumps(obj, ensure_ascii=False) + "\n")

# -------------------
# Load & Convert
# -------------------
print(">> Loading and converting gold-standard data …")
merged_bio = [
    bio for rec in iter_jsonl(FILE_MERGED)
    if (bio := record_to_bio(rec)) is not None
]
print(f"Total valid records: {len(merged_bio)}")

# -------------------
# Simple Random Split (70/20/10)
# -------------------
train_set, temp_set = train_test_split(
    merged_bio, test_size=0.3, random_state=42
)
dev_set, test_set = train_test_split(
    temp_set, test_size=1/3, random_state=42
)

print(f"Split sizes – TRAIN: {len(train_set)}, DEV: {len(dev_set)}, TEST: {len(test_set)}")

# -------------------
# Save BIO Format Files
# -------------------
dump_jsonl(TRAIN_BIO, train_set)
dump_jsonl(DEV_BIO, dev_set)
dump_jsonl(TEST_BIO, test_set)

# -------------------
# ✅ Save only raw test set texts
# -------------------
with TEST_TEXT_ONLY.open("w", encoding="utf-8") as fw:
    for ex in test_set:
        if "text" in ex:
            fw.write(json.dumps({"text": ex["text"]}, ensure_ascii=False) + "\n")

print(f"Saved to {TRAIN_BIO}, {DEV_BIO}, {TEST_BIO}")
print(f"✅ Raw test text saved to: {TEST_TEXT_ONLY}")


In [None]:
!pip install seqeval evaluate torchcrf

In [None]:
import json
from pathlib import Path

from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)
import evaluate

# 1. Load BIO datasets
BIO_DIR = Path("/kaggle/working/bio_outputs")

def load_jsonl(path: Path):
    with path.open(encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]

# 1.1 过滤非 HPO_TERM 标签为 'O'
def keep_only_hpo_labels(example):
    example["labels"] = [
        lab if lab.endswith("HPO_TERM") else "O"
        for lab in example["labels"]
    ]
    return example

train_examples = [keep_only_hpo_labels(ex) for ex in load_jsonl(BIO_DIR / "train.jsonl")]
dev_examples   = [keep_only_hpo_labels(ex) for ex in load_jsonl(BIO_DIR / "dev.jsonl")]
test_examples  = [keep_only_hpo_labels(ex) for ex in load_jsonl(BIO_DIR / "test.jsonl")]

ds_splits = DatasetDict({
    "train":      Dataset.from_list(train_examples),
    "validation": Dataset.from_list(dev_examples),
    "test":       Dataset.from_list(test_examples),
})
print("Loaded dataset sizes:", {k: len(v) for k, v in ds_splits.items()})

# 2. Tokenizer & label mapping
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    use_fast=True,
)

# 2.1 只保留 ['B-HPO_TERM', 'I-HPO_TERM', 'O']
unique_labels = sorted({lab for ex in train_examples + dev_examples + test_examples
                        for lab in ex["labels"]})
label2id = {lab: i for i, lab in enumerate(unique_labels)}
id2label = {i: lab for lab, i in label2id.items()}

def tokenize_and_align_labels(ex):
    enc = tokenizer(
        ex["tokens"],
        is_split_into_words=True,
        truncation=True,
        max_length=512,
        return_attention_mask=True,
    )
    enc["labels"] = [label2id[l] for l in ex["labels"]]
    return enc

ds_splits = ds_splits.map(
    tokenize_and_align_labels,
    batched=False,
    remove_columns=["tokens", "labels"],
)

# 3. Model
model = AutoModelForTokenClassification.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    num_labels=len(unique_labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

# 4. Metrics
seqeval = evaluate.load("seqeval")

def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    refs = p.label_ids
    true_labels = [
        [id2label[lid] for lid in seq if lid != -100] for seq in refs
    ]
    pred_labels = [
        [id2label[pid] for pid, lid in zip(pred_seq, ref_seq) if lid != -100]
        for pred_seq, ref_seq in zip(preds, refs)
    ]
    result = seqeval.compute(predictions=pred_labels, references=true_labels)
    return {
        "overall_precision": result["overall_precision"],
        "overall_recall":    result["overall_recall"],
        "overall_f1":        result["overall_f1"],
        "overall_accuracy":  result["overall_accuracy"],
    }

data_collator = DataCollatorForTokenClassification(tokenizer)

# 5. Training arguments and Trainer
training_args = TrainingArguments(
    output_dir="ner_pubmedbert",
    eval_strategy="steps",
    eval_steps=50,
    save_steps=500,
    logging_strategy="steps",
    logging_steps=50,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    learning_rate=3e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="overall_f1",
    greater_is_better=True,
    report_to=["none"],
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_splits["train"],
    eval_dataset=ds_splits["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# 6. Train and evaluate
trainer.train()
trainer.evaluate()

# 7. Predict on test set
test_metrics = trainer.predict(ds_splits["test"]).metrics
print("Test set metrics:", test_metrics)
predictions, labels, _ = trainer.predict(ds_splits["test"])
preds = predictions.argmax(-1)

true_labels = [
    [id2label[label_id] for label_id in seq if label_id != -100]
    for seq in labels
]
pred_labels = [
    [id2label[pred_id] for pred_id, label_id in zip(pred_seq, label_seq) if label_id != -100]
    for pred_seq, label_seq in zip(preds, labels)
]

detailed_result = seqeval.compute(predictions=pred_labels, references=true_labels)

# 8. Only show HPO_TERM in per-label report
print("\n HPO_TERM classification report:")
for label, metrics in detailed_result.items():
    if label.startswith("overall_"):
        continue
    if label != "HPO_TERM":
        continue
    print(f" {label:20} | Precision: {metrics['precision']:.3f} | Recall: {metrics['recall']:.3f} | F1: {metrics['f1']:.3f}")


In [None]:
from evaluate import load

# 预测
predictions, labels, _ = trainer.predict(ds_splits["test"])
preds = predictions.argmax(-1)

# 还原标签ID为标签名（跳过 -100，即 ignore_index）
true_labels = [
    [id2label[label_id] for label_id in seq if label_id != -100]
    for seq in labels
]
pred_labels = [
    [id2label[pred_id] for pred_id, label_id in zip(pred_seq, label_seq) if label_id != -100]
    for pred_seq, label_seq in zip(preds, labels)
]

# 使用 seqeval 计算所有标签评估结果
seqeval = load("seqeval")
detailed_result = seqeval.compute(predictions=pred_labels, references=true_labels)

# 打印 overall F1（可选）
print(f"\nOverall F1 score: {detailed_result.get('overall_f1', 0):.3f}")

# 只输出 HPO_TERM 的结果
print("\n HPO_TERM classification report:")
hpo_metrics = detailed_result.get("HPO_TERM")
if hpo_metrics:
    print(f" {'HPO_TERM':20} | Precision: {hpo_metrics['precision']:.3f} | Recall: {hpo_metrics['recall']:.3f} | F1: {hpo_metrics['f1']:.3f}")
else:
    print("No HPO_TERM entities found in predictions.")


In [None]:
from collections import defaultdict

# ----------- Step 1: 提取 HPO_TERM 实体 spans -----------
def extract_entities(labels):
    spans = []
    start = None
    current_label = None
    for i, lab_id in enumerate(labels):
        label = id2label.get(lab_id, "O")
        if label.startswith("B-HPO_TERM"):
            if current_label:
                spans.append((start, i - 1, current_label))
            start = i
            current_label = "HPO_TERM"
        elif label.startswith("I-HPO_TERM") and current_label:
            continue
        else:
            if current_label:
                spans.append((start, i - 1, current_label))
                current_label = None
                start = None
    if current_label:
        spans.append((start, len(labels) - 1, current_label))
    return spans

# ----------- Step 2: IOU计算 & Relaxed匹配 -----------
def iou(a, b):
    inter = max(0, min(a[1], b[1]) - max(a[0], b[0]) + 1)
    union = max(a[1], b[1]) - min(a[0], b[0]) + 1
    return inter / union

def relaxed_match(pred_span, true_span):
    ps, pe, plabel = pred_span
    ts, te, tlabel = true_span
    if plabel != tlabel:
        return False
    if abs(ps - ts) <= 4 and abs(pe - te) <= 4:
        return True
    if iou((ps, pe), (ts, te)) >= 0.4:
        return True
    return False

# ----------- Step 3: Relaxed Evaluation Metric ----------
def relaxed_compute_metrics(preds, refs):
    tp, fp, fn = 0, 0, 0
    label_metrics = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})

    for pred_seq, ref_seq in zip(preds, refs):
        pred_ents = extract_entities(pred_seq)
        true_ents = extract_entities(ref_seq)
        matched = set()

        for pred_ent in pred_ents:
            match_found = False
            for i, true_ent in enumerate(true_ents):
                if i in matched:
                    continue
                if relaxed_match(pred_ent, true_ent):
                    tp += 1
                    label_metrics["HPO_TERM"]["tp"] += 1
                    matched.add(i)
                    match_found = True
                    break
            if not match_found:
                fp += 1
                label_metrics["HPO_TERM"]["fp"] += 1

        for i, true_ent in enumerate(true_ents):
            if i not in matched:
                fn += 1
                label_metrics["HPO_TERM"]["fn"] += 1

    precision = tp / (tp + fp + 1e-10)
    recall    = tp / (tp + fn + 1e-10)
    f1        = 2 * precision * recall / (precision + recall + 1e-10)

    print("\n Relaxed Per-label HPO_TERM classification report:")
    for label, m in label_metrics.items():
        lp = m["tp"] / (m["tp"] + m["fp"] + 1e-10)
        lr = m["tp"] / (m["tp"] + m["fn"] + 1e-10)
        lf1 = 2 * lp * lr / (lp + lr + 1e-10)
        print(f" {label:20} | Precision: {lp:.3f} | Recall: {lr:.3f} | F1: {lf1:.3f}")

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

# ----------- Step 4: 清除 -100 Padding ----------
filtered_preds = []
filtered_labels = []

for pred_seq, label_seq in zip(preds, labels):
    filtered_pred = [p for p, l in zip(pred_seq, label_seq) if l != -100]
    filtered_label = [l for l in label_seq if l != -100]
    filtered_preds.append(filtered_pred)
    filtered_labels.append(filtered_label)

# ----------- Step 5: 修复结构（BIO） -----------
def clean_prediction_structure(labels):
    """修复孤立 I-、B-O-B 结构"""
    cleaned = []
    prev = "O"
    for i, label in enumerate(labels):
        if label.startswith("I-") and prev == "O":
            label = "B-" + label[2:]
        if label == "O" and i+2 < len(labels) and labels[i+1].startswith("B-") and labels[i+2].startswith("I-"):
            label = "I-" + labels[i+1][2:]
        cleaned.append(label)
        prev = label
    return cleaned

def fix_illegal_I(labels):
    """修复 I- 前不是 B- 或 I- 的非法结构"""
    fixed = []
    prev_type = "O"
    for label in labels:
        if label.startswith("I-"):
            if prev_type != label[2:]:
                label = "B-" + label[2:]
        fixed.append(label)
        if label.startswith("B-"):
            prev_type = label[2:]
        elif label.startswith("I-"):
            pass
        else:
            prev_type = "O"
    return fixed

def clean_and_fix_prediction_sequence(label_ids):
    """统一修复：结构 + I-合法性"""
    labels = [id2label.get(lid, "O") for lid in label_ids]
    labels = clean_prediction_structure(labels)
    labels = fix_illegal_I(labels)
    return [label2id.get(l, 0) for l in labels]

# ----------- Step 6: 应用修复并评估 ----------
filtered_preds_cleaned = [clean_and_fix_prediction_sequence(seq) for seq in filtered_preds]

print("\n Running relaxed evaluation on test set (HPO_TERM only, with structure repair)...")
relaxed_metrics = relaxed_compute_metrics(filtered_preds_cleaned, filtered_labels)
print("\n Relaxed HPO_TERM test set metrics:", relaxed_metrics)


In [None]:
from collections import defaultdict

# ----------- Step 1: 提取 HPO_TERM 实体 spans -----------
def extract_entities(labels):
    spans = []
    start = None
    current_label = None
    for i, lab_id in enumerate(labels):
        label = id2label.get(lab_id, "O")
        if label.startswith("B-HPO_TERM"):
            if current_label:
                spans.append((start, i - 1, current_label))
            start = i
            current_label = "HPO_TERM"
        elif label.startswith("I-HPO_TERM") and current_label:
            continue
        else:
            if current_label:
                spans.append((start, i - 1, current_label))
                current_label = None
                start = None
    if current_label:
        spans.append((start, len(labels) - 1, current_label))
    return spans

# ----------- Step 2: IOU计算 & Relaxed匹配 -----------
def iou(a, b):
    inter = max(0, min(a[1], b[1]) - max(a[0], b[0]) + 1)
    union = max(a[1], b[1]) - min(a[0], b[0]) + 1
    return inter / union

def relaxed_match(pred_span, true_span):
    ps, pe, plabel = pred_span
    ts, te, tlabel = true_span
    if plabel != tlabel:
        return False
    if abs(ps - ts) <= 4 and abs(pe - te) <= 4:
        return True
    if iou((ps, pe), (ts, te)) >= 0.4:
        return True
    return False

# ----------- Step 3: Relaxed Evaluation Metric ----------
def relaxed_compute_metrics(preds, refs):
    tp, fp, fn = 0, 0, 0
    label_metrics = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})

    for pred_seq, ref_seq in zip(preds, refs):
        pred_ents = extract_entities(pred_seq)
        true_ents = extract_entities(ref_seq)
        matched = set()

        for pred_ent in pred_ents:
            match_found = False
            for i, true_ent in enumerate(true_ents):
                if i in matched:
                    continue
                if relaxed_match(pred_ent, true_ent):
                    tp += 1
                    label_metrics["HPO_TERM"]["tp"] += 1
                    matched.add(i)
                    match_found = True
                    break
            if not match_found:
                fp += 1
                label_metrics["HPO_TERM"]["fp"] += 1

        for i, true_ent in enumerate(true_ents):
            if i not in matched:
                fn += 1
                label_metrics["HPO_TERM"]["fn"] += 1

    precision = tp / (tp + fp + 1e-10)
    recall    = tp / (tp + fn + 1e-10)
    f1        = 2 * precision * recall / (precision + recall + 1e-10)

    print("\n Relaxed Per-label HPO_TERM classification report:")
    for label, m in label_metrics.items():
        lp = m["tp"] / (m["tp"] + m["fp"] + 1e-10)
        lr = m["tp"] / (m["tp"] + m["fn"] + 1e-10)
        lf1 = 2 * lp * lr / (lp + lr + 1e-10)
        print(f"{label:20} | Precision: {lp:.3f} | Recall: {lr:.3f} | F1: {lf1:.3f}")

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

# ----------- Step 4: 清除 -100 Padding ----------
filtered_preds = []
filtered_labels = []

for pred_seq, label_seq in zip(preds, labels):
    filtered_pred = [p for p, l in zip(pred_seq, label_seq) if l != -100]
    filtered_label = [l for l in label_seq if l != -100]
    filtered_preds.append(filtered_pred)
    filtered_labels.append(filtered_label)

# ----------- Step 5: 直接评估，无结构修复 ----------
print("\n Running relaxed evaluation on test set (HPO_TERM only, no structure repair)...")
relaxed_metrics = relaxed_compute_metrics(filtered_preds, filtered_labels)
print("\n Relaxed HPO_TERM test set metrics:", relaxed_metrics)


加银数据

In [None]:
import json
from pathlib import Path
from collections import Counter, defaultdict
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

# -------------------
# Constants & Paths
# -------------------
FILE_MERGED = Path("/kaggle/working/merged_spans_with_entities.jsonl")
DIR_SILVER  = Path("/kaggle/input/hpo-only")
OUT_DIR     = Path("/kaggle/working/bio_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_BIO = OUT_DIR / "train.jsonl"
DEV_BIO   = OUT_DIR / "dev.jsonl"
TEST_BIO  = OUT_DIR / "test.jsonl"

ENTITY_TYPES = {
    "AGE_ONSET", "AGE_FOLLOWUP", "AGE_DEATH",
    "PATIENT", "HPO_TERM", "GENE", "GENE_VARIANT"
}

tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    use_fast=True
)

# -------------------
# Utility Functions
# -------------------
def iter_jsonl(path: Path):
    with path.open("r", encoding="utf-8") as fh:
        for line in fh:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                continue

def make_bio_labels(spans, enc):
    tokens   = enc.tokens()
    offsets  = enc["offset_mapping"]
    word_ids = enc.word_ids()
    tags     = ["O"] * len(tokens)
    span_to_tokens = []

    for sp in spans:
        s, e, typ = sp["start"], sp["end"], sp["label"]
        idxs = [
            i for i, (b, t) in enumerate(offsets)
            if not (t <= s or b >= e)
        ]
        span_to_tokens.append(idxs)
        if not idxs:
            continue
        tags[idxs[0]] = f"B-{typ}"
        for i in idxs[1:]:
            tags[i] = f"I-{typ}"

    for idxs in span_to_tokens:
        if len(idxs) == 1:
            tags[idxs[0]] = tags[idxs[0]].replace("I-", "B-")

    prev_wid = None
    for i, wid in enumerate(word_ids):
        if (wid is not None and wid == prev_wid
                and tags[i] == "O"
                and tags[i - 1].startswith(("B-", "I-"))):
            tags[i] = "I-" + tags[i - 1][2:]
        prev_wid = wid

    return tags

def record_to_bio(rec):
    text  = rec.get("text", "")
    spans = [s for s in rec.get("spans", []) if s.get("label") in ENTITY_TYPES]
    if not spans:
        return None
    enc = tokenizer(
        text,
        add_special_tokens=False,
        return_offsets_mapping=True,
        truncation=True,
        max_length=512
    )
    return {
        "tokens": enc.tokens(),
        "labels": make_bio_labels(spans, enc)
    }

def dump_jsonl(path: Path, data):
    with path.open("w", encoding="utf-8") as fh:
        for obj in data:
            fh.write(json.dumps(obj, ensure_ascii=False) + "\n")

def load_extra_bio(path: Path):
    extra = []
    for rec in iter_jsonl(path):
        bio = record_to_bio(rec)
        if bio:
            extra.append(bio)
    return extra

# -------------------
# Step 1: Load and convert gold data
# -------------------
print(">> Loading gold data …")
merged_bio = [
    bio for rec in iter_jsonl(FILE_MERGED)
    if (bio := record_to_bio(rec)) is not None
]
print(f"Total valid records in gold: {len(merged_bio)}")

# -------------------
# Step 2: Random split gold data
# -------------------
# 第一步：先拿出 20% 给 test
train_dev, test_set = train_test_split(
    merged_bio,
    test_size=0.20,
    random_state=42
)

# 第二步：再把剩下的 80% 中，25% 给 dev ⇒ 0.25×80% = 20%
train_set, dev_set = train_test_split(
    train_dev,
    test_size=0.25,
    random_state=42
)

print(f"Split sizes – TRAIN: {len(train_set)}, DEV: {len(dev_set)}, TEST: {len(test_set)}")

# -------------------
# Step 3: Add silver data to train set
# -------------------
extra_train = []
if DIR_SILVER.exists():
    print(">> Loading silver data from hpo-only/")
    for jf in sorted(DIR_SILVER.glob("*.jsonl")):
        print(f"  - {jf.name}")
        extra_train.extend(load_extra_bio(jf))
else:
    print(">> Silver data directory not found.")

train_final = train_set + extra_train
print(f"Final train size: {len(train_final)} (including {len(extra_train)} silver records)")

# -------------------
# Step 4: Save to disk
# -------------------
dump_jsonl(TRAIN_BIO, train_final)
dump_jsonl(DEV_BIO, dev_set)
dump_jsonl(TEST_BIO, test_set)

print(f"\nSaved to:")
print(f"  ➜ {TRAIN_BIO}")
print(f"  ➜ {DEV_BIO}")
print(f"  ➜ {TEST_BIO}")


In [None]:
import json
from pathlib import Path

from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)
import evaluate

# 1. Load BIO datasets
BIO_DIR = Path("/kaggle/working/bio_outputs")

def load_jsonl(path: Path):
    with path.open(encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]

# 1.1 过滤非 HPO_TERM 标签为 'O'
def keep_only_hpo_labels(example):
    example["labels"] = [
        lab if lab.endswith("HPO_TERM") else "O"
        for lab in example["labels"]
    ]
    return example

train_examples = [keep_only_hpo_labels(ex) for ex in load_jsonl(BIO_DIR / "train.jsonl")]
dev_examples   = [keep_only_hpo_labels(ex) for ex in load_jsonl(BIO_DIR / "dev.jsonl")]
test_examples  = [keep_only_hpo_labels(ex) for ex in load_jsonl(BIO_DIR / "test.jsonl")]

ds_splits = DatasetDict({
    "train":      Dataset.from_list(train_examples),
    "validation": Dataset.from_list(dev_examples),
    "test":       Dataset.from_list(test_examples),
})
print("Loaded dataset sizes:", {k: len(v) for k, v in ds_splits.items()})

# 2. Tokenizer & label mapping
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    use_fast=True,
)

# 2.1 只保留 ['B-HPO_TERM', 'I-HPO_TERM', 'O']
unique_labels = sorted({lab for ex in train_examples + dev_examples + test_examples
                        for lab in ex["labels"]})
label2id = {lab: i for i, lab in enumerate(unique_labels)}
id2label = {i: lab for lab, i in label2id.items()}

def tokenize_and_align_labels(ex):
    enc = tokenizer(
        ex["tokens"],
        is_split_into_words=True,
        truncation=True,
        max_length=512,
        return_attention_mask=True,
    )
    enc["labels"] = [label2id[l] for l in ex["labels"]]
    return enc

ds_splits = ds_splits.map(
    tokenize_and_align_labels,
    batched=False,
    remove_columns=["tokens", "labels"],
)

# 3. Model
model = AutoModelForTokenClassification.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    num_labels=len(unique_labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

# 4. Metrics
seqeval = evaluate.load("seqeval")

def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    refs = p.label_ids
    true_labels = [
        [id2label[lid] for lid in seq if lid != -100] for seq in refs
    ]
    pred_labels = [
        [id2label[pid] for pid, lid in zip(pred_seq, ref_seq) if lid != -100]
        for pred_seq, ref_seq in zip(preds, refs)
    ]
    result = seqeval.compute(predictions=pred_labels, references=true_labels)
    return {
        "overall_precision": result["overall_precision"],
        "overall_recall":    result["overall_recall"],
        "overall_f1":        result["overall_f1"],
        "overall_accuracy":  result["overall_accuracy"],
    }

data_collator = DataCollatorForTokenClassification(tokenizer)

# 5. Training arguments and Trainer
training_args = TrainingArguments(
    output_dir="ner_pubmedbert",
    eval_strategy="steps",
    eval_steps=50,
    save_steps=500,
    logging_strategy="steps",
    logging_steps=50,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    learning_rate=3e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="overall_f1",
    greater_is_better=True,
    report_to=["none"],
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_splits["train"],
    eval_dataset=ds_splits["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# 6. Train and evaluate
trainer.train()
trainer.evaluate()

# 7. Predict on test set
test_metrics = trainer.predict(ds_splits["test"]).metrics
print("Test set metrics:", test_metrics)
predictions, labels, _ = trainer.predict(ds_splits["test"])
preds = predictions.argmax(-1)

true_labels = [
    [id2label[label_id] for label_id in seq if label_id != -100]
    for seq in labels
]
pred_labels = [
    [id2label[pred_id] for pred_id, label_id in zip(pred_seq, label_seq) if label_id != -100]
    for pred_seq, label_seq in zip(preds, labels)
]

detailed_result = seqeval.compute(predictions=pred_labels, references=true_labels)

# 8. Only show HPO_TERM in per-label report
print("\n HPO_TERM classification report:")
for label, metrics in detailed_result.items():
    if label.startswith("overall_"):
        continue
    if label != "HPO_TERM":
        continue
    print(f" {label:20} | Precision: {metrics['precision']:.3f} | Recall: {metrics['recall']:.3f} | F1: {metrics['f1']:.3f}")

In [None]:
from evaluate import load

# 预测
predictions, labels, _ = trainer.predict(ds_splits["test"])
preds = predictions.argmax(-1)

# 还原标签ID为标签名（跳过 -100，即 ignore_index）
true_labels = [
    [id2label[label_id] for label_id in seq if label_id != -100]
    for seq in labels
]
pred_labels = [
    [id2label[pred_id] for pred_id, label_id in zip(pred_seq, label_seq) if label_id != -100]
    for pred_seq, label_seq in zip(preds, labels)
]

# 使用 seqeval 计算所有标签评估结果
seqeval = load("seqeval")
detailed_result = seqeval.compute(predictions=pred_labels, references=true_labels)

# 打印 overall F1（可选）
print(f"\nOverall F1 score: {detailed_result.get('overall_f1', 0):.3f}")

# 只输出 HPO_TERM 的结果
print("\n HPO_TERM classification report:")
hpo_metrics = detailed_result.get("HPO_TERM")
if hpo_metrics:
    print(f" {'HPO_TERM':20} | Precision: {hpo_metrics['precision']:.3f} | Recall: {hpo_metrics['recall']:.3f} | F1: {hpo_metrics['f1']:.3f}")
else:
    print("No HPO_TERM entities found in predictions.")

In [None]:
from collections import defaultdict

# ----------- Step 1: 提取 HPO_TERM 实体 spans -----------
def extract_entities(labels):
    spans = []
    start = None
    current_label = None
    for i, lab_id in enumerate(labels):
        label = id2label.get(lab_id, "O")
        if label.startswith("B-HPO_TERM"):
            if current_label:
                spans.append((start, i - 1, current_label))
            start = i
            current_label = "HPO_TERM"
        elif label.startswith("I-HPO_TERM") and current_label:
            continue
        else:
            if current_label:
                spans.append((start, i - 1, current_label))
                current_label = None
                start = None
    if current_label:
        spans.append((start, len(labels) - 1, current_label))
    return spans

# ----------- Step 2: IOU计算 & Relaxed匹配 -----------
def iou(a, b):
    inter = max(0, min(a[1], b[1]) - max(a[0], b[0]) + 1)
    union = max(a[1], b[1]) - min(a[0], b[0]) + 1
    return inter / union

def relaxed_match(pred_span, true_span):
    ps, pe, plabel = pred_span
    ts, te, tlabel = true_span
    if plabel != tlabel:
        return False
    if abs(ps - ts) <= 4 and abs(pe - te) <= 4:
        return True
    if iou((ps, pe), (ts, te)) >= 0.4:
        return True
    return False

# ----------- Step 3: Relaxed Evaluation Metric ----------
def relaxed_compute_metrics(preds, refs):
    tp, fp, fn = 0, 0, 0
    label_metrics = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})

    for pred_seq, ref_seq in zip(preds, refs):
        pred_ents = extract_entities(pred_seq)
        true_ents = extract_entities(ref_seq)
        matched = set()

        for pred_ent in pred_ents:
            match_found = False
            for i, true_ent in enumerate(true_ents):
                if i in matched:
                    continue
                if relaxed_match(pred_ent, true_ent):
                    tp += 1
                    label_metrics["HPO_TERM"]["tp"] += 1
                    matched.add(i)
                    match_found = True
                    break
            if not match_found:
                fp += 1
                label_metrics["HPO_TERM"]["fp"] += 1

        for i, true_ent in enumerate(true_ents):
            if i not in matched:
                fn += 1
                label_metrics["HPO_TERM"]["fn"] += 1

    precision = tp / (tp + fp + 1e-10)
    recall    = tp / (tp + fn + 1e-10)
    f1        = 2 * precision * recall / (precision + recall + 1e-10)

    print("\n Relaxed Per-label HPO_TERM classification report:")
    for label, m in label_metrics.items():
        lp = m["tp"] / (m["tp"] + m["fp"] + 1e-10)
        lr = m["tp"] / (m["tp"] + m["fn"] + 1e-10)
        lf1 = 2 * lp * lr / (lp + lr + 1e-10)
        print(f" {label:20} | Precision: {lp:.3f} | Recall: {lr:.3f} | F1: {lf1:.3f}")

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

# ----------- Step 4: 清除 -100 Padding ----------
filtered_preds = []
filtered_labels = []

for pred_seq, label_seq in zip(preds, labels):
    filtered_pred = [p for p, l in zip(pred_seq, label_seq) if l != -100]
    filtered_label = [l for l in label_seq if l != -100]
    filtered_preds.append(filtered_pred)
    filtered_labels.append(filtered_label)

# ----------- Step 5: 修复结构（BIO） -----------
def clean_prediction_structure(labels):
    """修复孤立 I-、B-O-B 结构"""
    cleaned = []
    prev = "O"
    for i, label in enumerate(labels):
        if label.startswith("I-") and prev == "O":
            label = "B-" + label[2:]
        if label == "O" and i+2 < len(labels) and labels[i+1].startswith("B-") and labels[i+2].startswith("I-"):
            label = "I-" + labels[i+1][2:]
        cleaned.append(label)
        prev = label
    return cleaned

def fix_illegal_I(labels):
    """修复 I- 前不是 B- 或 I- 的非法结构"""
    fixed = []
    prev_type = "O"
    for label in labels:
        if label.startswith("I-"):
            if prev_type != label[2:]:
                label = "B-" + label[2:]
        fixed.append(label)
        if label.startswith("B-"):
            prev_type = label[2:]
        elif label.startswith("I-"):
            pass
        else:
            prev_type = "O"
    return fixed

def clean_and_fix_prediction_sequence(label_ids):
    """统一修复：结构 + I-合法性"""
    labels = [id2label.get(lid, "O") for lid in label_ids]
    labels = clean_prediction_structure(labels)
    labels = fix_illegal_I(labels)
    return [label2id.get(l, 0) for l in labels]

# ----------- Step 6: 应用修复并评估 ----------
filtered_preds_cleaned = [clean_and_fix_prediction_sequence(seq) for seq in filtered_preds]

print("\n Running relaxed evaluation on test set (HPO_TERM only)...")
relaxed_metrics = relaxed_compute_metrics(filtered_preds_cleaned, filtered_labels)
print("\n Relaxed HPO_TERM test set metrics:", relaxed_metrics)

In [None]:
from collections import defaultdict

# ----------- Step 1: 提取 HPO_TERM 实体 spans -----------
def extract_entities(labels):
    spans = []
    start = None
    current_label = None
    for i, lab_id in enumerate(labels):
        label = id2label.get(lab_id, "O")
        if label.startswith("B-HPO_TERM"):
            if current_label:
                spans.append((start, i - 1, current_label))
            start = i
            current_label = "HPO_TERM"
        elif label.startswith("I-HPO_TERM") and current_label:
            continue
        else:
            if current_label:
                spans.append((start, i - 1, current_label))
                current_label = None
                start = None
    if current_label:
        spans.append((start, len(labels) - 1, current_label))
    return spans

# ----------- Step 2: IOU计算 & Relaxed匹配 -----------
def iou(a, b):
    inter = max(0, min(a[1], b[1]) - max(a[0], b[0]) + 1)
    union = max(a[1], b[1]) - min(a[0], b[0]) + 1
    return inter / union

def relaxed_match(pred_span, true_span):
    ps, pe, plabel = pred_span
    ts, te, tlabel = true_span
    if plabel != tlabel:
        return False
    if abs(ps - ts) <= 4 and abs(pe - te) <= 4:
        return True
    if iou((ps, pe), (ts, te)) >= 0.4:
        return True
    return False

# ----------- Step 3: Relaxed Evaluation Metric ----------
def relaxed_compute_metrics(preds, refs):
    tp, fp, fn = 0, 0, 0
    label_metrics = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})

    for pred_seq, ref_seq in zip(preds, refs):
        pred_ents = extract_entities(pred_seq)
        true_ents = extract_entities(ref_seq)
        matched = set()

        for pred_ent in pred_ents:
            match_found = False
            for i, true_ent in enumerate(true_ents):
                if i in matched:
                    continue
                if relaxed_match(pred_ent, true_ent):
                    tp += 1
                    label_metrics["HPO_TERM"]["tp"] += 1
                    matched.add(i)
                    match_found = True
                    break
            if not match_found:
                fp += 1
                label_metrics["HPO_TERM"]["fp"] += 1

        for i, true_ent in enumerate(true_ents):
            if i not in matched:
                fn += 1
                label_metrics["HPO_TERM"]["fn"] += 1

    precision = tp / (tp + fp + 1e-10)
    recall    = tp / (tp + fn + 1e-10)
    f1        = 2 * precision * recall / (precision + recall + 1e-10)

    print("\nRelaxed Per-label HPO_TERM classification report:")
    for label, m in label_metrics.items():
        lp = m["tp"] / (m["tp"] + m["fp"] + 1e-10)
        lr = m["tp"] / (m["tp"] + m["fn"] + 1e-10)
        lf1 = 2 * lp * lr / (lp + lr + 1e-10)
        print(f"{label:20} | Precision: {lp:.3f} | Recall: {lr:.3f} | F1: {lf1:.3f}")

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

# ----------- Step 4: 清除 -100 Padding ----------
filtered_preds = []
filtered_labels = []

for pred_seq, label_seq in zip(preds, labels):
    filtered_pred = [p for p, l in zip(pred_seq, label_seq) if l != -100]
    filtered_label = [l for l in label_seq if l != -100]
    filtered_preds.append(filtered_pred)
    filtered_labels.append(filtered_label)

# ----------- Step 5: 直接评估，无结构修复 ----------
print("\n Running relaxed evaluation on test set...")
relaxed_metrics = relaxed_compute_metrics(filtered_preds, filtered_labels)
print("\n Relaxed HPO_TERM test set metrics:", relaxed_metrics)

In [None]:
trainer.save_model("ner_pubmedbert_saved_HPO")
tokenizer.save_pretrained("ner_pubmedbert_saved_HPO")

In [None]:
pip install transformers obonet rapidfuzz

In [None]:
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification
import obonet
from rapidfuzz import process
import torch

# === Config ===
MODEL_DIR = "/kaggle/working/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE  = Path("/kaggle/working/test_normalized_mentions.jsonl")
MAX_LENGTH = 512  # 模型最大 token 长度
DEVICE = "cuda:0"

# === Step 1: 读取测试数据 ===
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
sentences = [" ".join(ex["tokens"]) for ex in test_data]

# === Step 2: 加载 tokenizer & model ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
tokenizer.model_max_length = MAX_LENGTH
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)
model.to(DEVICE)
model.eval()

# label map: id -> label string, e.g. "B-HPO_TERM", "I-HPO_TERM", "O"
id2label = model.config.id2label

# === Step 3: 构建 HPO ontology 映射 ===
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
graph = obonet.read_obo(obo_url)
hpo_map = {}
for node_id, data in graph.nodes(data=True):
    name = data.get("name")
    if name:
        hpo_map.setdefault(name.lower(), []).append(node_id)
    for syn in data.get("synonym", []):
        text = syn.split('"')[1]
        hpo_map.setdefault(text.lower(), []).append(node_id)

def normalize_mention(text: str):
    key = text.lower()
    if key in hpo_map:
        return hpo_map[key][0]
    matches = process.extract(key, list(hpo_map.keys()), limit=1, score_cutoff=80)
    if matches:
        return hpo_map[matches[0][0]][0]
    return None

# === Step 4: 对每句执行 NER ===
normalized_mentions = []

for idx, sentence in enumerate(sentences):
    # Tokenize + truncate
    encoding = tokenizer(
        sentence,
        return_offsets_mapping=True,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    ).to(DEVICE)

    with torch.no_grad():
        outputs = model(**{k: encoding[k] for k in ["input_ids","attention_mask"]})
    logits = outputs.logits  # shape [1, seq_len, num_labels]
    preds = logits.argmax(dim=-1)[0].cpu().tolist()  # [seq_len]
    offsets = encoding["offset_mapping"][0].cpu().tolist()  # [(start,end),...]

    # Extract contiguous HPO_TERM spans
    span_start, span_end = None, None
    for i, label_id in enumerate(preds):
        label = id2label[label_id]
        if label == "B-HPO_TERM":
            # start new span
            span_start = offsets[i][0]
            span_end = offsets[i][1]
        elif label == "I-HPO_TERM" and span_start is not None:
            # continue span
            span_end = offsets[i][1]
        else:
            # label is "O" or a new B- or outside; close existing span
            if span_start is not None:
                mention_text = sentence[span_start:span_end]
                hpo_id = normalize_mention(mention_text)
                normalized_mentions.append({
                    "sentence_index": idx,
                    "sentence":       sentence,
                    "mention":        mention_text,
                    "span":           (span_start, span_end),
                    "hpo_id":         hpo_id
                })
                span_start, span_end = None, None
            # no action on O or other B-

    # if sentence ends with a span open
    if span_start is not None:
        mention_text = sentence[span_start:span_end]
        hpo_id = normalize_mention(mention_text)
        normalized_mentions.append({
            "sentence_index": idx,
            "sentence":       sentence,
            "mention":        mention_text,
            "span":           (span_start, span_end),
            "hpo_id":         hpo_id
        })

# === Step 5: 输出统计 ===
total = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"Total mentions: {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})")
print(f"Failed to map:   {total-mapped} ({(total-mapped)/total:.1%})")

# === Step 6: 保存为 JSONL ===
with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"Normalized results saved to: {OUT_FILE.resolve()}")



In [None]:
import re
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification
import obonet
from rapidfuzz import process
import torch

# === Config ===
MODEL_DIR  = "/kaggle/working/ner_pubmedbert_saved_HPO"
TEST_FILE  = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE   = Path("/kaggle/working/test_normalized_mentions.jsonl")
MAX_LENGTH = 512         # 模型最大 token 长度
DEVICE     = "cuda:0"

# === Utility: 清洗函数（去掉数字、标点、##子词标记） ===
def clean_text(text: str) -> str:
    # 去除文献引用 [24, 25]、孤立数字、百分比
    text = re.sub(r"\[\s*\d+(?:\s*,\s*\d+)*\s*\]", "", text)
    text = re.sub(r"\d+%?", "", text)
    # 去掉子词标记
    text = text.replace("##", "")
    # （可选）去掉多余空格
    return re.sub(r"\s{2,}", " ", text).strip()

# === Utility: 噪声过滤 ===
def is_noise(mention: str) -> bool:
    # 纯数字或百分号
    if re.fullmatch(r"\d+%?", mention):
        return True
    # 纯标点
    if re.fullmatch(r"[^\w\s]+", mention):
        return True
    return False

# === Step 1: 读取测试数据 ===
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [" ".join(ex["tokens"]) for ex in test_data]

# === Step 2: 加载 tokenizer & model ===
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_DIR, use_fast=True, local_files_only=True
)
tokenizer.model_max_length = MAX_LENGTH
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_DIR, local_files_only=True
).to(DEVICE)
model.eval()

# id -> label 映射，如 "B-HPO_TERM", "I-HPO_TERM", "O"
id2label = model.config.id2label

# === Step 3: 构建 HPO ontology 映射 ===
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
graph   = obonet.read_obo(obo_url)
hpo_map = {}
for node_id, data in graph.nodes(data=True):
    name = data.get("name")
    if name:
        hpo_map.setdefault(name.lower(), []).append(node_id)
    for syn in data.get("synonym", []):
        text = syn.split('"')[1]
        hpo_map.setdefault(text.lower(), []).append(node_id)

def normalize_mention(text: str):
    key = text.lower()
    if key in hpo_map:
        return hpo_map[key][0]
    matches = process.extract(key, list(hpo_map.keys()), limit=1, score_cutoff=80)
    return hpo_map[matches[0][0]][0] if matches else None

# === Step 4: 对每一句文本清洗 + NER 抽取 ===
normalized_mentions = []

for idx, orig in enumerate(orig_sentences):
    # 4.1 清洗
    sentence = clean_text(orig)

    # 4.2 分词 & 截断 & 编码
    encoding = tokenizer(
        sentence,
        return_offsets_mapping=True,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    ).to(DEVICE)

    # 4.3 模型推理
    with torch.no_grad():
        outputs = model(
            input_ids=encoding["input_ids"],
            attention_mask=encoding["attention_mask"]
        )
    logits  = outputs.logits[0]                  # [seq_len, num_labels]
    preds   = logits.argmax(dim=-1).cpu().tolist()
    offsets = encoding["offset_mapping"][0].cpu().tolist()

    # 4.4 提取连续 HPO_TERM span
    span_start = span_end = None
    for i, label_id in enumerate(preds):
        label = id2label[label_id]
        if label == "B-HPO_TERM":
            span_start, span_end = offsets[i]
        elif label == "I-HPO_TERM" and span_start is not None:
            span_end = offsets[i][1]
        else:
            if span_start is not None:
                mention = sentence[span_start:span_end]
                # 4.5 后处理：过滤噪声
                if not is_noise(mention):
                    normalized_mentions.append({
                        "sentence_index": idx,
                        "sentence":       sentence,
                        "mention":        mention,
                        "span":           (span_start, span_end),
                        "hpo_id":         normalize_mention(mention)
                    })
                span_start = span_end = None

    # 末尾若仍有未闭合 span
    if span_start is not None:
        mention = sentence[span_start:span_end]
        if not is_noise(mention):
            normalized_mentions.append({
                "sentence_index": idx,
                "sentence":       sentence,
                "mention":        mention,
                "span":           (span_start, span_end),
                "hpo_id":         normalize_mention(mention)
            })

# === Step 5: 统计覆盖率 & 保存结果 ===
total  = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"Total mentions: {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})")
print(f"Failed to map:   {total-mapped} ({(total-mapped)/total:.1%})")

with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"Results saved to {OUT_FILE.resolve()}")


In [None]:
import re
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification
import obonet
from rapidfuzz import process
import torch

# === Config ===
MODEL_DIR  = "/kaggle/working/ner_pubmedbert_saved_HPO"
TEST_FILE  = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE   = Path("/kaggle/working/test_normalized_mentions.jsonl")
MAX_LENGTH = 512         # 模型最大 token 长度
DEVICE     = "cuda:0"

# === Utility: 清洗函数（去掉数字、标点、##子词标记、[UNK]） ===
def clean_text(text: str) -> str:
    text = text.replace("[UNK]", " ")
    # 去除文献引用 [24, 25]
    text = re.sub(r"\[\s*\d+(?:\s*,\s*\d+)*\s*\]", "", text)
    # 去除百分比和孤立数字
    text = re.sub(r"\d+%?", "", text)
    # 去掉子词标记
    text = text.replace("##", "")
    # 合并多余空格
    return re.sub(r"\s{2,}", " ", text).strip()

# === Utility: 噪声过滤 ===
def is_noise(mention: str) -> bool:
    # 纯数字或百分号
    if re.fullmatch(r"\d+%?", mention):
        return True
    # 纯标点
    if re.fullmatch(r"[^\w\s]+", mention):
        return True
    # 过短
    if len(mention.strip()) < 1:
        return True
    return False

# === Step 1: 读取测试数据 ===
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [" ".join(ex["tokens"]) for ex in test_data]

# === Step 2: 加载 tokenizer & model ===
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_DIR, use_fast=True, local_files_only=True
)
tokenizer.model_max_length = MAX_LENGTH
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_DIR, local_files_only=True
).to(DEVICE)
model.eval()

# id -> label 映射，如 "B-HPO_TERM", "I-HPO_TERM", "O"
id2label = model.config.id2label

# === Step 3: 构建 HPO ontology 映射 ===
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
graph   = obonet.read_obo(obo_url)
hpo_map = {}
for node_id, data in graph.nodes(data=True):
    name = data.get("name")
    if name:
        hpo_map.setdefault(name.lower(), []).append(node_id)
    for syn in data.get("synonym", []):
        text = syn.split('"')[1]
        hpo_map.setdefault(text.lower(), []).append(node_id)

def normalize_mention(text: str):
    key = text.lower()
    if key in hpo_map:
        return hpo_map[key][0]
    matches = process.extract(key, list(hpo_map.keys()), limit=1, score_cutoff=80)
    return hpo_map[matches[0][0]][0] if matches else None

# === Step 4: 对每一句文本清洗 + NER 抽取 ===
normalized_mentions = []

for idx, orig in enumerate(orig_sentences):
    # 4.1 清洗
    sentence = clean_text(orig)

    # 4.2 分词 & 截断 & 编码
    encoding = tokenizer(
        sentence,
        return_offsets_mapping=True,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    ).to(DEVICE)

    # 4.3 模型推理
    with torch.no_grad():
        outputs = model(
            input_ids=encoding["input_ids"],
            attention_mask=encoding["attention_mask"]
        )
    logits  = outputs.logits[0]                  # [seq_len, num_labels]
    preds   = logits.argmax(dim=-1).cpu().tolist()
    offsets = encoding["offset_mapping"][0].cpu().tolist()

    # 4.4 提取连续 HPO_TERM span
    span_start = span_end = None
    for i, label_id in enumerate(preds):
        label = id2label[label_id]
        if label == "B-HPO_TERM":
            span_start, span_end = offsets[i]
        elif label == "I-HPO_TERM" and span_start is not None:
            span_end = offsets[i][1]
        else:
            if span_start is not None:
                mention = sentence[span_start:span_end]
                # 4.5 后处理：过滤噪声
                if not is_noise(mention):
                    normalized_mentions.append({
                        "sentence_index": idx,
                        "sentence":       sentence,
                        "mention":        mention,
                        "span":           (span_start, span_end),
                        "hpo_id":         normalize_mention(mention)
                    })
                span_start = span_end = None

    # 末尾若仍有未闭合 span
    if span_start is not None:
        mention = sentence[span_start:span_end]
        if not is_noise(mention):
            normalized_mentions.append({
                "sentence_index": idx,
                "sentence":       sentence,
                "mention":        mention,
                "span":           (span_start, span_end),
                "hpo_id":         normalize_mention(mention)
            })

# === Step 5: 统计覆盖率 & 保存结果 ===
total  = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"Total mentions: {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})")
print(f"Failed to map:   {total-mapped} ({(total-mapped)/total:.1%})")

with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"Results saved to {OUT_FILE.resolve()}")


In [None]:
import re
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification
import obonet
from rapidfuzz import process
import torch

# === Config ===
MODEL_DIR  = "/kaggle/working/ner_pubmedbert_saved_HPO"
TEST_FILE  = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE   = Path("/kaggle/working/test_normalized_mentions.jsonl")
MAX_LENGTH = 512
DEVICE     = "cuda:0"

# === 清洗文本 ===
def clean_text(text: str) -> str:
    text = text.replace("[UNK]", " ")
    text = re.sub(r"\[\s*\d+(?:\s*,\s*\d+)*\s*\]", "", text)
    text = re.sub(r"\d+%?", "", text)
    text = text.replace("##", "")
    return re.sub(r"\s{2,}", " ", text).strip()

# === 判定是否是噪声 ===
def is_noise(mention: str) -> bool:
    if re.fullmatch(r"\d+%?", mention): return True
    if re.fullmatch(r"[^\w\s]+", mention): return True
    if len(mention.strip()) < 3: return True
    return False

# === Step 1: 读取测试数据 ===
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [" ".join(ex["tokens"]) for ex in test_data]

# === Step 2: 模型和Tokenizer加载 ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
tokenizer.model_max_length = MAX_LENGTH
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True).to(DEVICE)
model.eval()
id2label = model.config.id2label

# === Step 3: HPO 词典加载 ===
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
graph = obonet.read_obo(obo_url)
hpo_map = {}
for node_id, data in graph.nodes(data=True):
    name = data.get("name")
    if name:
        hpo_map.setdefault(name.lower(), []).append(node_id)
    for syn in data.get("synonym", []):
        text = syn.split('"')[1]
        hpo_map.setdefault(text.lower(), []).append(node_id)

def normalize_mention(text: str):
    key = text.lower()
    if key in hpo_map:
        return hpo_map[key][0]
    matches = process.extract(key, list(hpo_map.keys()), limit=1, score_cutoff=85)
    if matches:
        return hpo_map[matches[0][0]][0]
    return None

# === Step 4: NER处理 ===
normalized_mentions = []

for idx, orig in enumerate(orig_sentences):
    sentence = clean_text(orig)
    encoding = tokenizer(
        sentence,
        return_offsets_mapping=True,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LENGTH,
        is_split_into_words=False
    )

    # ❗️移除 offset_mapping，避免传入模型
    offset_mapping = encoding.pop("offset_mapping")
    encoding = {k: v.to(DEVICE) for k, v in encoding.items()}

    with torch.no_grad():
        outputs = model(**encoding)
    logits = outputs.logits[0]  # shape: [seq_len, num_labels]
    probs  = torch.softmax(logits, dim=-1)
    preds  = probs.argmax(dim=-1).cpu().tolist()
    scores = probs.max(dim=-1).values.cpu().tolist()
    offsets = offset_mapping[0].tolist()
    tokens  = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])

    current_mention = ""
    current_start = None
    current_score = []

    for i, label_id in enumerate(preds):
        label = id2label[label_id]
        token = tokens[i]
        offset = offsets[i]
        score = scores[i]

        if offset[0] == offset[1]:  # special tokens like [CLS], [SEP]
            continue
        if label == "B-HPO_TERM":
            if current_mention:
                mention = current_mention.strip()
                hpo_id = normalize_mention(mention)
                avg_score = sum(current_score)/len(current_score) if current_score else 0
                if hpo_id and not is_noise(mention) and avg_score > 0.6:
                    normalized_mentions.append({
                        "sentence_index": idx,
                        "sentence":       sentence,
                        "mention":        mention,
                        "span":           (current_start, offset[0]),
                        "hpo_id":         hpo_id
                    })
            current_mention = sentence[offset[0]:offset[1]]
            current_start = offset[0]
            current_score = [score]
        elif label == "I-HPO_TERM" and current_mention:
            current_mention += sentence[offset[0]:offset[1]]
            current_score.append(score)
        else:
            if current_mention:
                mention = current_mention.strip()
                hpo_id = normalize_mention(mention)
                avg_score = sum(current_score)/len(current_score) if current_score else 0
                if hpo_id and not is_noise(mention) and avg_score > 0.6:
                    normalized_mentions.append({
                        "sentence_index": idx,
                        "sentence":       sentence,
                        "mention":        mention,
                        "span":           (current_start, offset[0]),
                        "hpo_id":         hpo_id
                    })
            current_mention = ""
            current_score = []
            current_start = None

    # 收尾
    if current_mention:
        mention = current_mention.strip()
        hpo_id = normalize_mention(mention)
        avg_score = sum(current_score)/len(current_score) if current_score else 0
        if hpo_id and not is_noise(mention) and avg_score > 0.6:
            normalized_mentions.append({
                "sentence_index": idx,
                "sentence":       sentence,
                "mention":        mention,
                "span":           (current_start, len(sentence)),
                "hpo_id":         hpo_id
            })

# === Step 5: 输出统计与保存 ===
total  = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"Total mentions: {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})")
print(f"Failed to map:   {total-mapped} ({(total-mapped)/total:.1%})")

with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"Results saved to {OUT_FILE.resolve()}")



In [None]:
pip install nltk obonet rapidfuzz transformers torch

In [None]:
import nltk
nltk.download('punkt'

In [None]:
import re
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification
import obonet
from rapidfuzz import process
import torch

# === Config ===
MODEL_DIR = "/kaggle/working/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test_text_only.jsonl")
MAX_LENGTH = 512
DEVICE = "cuda:0"

# === Utility: 清洗函数（去掉 [UNK]、子词标记、文献引用、孤立数字）===
def clean_text(text: str) -> str:
    text = text.replace("[UNK]", " ")
    text = re.sub(r"\[\s*\d+(?:\s*,\s*\d+)*\s*\]", "", text)
    text = re.sub(r"\d+%?", "", text)
    text = text.replace("##", "")
    return re.sub(r"\s{2,}", " ", text).strip()

# === Utility: 更严格的噪声过滤（防粘连、无意义单词）===
def is_noise(mention: str) -> bool:
    mention = mention.strip()
    if not mention:
        return True
    if re.fullmatch(r"\d+%?", mention):
        return True
    if re.fullmatch(r"[^\w\s]+", mention):
        return True
    if len(mention) < 3:
        return True
    if not re.search(r"[aeiou]", mention.lower()):
        return True
    if len(mention) > 25 and " " not in mention:
        return True
    blacklist = {"showed", "found", "revealed", "including", "video", "fig", "fig.", "information"}
    if mention.lower() in blacklist:
        return True
    return False

# === Step 1: Load raw sentence text ===
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [ex["text"] for ex in test_data]

# === Step 2: Load model and tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
tokenizer.model_max_length = MAX_LENGTH
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True).to(DEVICE)
model.eval()
id2label = model.config.id2label

# === Step 3: Build HPO dictionary from hp.obo ===
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
graph = obonet.read_obo(obo_url)
hpo_map = {}
for node_id, data in graph.nodes(data=True):
    name = data.get("name")
    if name:
        hpo_map.setdefault(name.lower(), []).append(node_id)
    for syn in data.get("synonym", []):
        match = re.search(r'"(.+?)"', syn)
        if match:
            hpo_map.setdefault(match.group(1).lower(), []).append(node_id)

def normalize_mention(text: str):
    key = text.lower()
    if key in hpo_map:
        return hpo_map[key][0]
    match = process.extractOne(key, hpo_map.keys(), score_cutoff=85)
    if match:
        return hpo_map[match[0]][0]
    return None

# === Step 4: Run NER + Normalize ===
mapped_mentions = []
unmapped_mentions = []

for idx, orig in enumerate(orig_sentences):
    sentence = clean_text(orig)

    encoding = tokenizer(
        sentence,
        return_offsets_mapping=True,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LENGTH,
        is_split_into_words=False
    )
    offset_mapping = encoding.pop("offset_mapping")[0].tolist()
    encoding = {k: v.to(DEVICE) for k, v in encoding.items()}

    with torch.no_grad():
        outputs = model(**encoding)
    preds = outputs.logits.argmax(dim=-1)[0].cpu().tolist()

    current_offsets = []
    for i, label_id in enumerate(preds):
        label = id2label[label_id]
        start, end = offset_mapping[i]
        if start == end:
            continue
        if label == "B-HPO_TERM":
            if current_offsets:
                spans = current_offsets
                mention = " ".join([sentence[s:e] for s, e in spans]).strip()
                if not is_noise(mention):
                    hpo_id = normalize_mention(mention)
                    (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))
            current_offsets = [(start, end)]
        elif label == "I-HPO_TERM" and current_offsets:
            current_offsets.append((start, end))
        else:
            if current_offsets:
                spans = current_offsets
                mention = " ".join([sentence[s:e] for s, e in spans]).strip()
                if not is_noise(mention):
                    hpo_id = normalize_mention(mention)
                    (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))
            current_offsets = []

    # Last one
    if current_offsets:
        spans = current_offsets
        mention = " ".join([sentence[s:e] for s, e in spans]).strip()
        if not is_noise(mention):
            hpo_id = normalize_mention(mention)
            (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))

# === Step 5: Output
print(f"\n✅ Mapped Mentions ({len(mapped_mentions)}):")
for mention, hpo_id in mapped_mentions:
    print(f"{mention} --> {hpo_id}")

print(f"\n❌ Unmapped Mentions ({len(unmapped_mentions)}):")
for mention, _ in unmapped_mentions:
    print(mention)



In [None]:
import re
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification
import obonet
from rapidfuzz import process
import torch

# === Config ===
MODEL_DIR = "/kaggle/working/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test_text_only.jsonl")
MAX_LENGTH = 512
DEVICE = "cuda:0"

# === Utility: 清洗函数（去掉 [UNK]、子词标记、文献引用、孤立数字）===
def clean_text(text: str) -> str:
    text = text.replace("[UNK]", " ")
    text = re.sub(r"\[\s*\d+(?:\s*,\s*\d+)*\s*\]", "", text)
    text = re.sub(r"\d+%?", "", text)
    text = text.replace("##", "")
    return re.sub(r"\s{2,}", " ", text).strip()

# === Utility: 更严格的噪声过滤 ===
def is_noise(mention: str) -> bool:
    mention = mention.strip()
    if not mention:
        return True
    if re.fullmatch(r"\d+%?", mention): return True
    if re.fullmatch(r"[^\w\s]+", mention): return True
    if len(mention) < 3: return True
    if not re.search(r"[aeiou]", mention.lower()): return True
    if len(mention) > 25 and " " not in mention: return True
    if len(mention.split()) < 2 and len(mention) <= 5: return True

    blacklist = {
        "showed", "found", "revealed", "including", "video", "fig", "fig.",
        "information", "inserted", "chinese", "data", "further", "proband", "thereafter"
    }
    if mention.lower().strip(".") in blacklist:
        return True
    return False

# === Step 1: Load test data ===
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [ex["text"] for ex in test_data]

# === Step 2: Load model and tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
tokenizer.model_max_length = MAX_LENGTH
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True).to(DEVICE)
model.eval()
id2label = model.config.id2label

# === Step 3: Load HPO terms from hp.obo ===
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
graph = obonet.read_obo(obo_url)
hpo_map = {}
for node_id, data in graph.nodes(data=True):
    name = data.get("name")
    if name:
        hpo_map.setdefault(name.lower(), []).append(node_id)
    for syn in data.get("synonym", []):
        match = re.search(r'"(.+?)"', syn)
        if match:
            hpo_map.setdefault(match.group(1).lower(), []).append(node_id)

def normalize_mention(text: str):
    key = text.lower()
    if key in hpo_map:
        return hpo_map[key][0]
    match = process.extractOne(key, hpo_map.keys(), score_cutoff=85)
    if match:
        return hpo_map[match[0]][0]
    return None

# === Step 4: Run NER + Normalize ===
mapped_mentions = []
unmapped_mentions = []

for idx, orig in enumerate(orig_sentences):
    sentence = clean_text(orig)

    encoding = tokenizer(
        sentence,
        return_offsets_mapping=True,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LENGTH,
        is_split_into_words=False
    )
    offset_mapping = encoding.pop("offset_mapping")[0].tolist()
    encoding = {k: v.to(DEVICE) for k, v in encoding.items()}

    with torch.no_grad():
        outputs = model(**encoding)
    preds = outputs.logits.argmax(dim=-1)[0].cpu().tolist()

    current_offsets = []
    for i, label_id in enumerate(preds):
        label = id2label[label_id]
        start, end = offset_mapping[i]
        if start == end:
            continue
        if label == "B-HPO_TERM":
            if current_offsets:
                spans = current_offsets
                mention = " ".join([sentence[s:e] for s, e in spans])
                mention = re.sub(r"^[^\w]+", "", mention)
                mention = re.sub(r"[^\w]+$", "", mention)
                mention = re.sub(r"\s{2,}", " ", mention).strip()
                if not is_noise(mention):
                    hpo_id = normalize_mention(mention)
                    (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))
            current_offsets = [(start, end)]
        elif label == "I-HPO_TERM" and current_offsets:
            current_offsets.append((start, end))
        else:
            if current_offsets:
                spans = current_offsets
                mention = " ".join([sentence[s:e] for s, e in spans])
                mention = re.sub(r"^[^\w]+", "", mention)
                mention = re.sub(r"[^\w]+$", "", mention)
                mention = re.sub(r"\s{2,}", " ", mention).strip()
                if not is_noise(mention):
                    hpo_id = normalize_mention(mention)
                    (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))
            current_offsets = []

    # 最后一个 mention
    if current_offsets:
        spans = current_offsets
        mention = " ".join([sentence[s:e] for s, e in spans])
        mention = re.sub(r"^[^\w]+", "", mention)
        mention = re.sub(r"[^\w]+$", "", mention)
        mention = re.sub(r"\s{2,}", " ", mention).strip()
        if not is_noise(mention):
            hpo_id = normalize_mention(mention)
            (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))

# === Step 5: Output
print(f"\n✅ Mapped Mentions ({len(mapped_mentions)}):")
for mention, hpo_id in mapped_mentions:
    print(f"{mention} --> {hpo_id}")

print(f"\n❌ Unmapped Mentions ({len(unmapped_mentions)}):")
for mention, _ in unmapped_mentions:
    print(mention)


In [None]:
import re
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification
import obonet
from rapidfuzz import process
import torch

# === Config ===
MODEL_DIR  = "/kaggle/working/ner_pubmedbert_saved_HPO"
TEST_FILE  = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE   = Path("/kaggle/working/test_normalized_mentions.jsonl")
MAX_LENGTH = 512
DEVICE     = "cuda:0"

# === Utility: 清洗文本 ===
def clean_text(text: str) -> str:
    text = text.replace("[UNK]", " ")
    text = re.sub(r"\[\s*\d+(?:\s*,\s*\d+)*\s*\]", "", text)
    text = re.sub(r"\d+%?", "", text)
    text = text.replace("##", "")
    text = re.sub(r"\s{2,}", " ", text)
    return text.strip()

# === Utility: 判断噪声 mention ===
def is_noise(mention: str) -> bool:
    mention = mention.strip().lower()
    if not mention or len(mention) < 3:
        return True
    if mention in {"showed", "had", "was", "were", "is", "are", "and", "or", "she", "he"}:
        return True
    if mention.startswith(",") or mention.startswith(".") or mention.startswith(" "):
        return True
    if mention.count(" ") >= 4:  # 太长的碎片
        return True
    if re.fullmatch(r"[^\w]+", mention):
        return True
    if re.search(r"\b(?:he|she|it|they|this|that)\b.*\b(?:is|was|were|had|has|showed)\b", mention):
        return True
    return False

# === Step 1: 读取数据 ===
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [" ".join(ex["tokens"]) for ex in test_data]

# === Step 2: 加载模型 ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
tokenizer.model_max_length = MAX_LENGTH
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True).to(DEVICE)
model.eval()
id2label = model.config.id2label

# === Step 3: 加载 HPO 词典 ===
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
graph   = obonet.read_obo(obo_url)
hpo_map = {}
for node_id, data in graph.nodes(data=True):
    name = data.get("name")
    if name:
        hpo_map.setdefault(name.lower(), []).append(node_id)
    for syn in data.get("synonym", []):
        syn_text = syn.split('"')[1]
        hpo_map.setdefault(syn_text.lower(), []).append(node_id)

def normalize_mention(text: str):
    key = text.lower().strip()
    if key in hpo_map:
        return hpo_map[key][0]
    matches = process.extract(key, list(hpo_map.keys()), limit=1, score_cutoff=80)
    return hpo_map[matches[0][0]][0] if matches else None

# === Step 4: 推理并抽取 ===
normalized_mentions = []

for idx, raw in enumerate(orig_sentences):
    sentence = clean_text(raw)
    encoding = tokenizer(sentence, return_offsets_mapping=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model(**{k: encoding[k] for k in ["input_ids", "attention_mask"]})
    logits  = outputs.logits[0]
    preds   = logits.argmax(dim=-1).cpu().tolist()
    offsets = encoding["offset_mapping"][0].cpu().tolist()

    span_start = span_end = None
    for i, label_id in enumerate(preds):
        label = id2label[label_id]
        if label == "B-HPO_TERM":
            span_start, span_end = offsets[i]
        elif label == "I-HPO_TERM" and span_start is not None:
            span_end = offsets[i][1]
        else:
            if span_start is not None:
                mention = sentence[span_start:span_end]
                mention_cleaned = mention.strip()
                if not is_noise(mention_cleaned):
                    normalized_mentions.append({
                        "sentence_index": idx,
                        "sentence":       sentence,
                        "mention":        mention_cleaned,
                        "span":           (span_start, span_end),
                        "hpo_id":         normalize_mention(mention_cleaned)
                    })
                span_start = span_end = None

    if span_start is not None:
        mention = sentence[span_start:span_end]
        mention_cleaned = mention.strip()
        if not is_noise(mention_cleaned):
            normalized_mentions.append({
                "sentence_index": idx,
                "sentence":       sentence,
                "mention":        mention_cleaned,
                "span":           (span_start, span_end),
                "hpo_id":         normalize_mention(mention_cleaned)
            })

# === Step 5: 输出结果 ===
total  = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"Total mentions: {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})")
print(f"Failed to map:   {total - mapped} ({(total - mapped)/total:.1%})")

with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"Results saved to {OUT_FILE.resolve()}")


In [None]:
import re
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification
import obonet
from rapidfuzz import process
import torch

# === Config ===
MODEL_DIR  = "/kaggle/working/ner_pubmedbert_saved_HPO"
TEST_FILE  = Path("/kaggle/working/bio_outputs/test.jsonl")
OUT_FILE   = Path("/kaggle/working/test_normalized_mentions.jsonl")
MAX_LENGTH = 512         # 模型最大 token 长度
DEVICE     = "cuda:0"

# === Utility: 清洗函数（去掉数字、标点、##子词标记、[UNK]） ===
def clean_text(text: str) -> str:
    text = text.replace("[UNK]", " ")
    # 去除文献引用 [24, 25]
    text = re.sub(r"\[\s*\d+(?:\s*,\s*\d+)*\s*\]", "", text)
    # 去除百分比和孤立数字
    text = re.sub(r"\d+%?", "", text)
    # 去掉子词标记
    text = text.replace("##", "")
    # 合并多余空格
    return re.sub(r"\s{2,}", " ", text).strip()

# === Utility: 噪声过滤 ===
def is_noise(mention: str) -> bool:
    # 纯数字或百分号
    if re.fullmatch(r"\d+%?", mention):
        return True
    # 纯标点
    if re.fullmatch(r"[^\w\s]+", mention):
        return True
    # 过短
    if len(mention.strip()) < 3:
        return True
    return False

# === Step 1: 读取测试数据 ===
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [" ".join(ex["tokens"]) for ex in test_data]

# === Step 2: 加载 tokenizer & model ===
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_DIR, use_fast=True, local_files_only=True
)
tokenizer.model_max_length = MAX_LENGTH
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_DIR, local_files_only=True
).to(DEVICE)
model.eval()

# id -> label 映射，如 "B-HPO_TERM", "I-HPO_TERM", "O"
id2label = model.config.id2label

# === Step 3: 构建 HPO ontology 映射 ===
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
graph   = obonet.read_obo(obo_url)
hpo_map = {}
for node_id, data in graph.nodes(data=True):
    name = data.get("name")
    if name:
        hpo_map.setdefault(name.lower(), []).append(node_id)
    for syn in data.get("synonym", []):
        text = syn.split('"')[1]
        hpo_map.setdefault(text.lower(), []).append(node_id)

def normalize_mention(text: str):
    key = text.lower()
    if key in hpo_map:
        return hpo_map[key][0]
    # 尝试模糊匹配
    matches = process.extract(key, list(hpo_map.keys()), limit=1, score_cutoff=85)
    if matches:
        return hpo_map[matches[0][0]][0]
    return None

# === Step 4: 对每一句文本清洗 + NER 抽取 ===
normalized_mentions = []

for idx, orig in enumerate(orig_sentences):
    # 4.1 清洗
    sentence = clean_text(orig)

    # 4.2 分词 & 截断 & 编码
    encoding = tokenizer(
        sentence,
        return_offsets_mapping=True,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    ).to(DEVICE)

    # 4.3 模型推理
    with torch.no_grad():
        outputs = model(
            input_ids=encoding["input_ids"],
            attention_mask=encoding["attention_mask"]
        )
    logits  = outputs.logits[0]  # [seq_len, num_labels]
    preds   = logits.argmax(dim=-1).cpu().tolist()
    offsets = encoding["offset_mapping"][0].cpu().tolist()

    # 4.4 提取连续 HPO_TERM span
    span_start = span_end = None
    for i, label_id in enumerate(preds):
        label = id2label[label_id]
        if label == "B-HPO_TERM":
            span_start, span_end = offsets[i]
        elif label == "I-HPO_TERM" and span_start is not None:
            span_end = offsets[i][1]
        else:
            if span_start is not None:
                mention = sentence[span_start:span_end]
                hpo_id  = normalize_mention(mention)
                # 4.5 后处理：只保留有效映射且非噪声的 mention
                if hpo_id and not is_noise(mention):
                    normalized_mentions.append({
                        "sentence_index": idx,
                        "sentence":       sentence,
                        "mention":        mention,
                        "span":           (span_start, span_end),
                        "hpo_id":         hpo_id
                    })
                span_start = span_end = None

    # 若句尾有未闭合 span
    if span_start is not None:
        mention = sentence[span_start:span_end]
        hpo_id  = normalize_mention(mention)
        if hpo_id and not is_noise(mention):
            normalized_mentions.append({
                "sentence_index": idx,
                "sentence":       sentence,
                "mention":        mention,
                "span":           (span_start, span_end),
                "hpo_id":         hpo_id
            })

# === Step 5: 输出统计 & 保存 ===
total  = len(normalized_mentions)
mapped = sum(1 for r in normalized_mentions if r["hpo_id"] is not None)
print(f"Total mentions: {total}")
print(f"Mapped to HP ID: {mapped} ({mapped/total:.1%})")
print(f"Failed to map:   {total-mapped} ({(total-mapped)/total:.1%})")

with OUT_FILE.open("w", encoding="utf-8") as fout:
    for rec in normalized_mentions:
        fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"Results saved to {OUT_FILE.resolve()}")


8888888888888888888

In [None]:
import json
from pathlib import Path
from sklearn.model_selection import train_test_split

# -------------------
# Constants & Paths
# -------------------
FILE_MERGED = Path("/kaggle/working/merged_spans_with_entities.jsonl")
DIR_SILVER  = Path("/kaggle/input/hpo-only")
OUT_DIR     = Path("/kaggle/working/bio_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_FILE = OUT_DIR / "train.jsonl"
DEV_FILE   = OUT_DIR / "dev.jsonl"
TEST_FILE  = OUT_DIR / "test.jsonl"

ENTITY_TYPES = {
    "AGE_ONSET", "AGE_FOLLOWUP", "AGE_DEATH",
    "PATIENT", "HPO_TERM", "GENE", "GENE_VARIANT"
}

# -------------------
# Utility Functions
# -------------------
def iter_jsonl(path: Path):
    with path.open("r", encoding="utf-8") as fh:
        for line in fh:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                continue

def filter_valid_entities(rec):
    """保留有效实体类型，清理无关内容"""
    spans = [s for s in rec.get("spans", []) if s.get("label") in ENTITY_TYPES]
    if spans:
        return {
            "text": rec["text"],
            "spans": spans
        }
    return None

def dump_jsonl(path: Path, data):
    with path.open("w", encoding="utf-8") as fh:
        for obj in data:
            fh.write(json.dumps(obj, ensure_ascii=False) + "\n")

def load_filtered_silver(path: Path):
    extra = []
    for rec in iter_jsonl(path):
        rec = filter_valid_entities(rec)
        if rec:
            extra.append(rec)
    return extra

# -------------------
# Step 1: Load and convert gold data
# -------------------
print(">> Loading gold data …")
merged_filtered = []
for rec in iter_jsonl(FILE_MERGED):
    filtered = filter_valid_entities(rec)
    if filtered:
        merged_filtered.append(filtered)
print(f"Total valid records in gold: {len(merged_filtered)}")

# -------------------
# Step 2: Split gold into train/dev/test
# -------------------
train_dev, test_set = train_test_split(
    merged_filtered,
    test_size=0.20,
    random_state=42
)
train_set, dev_set = train_test_split(
    train_dev,
    test_size=0.25,
    random_state=42
)
print(f"Split sizes – TRAIN: {len(train_set)}, DEV: {len(dev_set)}, TEST: {len(test_set)}")

# -------------------
# Step 3: Add silver data to train set
# -------------------
extra_train = []
if DIR_SILVER.exists():
    print(">> Loading silver data from hpo-only/")
    for jf in sorted(DIR_SILVER.glob("*.jsonl")):
        print(f"  - {jf.name}")
        extra_train.extend(load_filtered_silver(jf))
else:
    print(">> Silver data directory not found.")

train_final = train_set + extra_train
print(f"Final train size: {len(train_final)} (including {len(extra_train)} silver records)")

# -------------------
# Step 4: Save to disk
# -------------------
dump_jsonl(TRAIN_FILE, train_final)
dump_jsonl(DEV_FILE, dev_set)
dump_jsonl(TEST_FILE, test_set)

print(f"\nSaved to:")
print(f"  ➜ {TRAIN_FILE}")
print(f"  ➜ {DEV_FILE}")
print(f"  ➜ {TEST_FILE}")


In [None]:
import json
from pathlib import Path
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)
import evaluate

# === 1. Load pre-split data with silver already included ===
BIO_DIR = Path("/kaggle/working/bio_outputs")

def load_jsonl(path: Path):
    with path.open(encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]

train_data = load_jsonl(BIO_DIR / "train.jsonl")
dev_data   = load_jsonl(BIO_DIR / "dev.jsonl")
test_data  = load_jsonl(BIO_DIR / "test.jsonl")

ds_raw = DatasetDict({
    "train": Dataset.from_list(train_data),
    "validation": Dataset.from_list(dev_data),
    "test": Dataset.from_list(test_data),
})
print("✅ Loaded dataset sizes:", {k: len(v) for k, v in ds_raw.items()})

# === 2. Tokenizer and label mappings ===
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    use_fast=True
)

label_list = ["O", "B-HPO_TERM", "I-HPO_TERM"]
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for label, idx in label2id.items()}

# === 3. Span-to-token label encoder ===
def encode_and_align_labels(example):
    text = example["text"]
    spans = example["spans"]
    entities = [(s["start"], s["end"]) for s in spans]

    encoding = tokenizer(
        text,
        return_offsets_mapping=True,
        truncation=True,
        max_length=512,
    )

    labels = []
    for offset in encoding["offset_mapping"]:
        if offset == (0, 0):
            labels.append("O")
            continue
        tag = "O"
        for start, end in entities:
            if offset[0] >= start and offset[1] <= end:
                tag = "B-HPO_TERM" if offset[0] == start else "I-HPO_TERM"
                break
        labels.append(tag)

    encoding["labels"] = [label2id[l] for l in labels]
    return encoding

# === 4. Encode all splits ===
ds_encoded = ds_raw.map(
    encode_and_align_labels,
    batched=False,
    remove_columns=["text", "spans"]
)
print("✅ Encoding complete.")

# === 5. Load model ===
model = AutoModelForTokenClassification.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
)

# === 6. Evaluation metrics ===
seqeval = evaluate.load("seqeval")

def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    labels = p.label_ids
    true_labels = [
        [id2label[lid] for lid in seq if lid != -100]
        for seq in labels
    ]
    pred_labels = [
        [id2label[pid] for pid, lid in zip(pred_seq, label_seq) if lid != -100]
        for pred_seq, label_seq in zip(preds, labels)
    ]
    result = seqeval.compute(predictions=pred_labels, references=true_labels)
    return {
        "overall_precision": result["overall_precision"],
        "overall_recall":    result["overall_recall"],
        "overall_f1":        result["overall_f1"],
        "overall_accuracy":  result["overall_accuracy"],
    }

# === 7. Training configuration ===
training_args = TrainingArguments(
    output_dir="ner_pubmedbert",
    eval_strategy="steps",
    eval_steps=50,
    save_steps=500,
    logging_strategy="steps",
    logging_steps=50,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    learning_rate=3e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="overall_f1",
    greater_is_better=True,
    report_to=["none"],
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_encoded["train"],
    eval_dataset=ds_encoded["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
    compute_metrics=compute_metrics,
)

# === 8. Train and evaluate ===
trainer.train()
trainer.evaluate()

# === 9. Predict on test set ===
print("\n--- Predicting on test set ---")
pred_output = trainer.predict(ds_encoded["test"])
preds = pred_output.predictions.argmax(-1)
labels = pred_output.label_ids

true_labels = [
    [id2label[lid] for lid in seq if lid != -100]
    for seq in labels
]
pred_labels = [
    [id2label[pid] for pid, lid in zip(pred_seq, label_seq) if lid != -100]
    for pred_seq, label_seq in zip(preds, labels)
]

detailed_result = seqeval.compute(predictions=pred_labels, references=true_labels)

print("\n📊 HPO_TERM classification report:")
for label, metrics in detailed_result.items():
    if label == "HPO_TERM":
        print(f"{label:20} | Precision: {metrics['precision']:.3f} | Recall: {metrics['recall']:.3f} | F1: {metrics['f1']:.3f}")



In [None]:
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

# === Config ===
MODEL_DIR = "/kaggle/working/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")  # 使用原始text字段的测试集
MAX_LENGTH = 512
DEVICE = 0  # use -1 for CPU (if no GPU)

# === Step 1: Load test data ===
print(">> Loading test data")
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [ex["text"] for ex in test_data]

# === Step 2: Load model and tokenizer with pipeline ===
print(">> Loading model and tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",  # 拼接B/I标签
    device=DEVICE
)

# === Step 3: Run NER without post-processing
print(">> Running NER without post-processing")
all_results = []

for idx, sentence in enumerate(orig_sentences):
    results = ner_pipeline(sentence)
    for ent in results:
        word = ent["word"]
        start = ent["start"]
        end = ent["end"]
        label = ent["entity_group"]
        score = ent["score"]
        all_results.append({
            "sentence_idx": idx,
            "text": sentence,
            "mention": word,
            "start": start,
            "end": end,
            "label": label,
            "score": round(score, 4)
        })

# === Step 4: Print Results
print(f"\nTotal mentions extracted: {len(all_results)}")
for r in all_results:
    print(f"[{r['label']}] {r['mention']} (score={r['score']}, span={r['start']}-{r['end']})")


In [None]:
import re
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import obonet
from rapidfuzz import process

# === Config ===
MODEL_DIR = "/kaggle/working/ner_pubmedbert_saved_HPO"
TEST_FILE = Path("/kaggle/working/bio_outputs/test.jsonl")
MAX_LENGTH = 512
DEVICE = 0  # use -1 for CPU

# === Step 1: Load test data ===
print(">> Loading test data")
test_data = [json.loads(line) for line in TEST_FILE.open(encoding="utf-8")]
orig_sentences = [ex["text"] for ex in test_data]

# === Step 2: Load model and tokenizer with pipeline ===
print(">> Loading model and tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, local_files_only=True)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",
    device=DEVICE
)

# === Step 3: Load HPO terms from hp.obo
print(">> Loading HPO terms from obo")
obo_url = "http://purl.obolibrary.org/obo/hp.obo"
graph = obonet.read_obo(obo_url)
hpo_map = {}
for node_id, data in graph.nodes(data=True):
    name = data.get("name")
    if name:
        hpo_map.setdefault(name.lower(), []).append(node_id)
    for syn in data.get("synonym", []):
        match = re.search(r'"(.+?)"', syn)
        if match:
            hpo_map.setdefault(match.group(1).lower(), []).append(node_id)

def normalize_mention(text: str):
    key = text.lower()
    if key in hpo_map:
        return hpo_map[key][0]
    match = process.extractOne(key, hpo_map.keys(), score_cutoff=85)
    if match:
        return hpo_map[match[0]][0]
    return None

# === Step 4: Run NER + Normalize (No Noise Filtering)
print(">> Running NER and normalization (no filtering)")
mapped_mentions = []
unmapped_mentions = []

for idx, sentence in enumerate(orig_sentences):
    results = ner_pipeline(sentence)
    for ent in results:
        mention = ent["word"].strip()
        hpo_id = normalize_mention(mention)
        (mapped_mentions if hpo_id else unmapped_mentions).append((mention, hpo_id))

# === Step 5: Output
print(f"\n✅ Mapped Mentions ({len(mapped_mentions)}):")
for mention, hpo_id in mapped_mentions:
    print(f"{mention} --> {hpo_id}")

print(f"\n❌ Unmapped Mentions ({len(unmapped_mentions)}):")
for mention, _ in unmapped_mentions:
    print(mention)


In [None]:
import json
from pathlib import Path
from sklearn.model_selection import train_test_split

# -------------------
# Constants & Paths
# -------------------
FILE_MERGED = Path("/kaggle/working/merged_spans_with_entities.jsonl")
SILVER_FILE = Path("/kaggle/input/hpo-only/HPO_only.jsonl")
OUT_DIR     = Path("/kaggle/working/bio_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_FILE = OUT_DIR / "train.jsonl"
DEV_FILE   = OUT_DIR / "dev.jsonl"
TEST_FILE  = OUT_DIR / "test.jsonl"

ENTITY_TYPES = {
    "AGE_ONSET", "AGE_FOLLOWUP", "AGE_DEATH",
    "PATIENT", "HPO_TERM", "GENE", "GENE_VARIANT"
}

# -------------------
# Utility Functions
# -------------------
def iter_jsonl(path: Path):
    with path.open("r", encoding="utf-8") as fh:
        for line in fh:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                continue

def filter_valid_entities(rec):
    """保留有效实体类型，清理无关内容"""
    spans = [s for s in rec.get("spans", []) if s.get("label") in ENTITY_TYPES]
    if spans:
        return {
            "text": rec["text"],
            "spans": spans
        }
    return None

def dump_jsonl(path: Path, data):
    with path.open("w", encoding="utf-8") as fh:
        for obj in data:
            fh.write(json.dumps(obj, ensure_ascii=False) + "\n")

# -------------------
# Step 1: Load and convert gold data
# -------------------
print(">> Loading gold data …")
merged_filtered = []
for rec in iter_jsonl(FILE_MERGED):
    filtered = filter_valid_entities(rec)
    if filtered:
        merged_filtered.append(filtered)
print(f"Total valid records in gold: {len(merged_filtered)}")

# -------------------
# Step 2: Split gold into train/dev/test
# -------------------
train_dev, test_set = train_test_split(
    merged_filtered,
    test_size=0.20,
    random_state=42
)
train_set, dev_set = train_test_split(
    train_dev,
    test_size=0.25,
    random_state=42
)
print(f"Split sizes – TRAIN: {len(train_set)}, DEV: {len(dev_set)}, TEST: {len(test_set)}")

# -------------------
# Step 3: Add silver data to train set
# -------------------
extra_train = []
if SILVER_FILE.exists():
    print(">> Loading silver data from HPO_only.jsonl")
    for rec in iter_jsonl(SILVER_FILE):
        filtered = filter_valid_entities(rec)
        if filtered:
            extra_train.append(filtered)
    print(f"  ➜ Loaded {len(extra_train)} silver records.")
else:
    print(f">> Silver file not found: {SILVER_FILE}")

train_final = train_set + extra_train
print(f"Final train size: {len(train_final)} (including {len(extra_train)} silver records)")

# -------------------
# Step 4: Save to disk
# -------------------
dump_jsonl(TRAIN_FILE, train_final)
dump_jsonl(DEV_FILE, dev_set)
dump_jsonl(TEST_FILE, test_set)

print(f"\nSaved to:")
print(f"  ➜ {TRAIN_FILE}")
print(f"  ➜ {DEV_FILE}")
print(f"  ➜ {TEST_FILE}")


In [None]:
import json
from pathlib import Path
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)
import evaluate

# === 1. Load pre-split data with silver already included ===
BIO_DIR = Path("/kaggle/working/bio_outputs")

def load_jsonl(path: Path):
    with path.open(encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]

train_data = load_jsonl(BIO_DIR / "train.jsonl")
dev_data   = load_jsonl(BIO_DIR / "dev.jsonl")
test_data  = load_jsonl(BIO_DIR / "test.jsonl")

ds_raw = DatasetDict({
    "train": Dataset.from_list(train_data),
    "validation": Dataset.from_list(dev_data),
    "test": Dataset.from_list(test_data),
})
print("✅ Loaded dataset sizes:", {k: len(v) for k, v in ds_raw.items()})

# === 2. Tokenizer and label mappings ===
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    use_fast=True
)

label_list = ["O", "B-HPO_TERM", "I-HPO_TERM"]
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for label, idx in label2id.items()}

# === 3. Span-to-token label encoder ===
def encode_and_align_labels(example):
    text = example["text"]
    spans = example["spans"]
    entities = [(s["start"], s["end"]) for s in spans]

    encoding = tokenizer(
        text,
        return_offsets_mapping=True,
        truncation=True,
        max_length=512,
    )

    labels = []
    for offset in encoding["offset_mapping"]:
        if offset == (0, 0):
            labels.append("O")
            continue
        tag = "O"
        for start, end in entities:
            if offset[0] >= start and offset[1] <= end:
                tag = "B-HPO_TERM" if offset[0] == start else "I-HPO_TERM"
                break
        labels.append(tag)

    encoding["labels"] = [label2id[l] for l in labels]
    return encoding

# === 4. Encode all splits ===
ds_encoded = ds_raw.map(
    encode_and_align_labels,
    batched=False,
    remove_columns=["text", "spans"]
)
print("✅ Encoding complete.")

# === 5. Load model ===
model = AutoModelForTokenClassification.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
)

# === 6. Evaluation metrics ===
seqeval = evaluate.load("seqeval")

def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    labels = p.label_ids
    true_labels = [
        [id2label[lid] for lid in seq if lid != -100]
        for seq in labels
    ]
    pred_labels = [
        [id2label[pid] for pid, lid in zip(pred_seq, label_seq) if lid != -100]
        for pred_seq, label_seq in zip(preds, labels)
    ]
    result = seqeval.compute(predictions=pred_labels, references=true_labels)
    return {
        "overall_precision": result["overall_precision"],
        "overall_recall":    result["overall_recall"],
        "overall_f1":        result["overall_f1"],
        "overall_accuracy":  result["overall_accuracy"],
    }

# === 7. Training configuration ===
training_args = TrainingArguments(
    output_dir="ner_pubmedbert",
    eval_strategy="steps",
    eval_steps=50,
    save_steps=500,
    logging_strategy="steps",
    logging_steps=50,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    learning_rate=3e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="overall_f1",
    greater_is_better=True,
    report_to=["none"],
    save_total_limit=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_encoded["train"],
    eval_dataset=ds_encoded["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
    compute_metrics=compute_metrics,
)

# === 8. Train and evaluate ===
trainer.train()
trainer.evaluate()

# === 9. Predict on test set ===
print("\n--- Predicting on test set ---")
pred_output = trainer.predict(ds_encoded["test"])
preds = pred_output.predictions.argmax(-1)
labels = pred_output.label_ids

true_labels = [
    [id2label[lid] for lid in seq if lid != -100]
    for seq in labels
]
pred_labels = [
    [id2label[pid] for pid, lid in zip(pred_seq, label_seq) if lid != -100]
    for pred_seq, label_seq in zip(preds, labels)
]

detailed_result = seqeval.compute(predictions=pred_labels, references=true_labels)

print("\n📊 HPO_TERM classification report:")
for label, metrics in detailed_result.items():
    if label == "HPO_TERM":
        print(f"{label:20} | Precision: {metrics['precision']:.3f} | Recall: {metrics['recall']:.3f} | F1: {metrics['f1']:.3f}")