In [3]:
import torch
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import (
    confusion_matrix,
    roc_auc_score,
    precision_recall_curve,
    auc
)
import glob
from dataset import *
from model import *

def test_model_on_9_proteins(model_path, protein_path, feature_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    sequences, labels = load_encoding_from_txt(protein_path)
    features = load_features_from_txt(feature_path)
    
    sequences = np.array(sequences)
    features = np.array(features)
    labels = np.array(labels)
    
    dataset = MyDataSet(sequences, features, labels)
    data_loader = Data.DataLoader(dataset, batch_size=9, shuffle=False)

    model = FusionPepNet().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    all_preds = []
    all_probs = []
    with torch.no_grad():
        for input_ids, seq_feat, lbl in data_loader:
            input_ids = input_ids.to(device)
            seq_feat = seq_feat.to(device)
            outputs, _, _ = model(input_ids, seq_feat)

            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            all_probs.extend(probs)
            all_preds.extend(preds)
    
    print("\n========= 9 Proteins Prediction Result =========")
    for i, (pred, prob, true_label) in enumerate(zip(all_preds, all_probs, labels)):
        print(f"Sample {i+1}: Pred={pred} | Prob={prob:.4f} | True={true_label}")
    
    num_positive = sum(all_preds)


In [None]:
if __name__ == "__main__":
    model_path = "fold_4_acc_0.9563.pth" 

    protein_path = ""
    feature_path = ""

    test_model_on_9_proteins(model_path, protein_path, feature_path)


  model.load_state_dict(torch.load(model_path, map_location=device))



Sample 1: Pred=1 | Prob=0.9581 | True=0
Sample 2: Pred=1 | Prob=0.9947 | True=0
Sample 3: Pred=1 | Prob=0.9701 | True=0
Sample 4: Pred=1 | Prob=0.9776 | True=0
Sample 5: Pred=1 | Prob=0.9396 | True=0
Sample 6: Pred=1 | Prob=0.8892 | True=0
Sample 7: Pred=1 | Prob=0.9806 | True=0
Sample 8: Pred=1 | Prob=0.9491 | True=0
Sample 9: Pred=1 | Prob=0.9018 | True=0

预测为正类的数量：9 / 9
预测为正类的比例：100.00%
