In [51]:
import os
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from torch import nn

In [60]:
# 设定设备
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device being used: {device}")

# 文件路径
data_path = "../data/processed/test_data.csv"
model_path = "../src/model/best_model_batch_2.pt"

Device being used: mps


In [61]:
# 读取数据
df = pd.read_csv(data_path)
print(f"Test dataset size: {df.shape}")

# 打印样本数据
print(df[['PMID', 'Text_combined', 'Terms']].head())


Test dataset size: (20, 6)
       PMID                                      Text_combined  \
0  27798626  Stabilizing mutations of KLHL24 ubiquitin liga...   
1  20015111  GGA autoinhibition revisited. The cytosolic ad...   
2   2068106  Bacterial chemotaxis signaling complexes: form...   
3  14722083  Comparative analyses of the three-dimensional ...   
4  23449916  Biochemical analysis of three putative KaiC cl...   

                 Terms  
0   autoubiquitination  
1       autoinhibition  
2  autophosphorylation  
3  autophosphorylation  
4  autophosphorylation  


In [62]:
# 固定标签顺序，与模型训练时保持一致
model_labels = [
    "non-autoregulatory", "autophosphorylation", "autocatalytic", "autoregulation",
    "autoubiquitination", "autoinhibition", "autoregulatory", "autoinducer",
    "autolysis", "autoinhibitory", "autoactivation", "autocatalysis",
    "autofeedback", "autoinduction", "autokinase"
]

print(f"Model Labels (15 fixed): {model_labels}")
print(f"Number of labels in model: {len(model_labels)}")


Model Labels (15 fixed): ['non-autoregulatory', 'autophosphorylation', 'autocatalytic', 'autoregulation', 'autoubiquitination', 'autoinhibition', 'autoregulatory', 'autoinducer', 'autolysis', 'autoinhibitory', 'autoactivation', 'autocatalysis', 'autofeedback', 'autoinduction', 'autokinase']
Number of labels in model: 15


In [63]:
# 设备映射：优先加载到 CPU，然后再转移到 MPS
map_location = "cpu"

try:
    # 加载模型权重到 CPU
    state_dict = torch.load(model_path, map_location=map_location)

    # 初始化模型，输出层大小固定为 15
    model = PubMedBERTClassifier(n_classes=15)
    model.load_state_dict(state_dict)

    # 将模型转移到 MPS
    model = model.to(device)
    model.eval()

    print(f"Model successfully loaded to {device}.")

except RuntimeError as e:
    print(f"Error loading model: {e}")


  state_dict = torch.load(model_path, map_location=map_location)


Model successfully loaded to mps.


In [64]:
# 初始化 tokenizer
tokenizer = AutoTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')

def preprocess_text(text, tokenizer, max_length=512):
    """
    预处理输入文本
    """
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    return encoding


In [65]:
# 设置阈值
threshold = 0.5

# 从数据集中获取实际出现的标签
actual_labels_set = df['Terms'].dropna().unique().tolist()
actual_labels_set.insert(0, "non-autoregulatory")
actual_labels_set = sorted(set(actual_labels_set))


# 逐行进行预测
for idx, row in df.iterrows():
    # 使用 Text_combined 列进行预测
    text = str(row['Text_combined'])
    actual_label = row['Terms']
    
    # 预处理
    encoding = preprocess_text(text, tokenizer)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # 模型预测
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        probabilities = torch.sigmoid(outputs).squeeze().cpu().numpy()

    # 检查形状匹配性
    if len(probabilities) != len(model_labels):
        print(f"Shape mismatch: Probabilities length = {len(probabilities)}, Label names length = {len(model_labels)}")
        continue

    # 映射回标签（仅输出实际出现的标签）
    predicted_labels = [
        model_labels[i] for i, prob in enumerate(probabilities)
        if prob >= threshold and model_labels[i] in actual_labels_set
    ]
    
    # 打印预测结果
    print(f"PMID: {row['PMID']} | Actual Label: {actual_label} | Predicted Labels: {predicted_labels}")


PMID: 27798626 | Actual Label: autoubiquitination | Predicted Labels: []
PMID: 20015111 | Actual Label: autoinhibition | Predicted Labels: ['autoregulatory']
PMID: 2068106 | Actual Label: autophosphorylation | Predicted Labels: []
PMID: 14722083 | Actual Label: autophosphorylation | Predicted Labels: []
PMID: 23449916 | Actual Label: autophosphorylation | Predicted Labels: []
PMID: 20519438 | Actual Label: autoinhibition | Predicted Labels: ['autoregulatory']
PMID: 22216903 | Actual Label: autophosphorylation | Predicted Labels: ['autophosphorylation']
PMID: 9856465 | Actual Label: autoregulatory | Predicted Labels: []
PMID: 7871721 | Actual Label: autocatalytic | Predicted Labels: ['autocatalytic']
PMID: 19812038 | Actual Label: autophosphorylation | Predicted Labels: []
PMID: 11889109 | Actual Label: nan | Predicted Labels: []
PMID: 19690332 | Actual Label: nan | Predicted Labels: []
PMID: 15815621 | Actual Label: nan | Predicted Labels: []
PMID: 12384590 | Actual Label: nan | Predic