# BiLSTM 二分类模型 - 批量预测

加载训练好的模型，对FASTA文件进行批量ARG预测


In [None]:
import os
import glob
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from Bio import SeqIO


## 1. 配置


In [None]:
# ================== 路径配置（请修改）==================
MODEL_PATH = ""  # 训练保存的模型
INPUT_DIR = ""               # 输入FASTA文件夹
OUTPUT_DIR = "./predicted_results"                   # 输出结果文件夹
THRESHOLD = 0.5                                      # 预测阈值

# 设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 创建输出目录
os.makedirs(OUTPUT_DIR, exist_ok=True)


## 2. 模型与数据处理定义

**注意**: 这里的定义必须与训练代码完全一致！


In [None]:
# 氨基酸编码字典（必须与训练一致）
AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWY'
AA_DICT = {aa: i + 1 for i, aa in enumerate(AMINO_ACIDS)}
AA_DICT.update({'X': 21, 'PAD': 0})

def seq_to_indices(sequence, max_length):
    """将氨基酸序列转换为索引数组"""
    indices = [AA_DICT.get(aa, 21) for aa in sequence]
    indices = indices[:max_length]
    if len(indices) < max_length:
        indices += [0] * (max_length - len(indices))
    return np.array(indices, dtype=np.int64)


class BiLSTMModel(nn.Module):
    """BiLSTM + Global Pooling 二分类模型（必须与训练一致）"""
    
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(
            config['vocab_size'], 
            config['embedding_dim'], 
            padding_idx=0
        )
        self.lstm = nn.LSTM(
            input_size=config['embedding_dim'],
            hidden_size=config['hidden_size'],
            num_layers=config['num_layers'],
            batch_first=True,
            bidirectional=True,
            dropout=config['dropout'] if config['num_layers'] > 1 else 0
        )
        self.dropout = nn.Dropout(config['dropout'])
        self.classifier = nn.Sequential(
            nn.Linear(config['hidden_size'] * 4, config['hidden_size']),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(config['hidden_size'], 1)
        )

    def forward(self, x):
        emb = self.embedding(x)
        output, _ = self.lstm(emb)
        max_pool, _ = torch.max(output, dim=1)
        avg_pool = torch.mean(output, dim=1)
        features = torch.cat([max_pool, avg_pool], dim=1)
        return self.classifier(self.dropout(features))


class InferenceDataset(Dataset):
    """推理用数据集"""
    def __init__(self, fasta_file, max_length):
        self.records = list(SeqIO.parse(fasta_file, "fasta"))
        self.max_length = max_length

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        seq_str = str(self.records[idx].seq).upper()
        return torch.from_numpy(seq_to_indices(seq_str, self.max_length))


## 3. 加载模型


In [None]:
# 加载模型（自动读取保存的配置）
print(f"Loading model from: {MODEL_PATH}")
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
config = checkpoint['config']
print(f"Model config: {config}")

model = BiLSTMModel(config).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully!")


## 4. 批量预测


In [None]:
def predict_file(model, input_path, output_path, config, threshold):
    """对单个FASTA文件进行预测"""
    dataset = InferenceDataset(input_path, config['max_length'])
    if len(dataset) == 0:
        return 0
    
    loader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=4)
    arg_records = []
    
    with torch.no_grad():
        for batch_i, inputs in enumerate(loader):
            inputs = inputs.to(device)
            logits = model(inputs)
            probs = torch.sigmoid(logits).cpu().numpy().flatten()
            
            for i, prob in enumerate(probs):
                if prob > threshold:
                    global_idx = batch_i * loader.batch_size + i
                    record = dataset.records[global_idx]
                    record.description += f" [ARG_prob={prob:.3f}]"
                    arg_records.append(record)
    
    if arg_records:
        SeqIO.write(arg_records, output_path, "fasta")
    
    return len(arg_records)


# 获取所有待处理文件
all_files = glob.glob(os.path.join(INPUT_DIR, "*.faa")) + glob.glob(os.path.join(INPUT_DIR, "*.fasta"))
print(f"Found {len(all_files)} files to process")

# 批量处理
total_args = 0
for i, input_path in enumerate(all_files):
    filename = os.path.basename(input_path)
    output_path = os.path.join(OUTPUT_DIR, f"{filename}_pred.fasta")
    
    # 断点续传：跳过已处理文件
    if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
        continue
    
    n_args = predict_file(model, input_path, output_path, config, THRESHOLD)
    total_args += n_args
    
    if (i + 1) % 100 == 0:
        print(f"Progress: {i+1}/{len(all_files)}")

print(f"\nDone! Total ARGs found: {total_args}")
