In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, BertTokenizerFast
from torchcrf import CRF
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report
import numpy as np
import re

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据读取与处理
def load_data(file_path):
    sentences = []
    sentence = []
    labels = set()
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                if sentence:
                    sentences.append(sentence)
                    sentence = []
            else:
                parts = line.split()
                word = ' '.join(parts[:-1])
                tag = parts[-1]
                sentence.append((word, tag))
                if tag != 'O':
                    entity_type = tag.split('-')[1]
                    labels.add(entity_type)
    
    if sentence:
        sentences.append(sentence)
    
    # 创建标签映射
    unique_tags = sorted({tag for sent in sentences for _, tag in sent} | {'O'})
    tag2idx = {tag: idx for idx, tag in enumerate(unique_tags)}
    idx2tag = {idx: tag for tag, idx in tag2idx.items()}
    
    return sentences, tag2idx, idx2tag, list(labels)

# 数据加载器
class NERDataset(Dataset):
    def __init__(self, sentences, tag2idx, tokenizer, max_len=128):
        self.sentences = sentences
        self.tag2idx = tag2idx
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        words = [word for word, _ in sentence]
        tags = [tag for _, tag in sentence]
        
        encoding = self.tokenizer(
            words,
            is_split_into_words=True,  # 输入已分词
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # 获取单词ID映射
        word_ids = encoding.word_ids()
        
        # 创建标签ID列表
        labels = []
        prev_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:  # 特殊token
                labels.append(0)
            elif word_idx != prev_word_idx:  # 新单词的第一个子词
                labels.append(self.tag2idx[tags[word_idx]])
            else:  # 同一单词的后续子词
                # 处理B-标签的延续
                if tags[word_idx].startswith('B-'):
                    labels.append(self.tag2idx['I-' + tags[word_idx].split('-')[1]])
                else:
                    labels.append(self.tag2idx[tags[word_idx]])
            prev_word_idx = word_idx
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(labels)
        }

# 模型构建
class BERT_CRF_NER(nn.Module):
    def __init__(self, bert_model, num_tags, dropout_prob=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_tags)
        self.crf = CRF(num_tags, batch_first=True)
    
    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        emissions = self.classifier(self.dropout(outputs.last_hidden_state))
        mask = attention_mask.bool()

        if labels is not None:
            # 训练模式：计算损失
            loss = -self.crf(emissions, labels, mask=mask, reduction='mean')
            return loss
        else:
            # 预测模式：返回解码后的标签序列
            return self.crf.decode(emissions, mask=mask)
        
# 训练与评估
def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        loss = model(input_ids, attention_mask, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, device, idx2tag):
    model.eval()
    true_labels = []
    pred_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()

            # 获取预测结果
            preds = model(input_ids, attention_mask)
            if preds is None:
                raise ValueError("Model returned None during evaluation")

            for i in range(len(preds)):
                mask = attention_mask[i].cpu().numpy().astype(bool)
                valid_length = mask.sum()

                # 获取预测标签
                pred_tags = [idx2tag[p] for p in preds[i][:valid_length]]
                
                # 获取真实标签
                true_tags = [idx2tag[l] for l in labels[i][:valid_length] if l != -100]
                min_len = min(len(pred_tags), len(true_tags))
                pred_labels.append(pred_tags[:min_len])
                true_labels.append(true_tags[:min_len])

    # 计算评估指标
    report = classification_report(true_labels, pred_labels, output_dict=True)
    precision = report['weighted avg']['precision']
    recall = report['weighted avg']['recall']
    f1 = report['weighted avg']['f1-score']
    
    # 计算准确率
    correct = 0
    total = 0
    for true, pred in zip(true_labels, pred_labels):
        for t, p in zip(true, pred):
            if t != 'O':  # 只计算实体标签的准确率
                total += 1
                if t == p:
                    correct += 1
    acc = correct / total if total > 0 else 0
    return precision, recall, f1, acc

# 主程序
if __name__ == "__main__":
    # 参数配置
    BERT_MODEL = "bert-base-uncased"
    MAX_LEN = 256
    BATCH_SIZE = 32
    EPOCHS = 10
    LEARNING_RATE = 0.0001
    DROPOUT_PROB = 0.1
    
    # 1. 加载数据
    sentences, tag2idx, idx2tag, entity_types = load_data('gold_data/bio_dataset_cleaned.txt')
    print(f"Loaded {len(sentences)} sentences")
    print(f"Entity types: {entity_types}")
    print(f"Tag mapping: {tag2idx}")
    
    # 划分训练测试集
    train_sents, test_sents = train_test_split(sentences, test_size=0.2, random_state=42)

    # 2. 创建数据加载器
    tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL)
    train_dataset = NERDataset(train_sents, tag2idx, tokenizer, MAX_LEN)
    test_dataset = NERDataset(test_sents, tag2idx, tokenizer, MAX_LEN)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # 3. 初始化模型
    model = BERT_CRF_NER(BERT_MODEL, len(tag2idx), DROPOUT_PROB).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    
    # 4. 训练循环
    print("Starting training...")
    for epoch in range(EPOCHS):
        train_loss = train(model, train_loader, optimizer, device)
        precision, recall, f1, acc = evaluate(model, test_loader, device, idx2tag)
        
        print(f"Epoch {epoch+1}/{EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Test Metrics: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, ACC={acc:.4f}")
        print("-" * 50)
    
    # 最终评估
    print("Final Evaluation on Test Set:")
    precision, recall, f1, acc = evaluate(model, test_loader, device, idx2tag)
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Accuracy: {acc:.4f}")

Loaded 758 sentences
Entity types: ['AGE_ONSET', 'GENE', 'GENE_VARIANT', 'AGE_DEATH', 'PATIENT', 'AGE_FOLLOWUP', 'HPO_TERM']
Tag mapping: {'B-AGE_DEATH': 0, 'B-AGE_FOLLOWUP': 1, 'B-AGE_ONSET': 2, 'B-GENE': 3, 'B-GENE_VARIANT': 4, 'B-HPO_TERM': 5, 'B-PATIENT': 6, 'I-AGE_DEATH': 7, 'I-AGE_FOLLOWUP': 8, 'I-AGE_ONSET': 9, 'I-GENE': 10, 'I-GENE_VARIANT': 11, 'I-HPO_TERM': 12, 'I-PATIENT': 13, 'O': 14}
Starting training...


  return forward_call(*args, **kwargs)
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/10
Train Loss: 112.6284
Test Metrics: Precision=0.2744, Recall=0.1358, F1=0.1817, ACC=0.0377
--------------------------------------------------
Epoch 2/10
Train Loss: 60.6609
Test Metrics: Precision=0.4052, Recall=0.4424, F1=0.4195, ACC=0.5947
--------------------------------------------------
Epoch 3/10
Train Loss: 37.1979
Test Metrics: Precision=0.5690, Recall=0.5925, F1=0.5763, ACC=0.6473
--------------------------------------------------
Epoch 4/10
Train Loss: 25.6099
Test Metrics: Precision=0.6188, Recall=0.6944, F1=0.6473, ACC=0.7718
--------------------------------------------------
Epoch 5/10
Train Loss: 18.0556
Test Metrics: Precision=0.6343, Recall=0.7078, F1=0.6659, ACC=0.7336
--------------------------------------------------
Epoch 6/10
Train Loss: 13.1220
Test Metrics: Precision=0.6160, Recall=0.7364, F1=0.6664, ACC=0.7830
--------------------------------------------------
Epoch 7/10
Train Loss: 9.5474
Test Metrics: Precision=0.6693, Recall=0.7185, F1=0.6901, ACC=0