In [None]:
import os
import random
from typing import Dict, Any, List, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchcrf import CRF
from transformers import BertModel, AutoTokenizer, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report, precision_score, recall_score, f1_score

# ========== Reproducibility ==========
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# ========== Device ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ========== Data ==========
def load_data(file_path: str) -> Tuple[List[List[Tuple[str, str]]], dict, dict, List[str]]:
    """Read BIO format data: each line has a 'word tag', and sentences are separated by blank lines"""
    sentences = []
    sentence = []
    labels = set()

    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                if sentence:
                    sentences.append(sentence)
                    sentence = []
            else:
                parts = line.split()
                word = ' '.join(parts[:-1])
                tag = parts[-1]
                sentence.append((word, tag))
                if tag != 'O' and '-' in tag:
                    labels.add(tag.split('-')[1])

    if sentence:
        sentences.append(sentence)

    unique_tags = sorted({tag for sent in sentences for _, tag in sent} | {'O'})
    tag2idx = {tag: idx for idx, tag in enumerate(unique_tags)}
    idx2tag = {idx: tag for tag, idx in tag2idx.items()}
    return sentences, tag2idx, idx2tag, sorted(list(labels))


class NERDataset(Dataset):
    def __init__(self, sentences, tag2idx, tokenizer, max_len=128):
        self.sentences = sentences
        self.tag2idx = tag2idx
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.o_id = self.tag2idx['O']

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        words = [w for w, _ in sentence]
        tags = [t for _, t in sentence]

        encoding = self.tokenizer(
            words,
            is_split_into_words=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

    
        word_ids = encoding.word_ids(batch_index=0)

        labels = []
        prev_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                labels.append(self.o_id)  
            elif word_idx != prev_word_idx:
                labels.append(self.tag2idx[tags[word_idx]])
            else:
                # Subsequent subwords of the same word: If it is B-, it continues to I-
                if tags[word_idx].startswith('B-'):
                    ent = tags[word_idx].split('-')[1]
                    labels.append(self.tag2idx.get(f'I-{ent}', self.tag2idx[tags[word_idx]]))
                else:
                    labels.append(self.tag2idx[tags[word_idx]])
            prev_word_idx = word_idx

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(labels, dtype=torch.long)
        }


# ========== Model ==========
class BERT_CRF_NER(nn.Module):
    def __init__(self, bert_model: str, num_tags: int, dropout_prob: float = 0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_tags)
        self.crf = CRF(num_tags, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        emissions = self.classifier(self.dropout(outputs.last_hidden_state))
        mask = attention_mask.bool()

        if labels is not None:
            loss = -self.crf(emissions, labels, mask=mask, reduction='mean')
            return loss
        else:
            return self.crf.decode(emissions, mask=mask)


# ========== Train / Eval ==========
def train_one_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        loss = model(input_ids, attention_mask, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item()
    return total_loss / max(1, len(dataloader))


def evaluate(model, dataloader, device, idx2tag):
    """Return: micro-P/R/F1, ACC (entity-only), report_text"""
    model.eval()
    true_labels, pred_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()

            preds = model(input_ids, attention_mask)  # list[list[int]]

            for i in range(len(preds)):
                mask_i = attention_mask[i].cpu().numpy().astype(bool)
                valid_len = int(mask_i.sum())

                pred_tags = [idx2tag[p] for p in preds[i][:valid_len]]
                true_tags = [idx2tag[l] for l in labels[i][:valid_len]]

                m = min(len(pred_tags), len(true_tags))
                pred_labels.append(pred_tags[:m])
                true_labels.append(true_tags[:m])


    report_text = classification_report(true_labels, pred_labels, digits=2, zero_division=0)

    # micro 平均
    micro_p = precision_score(true_labels, pred_labels, average='micro', zero_division=0)
    micro_r = recall_score(true_labels, pred_labels, average='micro', zero_division=0)
    micro_f1 = f1_score(true_labels, pred_labels, average='micro', zero_division=0)


    correct = total = 0
    for t_seq, p_seq in zip(true_labels, pred_labels):
        for t, p in zip(t_seq, p_seq):
            if t != 'O':
                total += 1
                if t == p:
                    correct += 1
    acc = correct / total if total > 0 else 0.0

    return micro_p, micro_r, micro_f1, acc, report_text


# ========== Main ==========
def main():
    # ---- Config ----
    DATA_PATH = "bio_dataset_cleaned.txt"  
    BIOBERT_MODEL = "dmis-lab/biobert-base-cased-v1.1"
    MAX_LEN = 256
    BATCH_SIZE = 32
    EPOCHS = 20
    LR = 1e-4
    DROPOUT = 0.1
    SAVE_PATH = "best_BioBERT_model.pt"  

    # 1) Load data
    sentences, tag2idx, idx2tag, entity_types = load_data(DATA_PATH)
    print(f"Loaded {len(sentences)} sentences")
    print(f"Entity types: {entity_types}")
    print(f"Num tags: {len(tag2idx)}")

    # 2) Split
    train_sents, val_sents = train_test_split(sentences, test_size=0.2, random_state=SEED, shuffle=True)

    # 3) Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(BIOBERT_MODEL)

    # 4) Datasets / Loaders
    train_ds = NERDataset(train_sents, tag2idx, tokenizer, MAX_LEN)
    val_ds = NERDataset(val_sents, tag2idx, tokenizer, MAX_LEN)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    # 5) Model & Optimizer & Scheduler
    model = BERT_CRF_NER(BIOBERT_MODEL, len(tag2idx), DROPOUT).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    total_steps = len(train_loader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=max(1, int(0.1 * total_steps)),
        num_training_steps=total_steps
    )

    # 6) Train loop with best model tracking (by micro-F1)
    best_f1 = -1.0
    best_epoch = -1

    print("Starting training with BioBERT+CRF ...")
    for epoch in range(1, EPOCHS + 1):
        train_loss = train_one_epoch(model, train_loader, optimizer, scheduler, device)
        p, r, f1, acc, _ = evaluate(model, val_loader, device, idx2tag)

        # === 你要的输出格式 ===
        print(f"Epoch {epoch}/{EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Test Metrics: Precision={p:.4f}, Recall={r:.4f}, F1={f1:.4f}, ACC={acc:.4f}")
        print("-" * 50)

        if f1 > best_f1:
            best_f1 = f1
            best_epoch = epoch
            torch.save(model.state_dict(), SAVE_PATH)

    print(f"Best epoch = {best_epoch}, micro-F1 = {best_f1:.4f}")
    print(f"Best model saved as {SAVE_PATH}")

    # 7) Final report with full table + overall indicators (micro)
    print("Final Evaluation on Test Set:")
    p, r, f1, acc, report_text = evaluate(model, val_loader, device, idx2tag)
    print(report_text)
    print("\n--- Overall performance indicators ---")
    print(f"Overall Precision: {p:.4f}")
    print(f"Overall Recall:    {r:.4f}")
    print(f"Overall F1-Score:  {f1:.4f}")
    print(f"Entity-only Accuracy (ACC): {acc:.4f}")


if __name__ == "__main__":
    main()


In [None]:
# -*Gold+Silver*-
import os
import warnings
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizerFast
from torchcrf import CRF
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report, precision_score, recall_score, f1_score

warnings.filterwarnings("ignore")  

# ========== Configure==========
FOCUS_SINGLE_ENTITY = True
TARGET_ENTITY = "HPO_TERM"

GOLD_BIO   = "bio_dataset_cleaned.txt"
SILVER_BIO = "chatgpt_integrated_bio_no_punctuation.txt"  

BERT_MODEL    = "dmis-lab/biobert-base-cased-v1.1"
MAX_LEN       = 256
BATCH_SIZE    = 32
EPOCHS        = 20
LEARNING_RATE = 1e-4
DROPOUT_PROB  = 0.1
EARLY_STOP_PATIENCE = 5  

 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========== data load ==========
def load_sentences_only(file_path):
    """Read BIO file -> List[List[(word, tag)]], and the original tag set appears"""
    sentences, sentence = [], []
    tags_seen = set()
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                if sentence:
                    sentences.append(sentence)
                    sentence = []
            else:
                parts = line.split()
                word = ' '.join(parts[:-1])
                tag = parts[-1]
                sentence.append((word, tag))
                tags_seen.add(tag)
    if sentence:
        sentences.append(sentence)
    return sentences, tags_seen

def remap_to_single_entity(sentences, target="HPO_TERM"):
    """Set all B-*/I-* labels except target to O; keep O and B/I-target."""
    out = []
    for sent in sentences:
        new_sent = []
        for w, t in sent:
            if t == "O":
                new_sent.append((w, "O"))
            else:
                if "-" in t:
                    bi, ent = t.split("-", 1)
                    if ent == target and bi in ("B", "I"):
                        new_sent.append((w, f"{bi}-{target}"))
                    else:
                        new_sent.append((w, "O"))
                else:
                    new_sent.append((w, "O"))
        out.append(new_sent)
    return out

# ========== Dataset ==========
class NERDataset(Dataset):
    def __init__(self, sentences, tag2idx, tokenizer, max_len=128):
        self.sentences = sentences
        self.tag2idx = tag2idx
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.sentences)
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        words = [w for w, _ in sentence]
        tags  = [t for _, t in sentence]
        enc = self.tokenizer(
            words,
            is_split_into_words=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        # Word ID mapping (batch_index must be specified in batch mode)
        word_ids = enc.word_ids(batch_index=0)

        labels = []
        prev_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:  # [CLS]/[SEP]/PAD
                labels.append(self.tag2idx['O'])   # Use 'O' as the id, rather than hard-coding 0
            elif word_idx != prev_word_idx:        
                labels.append(self.tag2idx[tags[word_idx]])
            else:                                  
                if tags[word_idx].startswith('B-'):
                    labels.append(self.tag2idx['I-' + tags[word_idx].split('-')[1]])
                else:
                    labels.append(self.tag2idx[tags[word_idx]])
            prev_word_idx = word_idx

        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'labels': torch.tensor(labels, dtype=torch.long)
        }

# ========== Model ==========
class BERT_CRF_NER(nn.Module):
    def __init__(self, bert_model, num_tags, dropout_prob=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_tags)
        self.crf = CRF(num_tags, batch_first=True)
    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        emissions = self.classifier(self.dropout(outputs.last_hidden_state))
        mask = attention_mask.bool()
        if labels is not None:
            loss = -self.crf(emissions, labels, mask=mask, reduction='mean')
            return loss
        else:
            return self.crf.decode(emissions, mask=mask)

# ========== Train ==========
def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        optimizer.zero_grad()
        loss = model(input_ids, attention_mask, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / max(1, len(dataloader))

# ========== Evalu ==========
def _collect_true_pred(model, dataloader, device, idx2tag):
    model.eval()
    true_labels, pred_labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()
            preds = model(input_ids, attention_mask)  # List[List[int]]
            for i in range(len(preds)):
                mask = attention_mask[i].cpu().numpy().astype(bool)
                valid_len = int(mask.sum())
                pred_tags = [idx2tag[p] for p in preds[i][:valid_len]]
                true_tags = [idx2tag[int(l)] for l in labels[i][:valid_len]]
                L = min(len(true_tags), len(pred_tags))
                true_labels.append(true_tags[:L])
                pred_labels.append(pred_tags[:L])
    return true_labels, pred_labels

def evaluate(model, dataloader, device, idx2tag):
    true_labels, pred_labels = _collect_true_pred(model, dataloader, device, idx2tag)


    report_text = classification_report(true_labels, pred_labels, digits=2, zero_division=0)

    # Overall
    precision = precision_score(true_labels, pred_labels, average='micro', zero_division=0)
    recall    = recall_score(true_labels, pred_labels, average='micro', zero_division=0)
    f1        = f1_score(true_labels, pred_labels, average='micro', zero_division=0)

    correct = total = 0
    for t_seq, p_seq in zip(true_labels, pred_labels):
        for t, p in zip(t_seq, p_seq):
            if t != 'O':
                total += 1
                if t == p:
                    correct += 1
    acc = correct / total if total else 0.0
    return precision, recall, f1, acc, report_text

# ========== main ==========
if __name__ == "__main__":
    
    gold_sents, gold_tags = load_sentences_only(GOLD_BIO)
    silver_sents, silver_tags = load_sentences_only(SILVER_BIO)

    print(f"[GOLD] sentences:   {len(gold_sents)}")
    print(f"[SILVER] sentences: {len(silver_sents)}")

    # 2) Focus only on HPO_TERM: map all non-target entities to O
    if FOCUS_SINGLE_ENTITY:
        gold_sents   = remap_to_single_entity(gold_sents, TARGET_ENTITY)
        silver_sents = remap_to_single_entity(silver_sents, TARGET_ENTITY)

    # 3) GOLD 80/20 (20% as validation set)
    train_gold, val_gold = train_test_split(gold_sents, test_size=0.2, random_state=42, shuffle=True)

    # 4) Training set = gold80% + all silver; validation set = gold20%
    train_sents = train_gold + silver_sents
    val_sents   = val_gold
    print(f"[TRAIN] {len(train_sents)} (= gold80% + silver100%)")
    print(f"[VAL]   {len(val_sents)}   (= gold20%)")


    if FOCUS_SINGLE_ENTITY:
        unique_tags = ["O", f"B-{TARGET_ENTITY}", f"I-{TARGET_ENTITY}"]
    else:
        all_tags = (gold_tags | silver_tags | {"O"})
        unique_tags = sorted(all_tags)

    tag2idx = {tag: i for i, tag in enumerate(unique_tags)}
    idx2tag = {i: tag for tag, i in tag2idx.items()}
    print("Tags:", unique_tags)

    # 6) DataLoader
    tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL)
    train_dataset = NERDataset(train_sents, tag2idx, tokenizer, MAX_LEN)
    val_dataset   = NERDataset(val_sents,   tag2idx, tokenizer, MAX_LEN)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)


    model = BERT_CRF_NER(BERT_MODEL, len(tag2idx), DROPOUT_PROB).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

 
    os.makedirs("reports", exist_ok=True)

    # 8) Training + Validation Monitoring + Best Save + Early Stopping
    best_f1 = 0.0
    best_epoch = 0
    no_improve = 0

    print("Starting training...")
    for epoch in range(1, EPOCHS + 1):
        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        precision, recall, f1, acc, report_text = evaluate(model, val_loader, device, idx2tag)


        print(f"Epoch {epoch}/{EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Metrics: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, ACC={acc:.4f}")
        print("-" * 50)

        
        with open(f"reports/epoch_{epoch:02d}.txt", "w", encoding="utf-8") as f:
            f.write(f"Classification Report (per-class) - Epoch {epoch}\n")
            f.write(report_text)
            f.write("\n\n--- Overall (micro) ---\n")
            f.write(f"Precision: {precision:.4f}\n")
            f.write(f"Recall:    {recall:.4f}\n")
            f.write(f"F1:        {f1:.4f}\n")
            f.write(f"ACC(entity-only): {acc:.4f}\n")

        # — Best save & early stop count —
        if f1 > best_f1:
            best_f1 = f1
            best_epoch = epoch
            no_improve = 0
            torch.save(model.state_dict(), "best_BioBERT_model_gold_silver.pt")
            torch.save({"tag2idx": tag2idx, "idx2tag": idx2tag}, "label_mapping_gold_silver.pt")
        else:
            no_improve += 1
            if EARLY_STOP_PATIENCE > 0 and no_improve >= EARLY_STOP_PATIENCE:
                print(f"Early stopping at epoch {epoch} (no improvement for {EARLY_STOP_PATIENCE} epochs).")
                break

    print(f"Best epoch = {best_epoch}, micro-F1 = {best_f1:.4f}")
    print("Best model saved as best_BioBERT_model.pt")

    # 9) 
    print("Final Evaluation on Validation (20% GOLD):")
    precision, recall, f1, acc, report_text = evaluate(model, val_loader, device, idx2tag)
    print(report_text)
    print("\n--- Overall performance indicators ---")
    print(f"Overall Precision: {precision:.4f}")
    print(f"Overall Recall:    {recall:.4f}")
    print(f"Overall F1-Score:  {f1:.4f}")
    print(f"Entity-only Accuracy (ACC): {acc:.4f}")

    with open("reports/final_val_report.txt", "w", encoding="utf-8") as f:
        f.write(report_text)
        f.write("\n\n--- Overall performance indicators ---\n")
        f.write(f"Overall Precision: {precision:.4f}\n")
        f.write(f"Overall Recall:    {recall:.4f}\n")
        f.write(f"Overall F1-Score:  {f1:.4f}\n")
        f.write(f"Entity-only Accuracy (ACC): {acc:.4f}\n")
