In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import BertTokenizerFast, BertModel
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from seqeval.scheme import IOB2
from tqdm import tqdm
import numpy as np

def read_bio_file(file_path):
    """从BIO格式的文件中读取句子和标签。"""
    sentences = []
    labels = []
    with open(file_path, 'r', encoding='utf-8') as f:
        tokens, tags = [], []
        for line in f:
            line = line.strip()
            if not line:
                if tokens:
                    sentences.append(tokens)
                    labels.append(tags)
                    tokens, tags = [], []
            else:
                splits = line.split()
                if len(splits) >= 2:
                    tokens.append(splits[0])
                    tags.append(splits[1])
        if tokens:
            sentences.append(tokens)
            labels.append(tags)
    return sentences, labels

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
# --- 请修改为您自己的文件路径 ---
file_path = r"C:\Users\Administrator\Desktop\Project\combined_dataset.txt"
sentences, labels = read_bio_file(file_path)

# 加载BERT分词器
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# 创建标签到ID的映射
label_list = sorted(set(label for label_seq in labels for label in label_seq))
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for label, i in label2id.items()}
num_labels = len(label2id)

def encode_examples(sentences, labels, max_length=128):
    """将文本和标签编码为模型输入格式。"""
    input_ids = []
    attention_masks = []
    label_ids = []

    for sent, label_seq in zip(sentences, labels):
        encoding = tokenizer(
            sent,
            is_split_into_words=True,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )

        word_ids = encoding.word_ids(batch_index=0)
        aligned_labels = []
        prev_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                aligned_labels.append(label2id["O"])
            elif word_idx != prev_word_idx:
                aligned_labels.append(label2id[label_seq[word_idx]])
            else:
                aligned_labels.append(label2id["O"])
            prev_word_idx = word_idx

        input_ids.append(encoding['input_ids'][0])
        attention_masks.append(encoding['attention_mask'][0])
        label_ids.append(torch.tensor(aligned_labels))

    return input_ids, attention_masks, label_ids

In [9]:
# 按照80/20的比例划分训练集和验证集
train_texts, val_texts, train_tags, val_tags = train_test_split(
    sentences, labels, test_size=0.2, random_state=42
)

# 为两组数据分别进行编码
train_input_ids, train_masks, train_labels = encode_examples(train_texts, train_tags)
val_input_ids, val_masks, val_labels = encode_examples(val_texts, val_tags)

# 定义Dataset类
class NERDataset(Dataset):
    def __init__(self, input_ids, attention_masks, labels):
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_masks[idx],
            "labels": self.labels[idx]
        }

# 创建Dataset对象
train_dataset = NERDataset(train_input_ids, train_masks, train_labels)
val_dataset = NERDataset(val_input_ids, val_masks, val_labels)

# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print(f"训练集批次数: {len(train_loader)}")
print(f"验证集批次数: {len(val_loader)}")

训练集批次数: 296
验证集批次数: 74


In [10]:
from torchcrf import CRF

class BERT_CRF(nn.Module):
    def __init__(self, bert_model_name, num_labels):
        super(BERT_CRF, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.crf = CRF(num_labels, batch_first=True) # 设置 batch_first=True

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)
        emissions = self.classifier(sequence_output)

        if labels is not None:
            # 计算损失，mask需要是bool类型
            loss = -self.crf(emissions, labels, mask=attention_mask.bool(), reduction='mean')
            return loss
        else:
            # 解码预测，mask需要是bool类型
            predictions = self.crf.decode(emissions, mask=attention_mask.bool())
            return predictions

In [11]:
def train(model, dataloader, optimizer, device):
    """训练一个epoch。"""
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss.backward() # 直接反向传播，因为loss已经是单个值
        optimizer.step()
        total_loss += loss.item() # 直接取item

    return total_loss / len(dataloader)

def predict(model, dataloader, device):
    """进行预测。"""
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            predictions = model(input_ids=input_ids, attention_mask=attention_mask)
            all_preds.extend(predictions)

            for i in range(labels.shape[0]):
                actual_len = attention_mask[i].sum().item()
                all_labels.append(labels[i][:actual_len].tolist())
    return all_preds, all_labels

def evaluate(preds, trues, id2label):
    """计算并打印评估指标。"""
    preds_label = [[id2label[idx] for idx in seq] for seq in preds]
    trues_label = [[id2label[idx] for idx in seq] for seq in trues]

    print("📊 分类报告:")
    print(classification_report(trues_label, preds_label, mode='strict', scheme=IOB2))

    p = precision_score(trues_label, preds_label)
    r = recall_score(trues_label, preds_label)
    f1 = f1_score(trues_label, preds_label)

    print("\n--- 总体性能指标 ---")
    print(f"Overall Precision: {p:.4f}")
    print(f"Overall Recall:    {r:.4f}")
    print(f"Overall F1-Score:  {f1:.4f}")

In [12]:
# 初始化模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERT_CRF('bert-base-uncased', num_labels).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)
epochs = 5 # 训练轮数

# --- 训练与验证循环 ---
best_val_f1 = 0.0
model_save_path = "best_model_on_validation.pth"

for epoch in range(epochs):
    print(f"--- Epoch {epoch+1}/{epochs} ---")
    
    avg_loss = train(model, train_loader, optimizer, device)
    print(f"Average Training Loss: {avg_loss:.4f}")
    
    print("Evaluating on validation set...")
    preds, trues = predict(model, val_loader, device)
    
    preds_label = [[id2label[idx] for idx in seq] for seq in preds]
    trues_label = [[id2label[idx] for idx in seq] for seq in trues]
    val_f1 = f1_score(trues_label, preds_label)
    
    print(f"Validation F1-Score: {val_f1:.4f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        print(f"✅ New best F1-Score! Saving model to {model_save_path}")
        torch.save(model.state_dict(), model_save_path)
    
    print("-" * 30)

print("Training finished.")
print(f"Best F1-Score on Validation Set: {best_val_f1:.4f}")

--- Epoch 1/5 ---


  return forward_call(*args, **kwargs)
Training: 100%|██████████| 296/296 [14:16<00:00,  2.89s/it]


Average Training Loss: 17.5083
Evaluating on validation set...
Validation F1-Score: 0.7865
✅ New best F1-Score! Saving model to best_model_on_validation.pth
------------------------------
--- Epoch 2/5 ---


Training: 100%|██████████| 296/296 [14:02<00:00,  2.85s/it]


Average Training Loss: 6.0278
Evaluating on validation set...
Validation F1-Score: 0.8293
✅ New best F1-Score! Saving model to best_model_on_validation.pth
------------------------------
--- Epoch 3/5 ---


Training: 100%|██████████| 296/296 [14:17<00:00,  2.90s/it]


Average Training Loss: 3.8373
Evaluating on validation set...
Validation F1-Score: 0.8537
✅ New best F1-Score! Saving model to best_model_on_validation.pth
------------------------------
--- Epoch 4/5 ---


Training: 100%|██████████| 296/296 [14:30<00:00,  2.94s/it]


Average Training Loss: 2.5748
Evaluating on validation set...
Validation F1-Score: 0.8550
✅ New best F1-Score! Saving model to best_model_on_validation.pth
------------------------------
--- Epoch 5/5 ---


Training: 100%|██████████| 296/296 [14:39<00:00,  2.97s/it]


Average Training Loss: 1.6995
Evaluating on validation set...
Validation F1-Score: 0.8619
✅ New best F1-Score! Saving model to best_model_on_validation.pth
------------------------------
Training finished.
Best F1-Score on Validation Set: 0.8619


In [13]:
# 加载性能最好的模型
model.load_state_dict(torch.load(model_save_path))

print("\n--- Final Report for the Best Model on Validation Set ---")
val_preds, val_trues = predict(model, val_loader, device)
evaluate(val_preds, val_trues, id2label)


--- Final Report for the Best Model on Validation Set ---
📊 分类报告:
              precision    recall  f1-score   support

   AGE_DEATH       0.50      0.22      0.31         9
AGE_FOLLOWUP       0.60      0.60      0.60        10
   AGE_ONSET       0.42      0.56      0.48        18
        GENE       0.79      0.76      0.78        75
GENE_VARIANT       0.74      0.81      0.78       102
    HPO_TERM       0.89      0.88      0.89      1592
     PATIENT       0.66      0.76      0.71        80

   micro avg       0.86      0.86      0.86      1886
   macro avg       0.66      0.66      0.65      1886
weighted avg       0.86      0.86      0.86      1886


--- 总体性能指标 ---
Overall Precision: 0.8554
Overall Recall:    0.8685
Overall F1-Score:  0.8619
