In [None]:
# -*- coding: utf-8 -*-
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizerFast
from torchcrf import CRF
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report, precision_score, recall_score, f1_score
import numpy as np

# ========== 设备 ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========== 数据读取 ==========
def load_data(file_path):
    sentences = []
    sentence = []
    labels = set()
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                if sentence:
                    sentences.append(sentence)
                    sentence = []
            else:
                parts = line.split()
                word = ' '.join(parts[:-1])
                tag = parts[-1]
                sentence.append((word, tag))
                if tag != 'O':
                    entity_type = tag.split('-')[1]
                    labels.add(entity_type)
    if sentence:
        sentences.append(sentence)
    # 标签映射
    unique_tags = sorted({tag for sent in sentences for _, tag in sent} | {'O'})
    tag2idx = {tag: idx for idx, tag in enumerate(unique_tags)}
    idx2tag = {idx: tag for tag, idx in tag2idx.items()}
    return sentences, tag2idx, idx2tag, list(labels)

# ========== 数据集 ==========
class NERDataset(Dataset):
    def __init__(self, sentences, tag2idx, tokenizer, max_len=128):
        self.sentences = sentences
        self.tag2idx = tag2idx
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.sentences)
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        words = [w for w, _ in sentence]
        tags  = [t for _, t in sentence]
        enc = self.tokenizer(
            words,
            is_split_into_words=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        # 注意：batch 模式下要指定 batch_index
        word_ids = enc.word_ids(batch_index=0)

        labels = []
        prev_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:  # [CLS]/[SEP]/PAD
                labels.append(0)
            elif word_idx != prev_word_idx:  # 新词首 subword
                labels.append(self.tag2idx[tags[word_idx]])
            else:  # 同一词的后续 subword
                if tags[word_idx].startswith('B-'):
                    labels.append(self.tag2idx['I-' + tags[word_idx].split('-')[1]])
                else:
                    labels.append(self.tag2idx[tags[word_idx]])
            prev_word_idx = word_idx

        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'labels': torch.tensor(labels, dtype=torch.long)
        }

# ========== 模型 ==========
class BERT_CRF_NER(nn.Module):
    def __init__(self, bert_model, num_tags, dropout_prob=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_tags)
        self.crf = CRF(num_tags, batch_first=True)
    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        emissions = self.classifier(self.dropout(outputs.last_hidden_state))
        mask = attention_mask.bool()
        if labels is not None:
            loss = -self.crf(emissions, labels, mask=mask, reduction='mean')
            return loss
        else:
            return self.crf.decode(emissions, mask=mask)

# ========== 训练 ==========
def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        optimizer.zero_grad()
        loss = model(input_ids, attention_mask, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / max(1, len(dataloader))

# ========== 评估：收集预测 + 计算指标 + 生成表格 ==========
def _collect_true_pred(model, dataloader, device, idx2tag):
    model.eval()
    true_labels, pred_labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()
            preds = model(input_ids, attention_mask)  # List[List[int]]
            for i in range(len(preds)):
                mask = attention_mask[i].cpu().numpy().astype(bool)
                valid_len = int(mask.sum())
                pred_tags = [idx2tag[p] for p in preds[i][:valid_len]]
                true_tags = [idx2tag[int(l)] for l in labels[i][:valid_len]]
                L = min(len(true_tags), len(pred_tags))
                true_labels.append(true_tags[:L])
                pred_labels.append(pred_tags[:L])
    return true_labels, pred_labels

def evaluate(model, dataloader, device, idx2tag):
    true_labels, pred_labels = _collect_true_pred(model, dataloader, device, idx2tag)

    # 每类表格（含 micro/macro/weighted）
    report_text = classification_report(true_labels, pred_labels, digits=2)

    # Overall（micro 平均）
    precision = precision_score(true_labels, pred_labels, average='micro')
    recall    = recall_score(true_labels, pred_labels, average='micro')
    f1        = f1_score(true_labels, pred_labels, average='micro')

    # 实体位 ACC（不含 O）
    correct = total = 0
    for t_seq, p_seq in zip(true_labels, pred_labels):
        for t, p in zip(t_seq, p_seq):
            if t != 'O':
                total += 1
                if t == p:
                    correct += 1
    acc = correct / total if total else 0.0
    return precision, recall, f1, acc, report_text

# ========== 主程序 ==========
if __name__ == "__main__":
    # 参数
    BERT_MODEL   = "bert-base-uncased"
    MAX_LEN      = 256
    BATCH_SIZE   = 32
    EPOCHS       = 20
    LEARNING_RATE= 1e-4
    DROPOUT_PROB = 0.1
    DATA_FILE    = "bio_dataset_cleaned.txt"  

    # 1) 读数据
    sentences, tag2idx, idx2tag, entity_types = load_data(DATA_FILE)
    print(f"Loaded {len(sentences)} sentences")
    print(f"Entity types: {entity_types}")
    print(f"Tag mapping: {tag2idx}")

    # 2) 划分（这里把 test 当作“验证集”使用）
    train_sents, val_sents = train_test_split(sentences, test_size=0.2, random_state=42)

    # 3) DataLoader
    tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL)
    train_dataset = NERDataset(train_sents, tag2idx, tokenizer, MAX_LEN)
    val_dataset   = NERDataset(val_sents,   tag2idx, tokenizer, MAX_LEN)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)

    # 4) 模型与优化器
    model = BERT_CRF_NER(BERT_MODEL, len(tag2idx), DROPOUT_PROB).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # 结果目录
    os.makedirs("reports", exist_ok=True)

    # 5) 训练 + 每轮评估 + 保存最佳
    best_f1 = 0.0
    best_epoch = 0

    print("Starting training...")
    for epoch in range(1, EPOCHS + 1):
        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        precision, recall, f1, acc, report_text = evaluate(model, val_loader, device, idx2tag)

        # —— 控制台打印“简要四项”（保持你原有风格）——
        print(f"Epoch {epoch}/{EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Test Metrics: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, ACC={acc:.4f}")
        print("-" * 50)

        # —— 按类表格写入文件（不刷屏）——
        with open(f"reports/epoch_{epoch:02d}.txt", "w", encoding="utf-8") as f:
            f.write(f"Classification Report (per-class) - Epoch {epoch}\n")
            f.write(report_text)
            f.write("\n\n--- Overall (micro) ---\n")
            f.write(f"Precision: {precision:.4f}\n")
            f.write(f"Recall:    {recall:.4f}\n")
            f.write(f"F1:        {f1:.4f}\n")
            f.write(f"ACC(entity-only): {acc:.4f}\n")

        # —— 保存最佳（以 micro-F1 为准）——
        if f1 > best_f1:
            best_f1 = f1
            best_epoch = epoch
            torch.save(model.state_dict(), "best_model.pt")
            # 同时保存标签映射，方便后续加载预测
            torch.save({"tag2idx": tag2idx, "idx2tag": idx2tag}, "label_mapping.pt")

    print(f"Best epoch = {best_epoch}, micro-F1 = {best_f1:.4f}")
    print("Best model saved as best_model.pt")

    # 6) 最终评估（再次打印“按类表格”到控制台并保存）
    print("Final Evaluation on Test Set:")
    precision, recall, f1, acc, report_text = evaluate(model, val_loader, device, idx2tag)
    print(report_text)
    print("\n--- Overall performance indicators ---")
    print(f"Overall Precision: {precision:.4f}")
    print(f"Overall Recall:    {recall:.4f}")
    print(f"Overall F1-Score:  {f1:.4f}")
    print(f"Entity-only Accuracy (ACC): {acc:.4f}")

    with open("reports/final_test_report.txt", "w", encoding="utf-8") as f:
        f.write(report_text)
        f.write("\n\n--- Overall performance indicators ---\n")
        f.write(f"Overall Precision: {precision:.4f}\n")
        f.write(f"Overall Recall:    {recall:.4f}\n")
        f.write(f"Overall F1-Score:  {f1:.4f}\n")
        f.write(f"Entity-only Accuracy (ACC): {acc:.4f}\n")


In [None]:
# -*- coding: utf-8 -*-
import os
import warnings
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizerFast
from torchcrf import CRF
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report, precision_score, recall_score, f1_score

warnings.filterwarnings("ignore")  # 清爽输出；我们也在报表里用 zero_division=0

# ========== 配置 ==========
FOCUS_SINGLE_ENTITY = True
TARGET_ENTITY = "HPO_TERM"

# 路径：改成你的真实文件
GOLD_BIO   = "bio_dataset_cleaned.txt"
SILVER_BIO = "chatgpt_integrated_bio_no_punctuation.txt"  # 如果银数据在别处，改这里

# 训练超参
BERT_MODEL    = "bert-base-uncased"
MAX_LEN       = 256
BATCH_SIZE    = 32
EPOCHS        = 20
LEARNING_RATE = 1e-4
DROPOUT_PROB  = 0.1
EARLY_STOP_PATIENCE = 5  # 若验证 F1 连续 patience 轮未提升则早停；设为0关闭早停

# ========== 设备 ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========== 数据读取 ==========
def load_sentences_only(file_path):
    """读取 BIO 文件 -> List[List[(word, tag)]], 以及出现的原始标签集合"""
    sentences, sentence = [], []
    tags_seen = set()
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                if sentence:
                    sentences.append(sentence)
                    sentence = []
            else:
                parts = line.split()
                word = ' '.join(parts[:-1])
                tag = parts[-1]
                sentence.append((word, tag))
                tags_seen.add(tag)
    if sentence:
        sentences.append(sentence)
    return sentences, tags_seen

def remap_to_single_entity(sentences, target="HPO_TERM"):
    """将除 target 外所有 B-*/I-* 标签置为 O；保留 O 与 B/I-target。"""
    out = []
    for sent in sentences:
        new_sent = []
        for w, t in sent:
            if t == "O":
                new_sent.append((w, "O"))
            else:
                if "-" in t:
                    bi, ent = t.split("-", 1)
                    if ent == target and bi in ("B", "I"):
                        new_sent.append((w, f"{bi}-{target}"))
                    else:
                        new_sent.append((w, "O"))
                else:
                    new_sent.append((w, "O"))
        out.append(new_sent)
    return out

# ========== 数据集 ==========
class NERDataset(Dataset):
    def __init__(self, sentences, tag2idx, tokenizer, max_len=128):
        self.sentences = sentences
        self.tag2idx = tag2idx
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.sentences)
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        words = [w for w, _ in sentence]
        tags  = [t for _, t in sentence]
        enc = self.tokenizer(
            words,
            is_split_into_words=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        # word id 映射（batch 模式需指定 batch_index）
        word_ids = enc.word_ids(batch_index=0)

        labels = []
        prev_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:  # [CLS]/[SEP]/PAD
                labels.append(self.tag2idx['O'])   # 用 'O' 的 id，而不是写死 0
            elif word_idx != prev_word_idx:        # 新词首 subword
                labels.append(self.tag2idx[tags[word_idx]])
            else:                                   # 同一词的后续 subword
                if tags[word_idx].startswith('B-'):
                    labels.append(self.tag2idx['I-' + tags[word_idx].split('-')[1]])
                else:
                    labels.append(self.tag2idx[tags[word_idx]])
            prev_word_idx = word_idx

        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'labels': torch.tensor(labels, dtype=torch.long)
        }

# ========== 模型 ==========
class BERT_CRF_NER(nn.Module):
    def __init__(self, bert_model, num_tags, dropout_prob=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_tags)
        self.crf = CRF(num_tags, batch_first=True)
    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        emissions = self.classifier(self.dropout(outputs.last_hidden_state))
        mask = attention_mask.bool()
        if labels is not None:
            loss = -self.crf(emissions, labels, mask=mask, reduction='mean')
            return loss
        else:
            return self.crf.decode(emissions, mask=mask)

# ========== 训练 ==========
def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        optimizer.zero_grad()
        loss = model(input_ids, attention_mask, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / max(1, len(dataloader))

# ========== 评估 ==========
def _collect_true_pred(model, dataloader, device, idx2tag):
    model.eval()
    true_labels, pred_labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()
            preds = model(input_ids, attention_mask)  # List[List[int]]
            for i in range(len(preds)):
                mask = attention_mask[i].cpu().numpy().astype(bool)
                valid_len = int(mask.sum())
                pred_tags = [idx2tag[p] for p in preds[i][:valid_len]]
                true_tags = [idx2tag[int(l)] for l in labels[i][:valid_len]]
                L = min(len(true_tags), len(pred_tags))
                true_labels.append(true_tags[:L])
                pred_labels.append(pred_tags[:L])
    return true_labels, pred_labels

def evaluate(model, dataloader, device, idx2tag):
    true_labels, pred_labels = _collect_true_pred(model, dataloader, device, idx2tag)

    # 每类表格（含 micro/macro/weighted），避免 0 分告警
    report_text = classification_report(true_labels, pred_labels, digits=2, zero_division=0)

    # Overall（micro 平均）
    precision = precision_score(true_labels, pred_labels, average='micro', zero_division=0)
    recall    = recall_score(true_labels, pred_labels, average='micro', zero_division=0)
    f1        = f1_score(true_labels, pred_labels, average='micro', zero_division=0)

    # 实体位 ACC（不含 O）
    correct = total = 0
    for t_seq, p_seq in zip(true_labels, pred_labels):
        for t, p in zip(t_seq, p_seq):
            if t != 'O':
                total += 1
                if t == p:
                    correct += 1
    acc = correct / total if total else 0.0
    return precision, recall, f1, acc, report_text

# ========== 主程序 ==========
if __name__ == "__main__":
    # 1) 分别读取 GOLD / SILVER
    gold_sents, gold_tags = load_sentences_only(GOLD_BIO)
    silver_sents, silver_tags = load_sentences_only(SILVER_BIO)

    print(f"[GOLD] sentences:   {len(gold_sents)}")
    print(f"[SILVER] sentences: {len(silver_sents)}")

    # 2) 只关注 HPO_TERM：把非目标实体都映射为 O
    if FOCUS_SINGLE_ENTITY:
        gold_sents   = remap_to_single_entity(gold_sents, TARGET_ENTITY)
        silver_sents = remap_to_single_entity(silver_sents, TARGET_ENTITY)

    # 3) GOLD 80/20（20% 作为验证集）
    train_gold, val_gold = train_test_split(gold_sents, test_size=0.2, random_state=42, shuffle=True)

    # 4) 训练集 = gold80% + 全 silver；验证集 = gold20%
    train_sents = train_gold + silver_sents
    val_sents   = val_gold
    print(f"[TRAIN] {len(train_sents)} (= gold80% + silver100%)")
    print(f"[VAL]   {len(val_sents)}   (= gold20%)")

    # 5) 标签空间（只保留 HPO_TERM 三个标签）
    if FOCUS_SINGLE_ENTITY:
        unique_tags = ["O", f"B-{TARGET_ENTITY}", f"I-{TARGET_ENTITY}"]
    else:
        all_tags = (gold_tags | silver_tags | {"O"})
        unique_tags = sorted(all_tags)

    tag2idx = {tag: i for i, tag in enumerate(unique_tags)}
    idx2tag = {i: tag for tag, i in tag2idx.items()}
    print("Tags:", unique_tags)

    # 6) DataLoader
    tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL)
    train_dataset = NERDataset(train_sents, tag2idx, tokenizer, MAX_LEN)
    val_dataset   = NERDataset(val_sents,   tag2idx, tokenizer, MAX_LEN)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)

    # 7) 模型与优化器
    model = BERT_CRF_NER(BERT_MODEL, len(tag2idx), DROPOUT_PROB).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # 结果目录
    os.makedirs("reports", exist_ok=True)

    # 8) 训练 + 验证监控 + 最佳保存 + 早停
    best_f1 = 0.0
    best_epoch = 0
    no_improve = 0

    print("Starting training...")
    for epoch in range(1, EPOCHS + 1):
        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        precision, recall, f1, acc, report_text = evaluate(model, val_loader, device, idx2tag)

        # —— 控制台打印“简要四项” ——（与原风格一致）
        print(f"Epoch {epoch}/{EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Metrics: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, ACC={acc:.4f}")
        print("-" * 50)

        # —— 保存当轮的按类报告 ——（只有 HPO_TERM 一类 + micro/macro/weighted）
        with open(f"reports/epoch_{epoch:02d}.txt", "w", encoding="utf-8") as f:
            f.write(f"Classification Report (per-class) - Epoch {epoch}\n")
            f.write(report_text)
            f.write("\n\n--- Overall (micro) ---\n")
            f.write(f"Precision: {precision:.4f}\n")
            f.write(f"Recall:    {recall:.4f}\n")
            f.write(f"F1:        {f1:.4f}\n")
            f.write(f"ACC(entity-only): {acc:.4f}\n")

        # —— 最佳保存 & 早停计数 ——
        if f1 > best_f1:
            best_f1 = f1
            best_epoch = epoch
            no_improve = 0
            torch.save(model.state_dict(), "best_model.pt")
            torch.save({"tag2idx": tag2idx, "idx2tag": idx2tag}, "label_mapping.pt")
        else:
            no_improve += 1
            if EARLY_STOP_PATIENCE > 0 and no_improve >= EARLY_STOP_PATIENCE:
                print(f"Early stopping at epoch {epoch} (no improvement for {EARLY_STOP_PATIENCE} epochs).")
                break

    print(f"Best epoch = {best_epoch}, micro-F1 = {best_f1:.4f}")
    print("Best model saved as best_model.pt")

    # 9) 最终评估（打印并保存报告）
    print("Final Evaluation on Validation (20% GOLD):")
    precision, recall, f1, acc, report_text = evaluate(model, val_loader, device, idx2tag)
    print(report_text)
    print("\n--- Overall performance indicators ---")
    print(f"Overall Precision: {precision:.4f}")
    print(f"Overall Recall:    {recall:.4f}")
    print(f"Overall F1-Score:  {f1:.4f}")
    print(f"Entity-only Accuracy (ACC): {acc:.4f}")

    with open("reports/final_val_report.txt", "w", encoding="utf-8") as f:
        f.write(report_text)
        f.write("\n\n--- Overall performance indicators ---\n")
        f.write(f"Overall Precision: {precision:.4f}\n")
        f.write(f"Overall Recall:    {recall:.4f}\n")
        f.write(f"Overall F1-Score:  {f1:.4f}\n")
        f.write(f"Entity-only Accuracy (ACC): {acc:.4f}\n")
