In [None]:
# -*- coding: utf-8 -*-
# BioBERT + CRF for NER (BIO format)
# Usage:
#   python train_biobert_crf.py

import os
import random
from typing import Dict, Any, List, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchcrf import CRF
from transformers import BertModel, AutoTokenizer, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report

# ========== Reproducibility ==========
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# ========== Device ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ========== Data ==========
def load_data(file_path: str) -> Tuple[List[List[Tuple[str, str]]], dict, dict, List[str]]:
    """
    读取 BIO 格式数据：每行 'word tag'，句子间以空行分隔
    返回:sentences, tag2idx, idx2tag, entity_types
    """
    sentences = []
    sentence = []
    labels = set()

    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                if sentence:
                    sentences.append(sentence)
                    sentence = []
            else:
                parts = line.split()
                word = ' '.join(parts[:-1])
                tag = parts[-1]
                sentence.append((word, tag))
                if tag != 'O' and '-' in tag:
                    labels.add(tag.split('-')[1])

    if sentence:
        sentences.append(sentence)

    unique_tags = sorted({tag for sent in sentences for _, tag in sent} | {'O'})
    tag2idx = {tag: idx for idx, tag in enumerate(unique_tags)}
    idx2tag = {idx: tag for tag, idx in tag2idx.items()}
    return sentences, tag2idx, idx2tag, sorted(list(labels))


class NERDataset(Dataset):
    def __init__(self, sentences, tag2idx, tokenizer, max_len=128):
        self.sentences = sentences
        self.tag2idx = tag2idx
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.o_id = self.tag2idx['O']

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        words = [w for w, _ in sentence]
        tags = [t for _, t in sentence]

        encoding = self.tokenizer(
            words,
            is_split_into_words=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        word_ids = encoding.word_ids()
        labels = []
        prev_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                labels.append(self.o_id)  # 特殊/填充位，CRF 通过 mask 忽略
            elif word_idx != prev_word_idx:
                labels.append(self.tag2idx[tags[word_idx]])
            else:
                # 同一单词的后续子词：若是 B- 则延续为 I-
                if tags[word_idx].startswith('B-'):
                    ent = tags[word_idx].split('-')[1]
                    labels.append(self.tag2idx.get(f'I-{ent}', self.tag2idx[tags[word_idx]]))
                else:
                    labels.append(self.tag2idx[tags[word_idx]])
            prev_word_idx = word_idx

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(labels, dtype=torch.long)
        }


# ========== Model ==========
class BERT_CRF_NER(nn.Module):
    def __init__(self, bert_model: str, num_tags: int, dropout_prob: float = 0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_tags)
        self.crf = CRF(num_tags, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        emissions = self.classifier(self.dropout(outputs.last_hidden_state))
        mask = attention_mask.bool()

        if labels is not None:
            loss = -self.crf(emissions, labels, mask=mask, reduction='mean')
            return loss
        else:
            return self.crf.decode(emissions, mask=mask)


# ========== Pretty Print ==========
def pretty_print_report(title: str, report: Dict[str, Any]) -> None:
    print(f"\n--- {title} ---")
    print("📊 分类报告：")
    header = f"{'':<14}{'precision':>10} {'recall':>8} {'f1-score':>10} {'support':>9}"
    print(header)

    skip = {'micro avg', 'macro avg', 'weighted avg', 'accuracy'}
    labels = [k for k in report.keys() if k not in skip]

    for lab in labels:
        row = report[lab]
        p = row.get('precision', 0.0)
        r = row.get('recall', 0.0)
        f = row.get('f1-score', 0.0)
        s = int(row.get('support', 0))
        print(f"{lab:<14}{p:>10.2f} {r:>8.2f} {f:>10.2f} {s:>9}")

    for avg_key in ['micro avg', 'macro avg', 'weighted avg']:
        if avg_key in report:
            row = report[avg_key]
            p = row.get('precision', 0.0)
            r = row.get('recall', 0.0)
            f = row.get('f1-score', 0.0)
            s = int(row.get('support', 0))
            print(f"\n{avg_key:<14}{p:>10.2f} {r:>8.2f} {f:>10.2f} {s:>9}")

    w = report.get('weighted avg', {})
    print("\n--- 总体性能指标 ---")
    print(f"Overall Precision: {w.get('precision', 0.0):.4f}")
    print(f"Overall Recall:    {w.get('recall', 0.0):.4f}")
    print(f"Overall F1-Score:  {w.get('f1-score', 0.0):.4f}\n")


# ========== Train / Eval ==========
def train_one_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        loss = model(input_ids, attention_mask, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item()
    return total_loss / max(1, len(dataloader))


def evaluate(model, dataloader, device, idx2tag, verbose: bool = False):
    model.eval()
    true_labels, pred_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()

            preds = model(input_ids, attention_mask)  # list[list[int]]

            for i in range(len(preds)):
                mask_i = attention_mask[i].cpu().numpy().astype(bool)
                valid_len = int(mask_i.sum())

                pred_tags = [idx2tag[p] for p in preds[i][:valid_len]]
                true_tags = [idx2tag[l] for l in labels[i][:valid_len]]

                m = min(len(pred_tags), len(true_tags))
                pred_labels.append(pred_tags[:m])
                true_labels.append(true_tags[:m])

    report: Dict[str, Any] = classification_report(true_labels, pred_labels, output_dict=True)
    if verbose:
        pretty_print_report("Validation Report", report)

    precision = report['weighted avg']['precision']
    recall = report['weighted avg']['recall']
    f1 = report['weighted avg']['f1-score']

    # 附加实体级 ACC（可选）
    correct = total = 0
    for t_seq, p_seq in zip(true_labels, pred_labels):
        for t, p in zip(t_seq, p_seq):
            if t != 'O':
                total += 1
                if t == p:
                    correct += 1
    acc = correct / total if total > 0 else 0.0
    return precision, recall, f1, acc, report


# ========== Main ==========
def main():
    # ---- Config ----
    DATA_PATH = "combined_dataset.txt"   # <<< 修改为你的BIO数据路径
    BIOBERT_MODEL = "dmis-lab/biobert-base-cased-v1.1"
    MAX_LEN = 256
    BATCH_SIZE = 32
    EPOCHS = 10
    LR = 3e-5
    DROPOUT = 0.1
    SAVE_PATH = "biobert_crf_best.pt"

    # 1) Load data
    sentences, tag2idx, idx2tag, entity_types = load_data(DATA_PATH)
    print(f"Loaded {len(sentences)} sentences")
    print(f"Entity types: {entity_types}")
    print(f"Num tags: {len(tag2idx)}")

    # 2) Split
    train_sents, val_sents = train_test_split(sentences, test_size=0.2, random_state=SEED, shuffle=True)

    # 3) Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(BIOBERT_MODEL)

    # 4) Datasets / Loaders
    train_ds = NERDataset(train_sents, tag2idx, tokenizer, MAX_LEN)
    val_ds = NERDataset(val_sents, tag2idx, tokenizer, MAX_LEN)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    # 5) Model & Optimizer & Scheduler
    model = BERT_CRF_NER(BIOBERT_MODEL, len(tag2idx), DROPOUT).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    total_steps = len(train_loader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=max(1, int(0.1 * total_steps)),
        num_training_steps=total_steps
    )

    # 6) Train loop with best model tracking
    best_f1 = -1.0
    best_report = None

    print("Starting training with BioBERT+CRF ...")
    for epoch in range(1, EPOCHS + 1):
        train_loss = train_one_epoch(model, train_loader, optimizer, scheduler, device)
        precision, recall, f1, acc, report = evaluate(model, val_loader, device, idx2tag, verbose=True)

        print(f"Epoch {epoch}/{EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val  Metrics: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, ACC={acc:.4f}")
        print("-" * 60)

        # Save best
        if f1 > best_f1:
            best_f1 = f1
            best_report = report
            torch.save(model.state_dict(), SAVE_PATH)

    # 7) Final report for the best model
    print("\n✅ Training finished. Loading best model and printing final report ...")
    model.load_state_dict(torch.load(SAVE_PATH, map_location=device))
    _, _, _, _, final_report = evaluate(model, val_loader, device, idx2tag, verbose=False)
    pretty_print_report("Final Report for the Best Model on Validation Set", final_report)


if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu
Loaded 758 sentences
Entity types: ['AGE_DEATH', 'AGE_FOLLOWUP', 'AGE_ONSET', 'GENE', 'GENE_VARIANT', 'HPO_TERM', 'PATIENT']
Tag mapping size: 15
Starting training with BioBERT+CRF ...


  return forward_call(*args, **kwargs)
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/5
Train Loss: 117.3943
Test Metrics: Precision=0.0902, Recall=0.1315, F1=0.1066, ACC=0.4998
------------------------------------------------------------


KeyboardInterrupt: 