# BiLSTM 二分类模型 - 批量预测

加载训练好的模型，对FASTA文件进行批量ARG预测


In [6]:
import os
import glob
import csv
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from Bio import SeqIO


## 1. 配置


In [None]:
# ================== 路径配置（请修改）==================
MODEL_PATH = "/home/mayue/model_c/binary/model_train/well-trained/bilstm_20260206_1422.pth"  # 训练保存的模型

# 二选一：
# 1) INPUT_PATH：只测试单个 .faa/.fasta 文件（推荐用于先跑通流程）
# 2) INPUT_DIR：批量测试目录下所有 .faa/.fasta 文件
INPUT_PATH = "/home/mayue/bilstm/data/predict_sequence/prodigal_result_remove/CRR029083_bin.1_remove.faa"  # 例如：/path/to/xxx.faa
# INPUT_DIR = "./binary_test_out"   # 例如：/path/to/faa_dir

OUTPUT_DIR = "./results"  # 输出目录
THRESHOLD = 0.5  # 预测阈值

# 输出控制：
SAVE_PRED_FASTA = True   # 是否输出预测为 ARG 的 FASTA（供多分类使用）
SAVE_SCORES_CSV = True   # 是否输出全量序列的预测分数 CSV（用于真实性能评估）

# 设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 创建输出目录
os.makedirs(OUTPUT_DIR, exist_ok=True)


Using device: cuda:0


## 2. 模型与数据处理定义

**注意**: 这里的定义必须与训练代码完全一致！


In [8]:
# 氨基酸编码字典（必须与训练一致）
AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWY'
AA_DICT = {aa: i + 1 for i, aa in enumerate(AMINO_ACIDS)}
AA_DICT.update({'X': 21, 'PAD': 0})

def seq_to_indices(sequence, max_length):
    """将氨基酸序列转换为索引数组"""
    indices = [AA_DICT.get(aa, 21) for aa in sequence]
    indices = indices[:max_length]
    if len(indices) < max_length:
        indices += [0] * (max_length - len(indices))
    return np.array(indices, dtype=np.int64)


class BiLSTMModel(nn.Module):
    """BiLSTM + Global Pooling 二分类模型（必须与训练一致）"""
    
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(
            config['vocab_size'], 
            config['embedding_dim'], 
            padding_idx=0
        )
        self.lstm = nn.LSTM(
            input_size=config['embedding_dim'],
            hidden_size=config['hidden_size'],
            num_layers=config['num_layers'],
            batch_first=True,
            bidirectional=True,
            dropout=config['dropout'] if config['num_layers'] > 1 else 0
        )
        self.dropout = nn.Dropout(config['dropout'])
        self.classifier = nn.Sequential(
            nn.Linear(config['hidden_size'] * 4, config['hidden_size']),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(config['hidden_size'], 1)
        )

    def forward(self, x):
        emb = self.embedding(x)
        output, _ = self.lstm(emb)
        max_pool, _ = torch.max(output, dim=1)
        avg_pool = torch.mean(output, dim=1)
        features = torch.cat([max_pool, avg_pool], dim=1)
        return self.classifier(self.dropout(features))


class InferenceDataset(Dataset):
    """推理用数据集（返回索引 + 元信息，便于输出全量 CSV）"""
    def __init__(self, fasta_file, max_length):
        self.records = list(SeqIO.parse(fasta_file, "fasta"))
        self.max_length = max_length

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        record = self.records[idx]
        seq_str = str(record.seq).upper()
        x = torch.from_numpy(seq_to_indices(seq_str, self.max_length))
        return x, record.id, len(seq_str)


## 3. 加载模型


In [9]:
# 加载模型（自动读取保存的配置）
print(f"Loading model from: {MODEL_PATH}")
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
config = checkpoint['config']
print(f"Model config: {config}")

model = BiLSTMModel(config).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully!")


Loading model from: /home/mayue/model_c/binary/model_train/well-trained/bilstm_20260206_1422.pth
Model config: {'vocab_size': 22, 'embedding_dim': 48, 'hidden_size': 48, 'num_layers': 1, 'dropout': 0.5, 'max_length': 1000}
Model loaded successfully!


## 4. 批量预测


In [10]:
# 容错：如果未运行“路径配置”cell，给关键变量设置默认值，避免 NameError
INPUT_PATH = globals().get('INPUT_PATH', '')
INPUT_DIR = globals().get('INPUT_DIR', '')
OUTPUT_DIR = globals().get('OUTPUT_DIR', './predicted_results')
THRESHOLD = globals().get('THRESHOLD', 0.5)
SAVE_PRED_FASTA = globals().get('SAVE_PRED_FASTA', True)
SAVE_SCORES_CSV = globals().get('SAVE_SCORES_CSV', True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

def predict_file(model, input_path, config, threshold, output_dir, save_pred_fasta=True, save_scores_csv=True):
    """对单个FASTA/FAA文件进行预测，并按需输出：预测阳性 FASTA + 全量分数 CSV"""
    dataset = InferenceDataset(input_path, config['max_length'])
    if len(dataset) == 0:
        return 0
    
    filename = os.path.basename(input_path)
    pred_fasta_path = os.path.join(output_dir, f"{filename}_pred.fasta")
    scores_csv_path = os.path.join(output_dir, f"{filename}_scores.csv")
    
    loader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=4)
    arg_records = []
    rows = []
    
    with torch.no_grad():
        for batch_i, batch in enumerate(loader):
            inputs, ids, lens = batch
            inputs = inputs.to(device)
            logits = model(inputs)
            probs = torch.sigmoid(logits).cpu().numpy().flatten()
            
            for i, prob in enumerate(probs):
                global_idx = batch_i * loader.batch_size + i
                record = dataset.records[global_idx]
                pred = 1 if prob > threshold else 0
                rows.append({
                    'FileName': filename,
                    'SequenceID': record.id,
                    'SeqLen': int(lens[i]),
                    'ARG_Prob': float(prob),
                    'ARG_Pred': int(pred),
                    'Threshold': float(threshold),
                })
                if pred == 1 and save_pred_fasta:
                    record.description += f" [ARG_prob={prob:.3f}]"
                    arg_records.append(record)
    
    if save_pred_fasta and arg_records:
        SeqIO.write(arg_records, pred_fasta_path, "fasta")
    
    if save_scores_csv:
        with open(scores_csv_path, 'w', newline='', encoding='utf-8') as f:
            writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()) if rows else [
                'FileName','SequenceID','SeqLen','ARG_Prob','ARG_Pred','Threshold'
            ])
            writer.writeheader()
            for r in rows:
                writer.writerow(r)
    
    return len(arg_records)


# 获取待处理文件（优先 INPUT_PATH）
if INPUT_PATH:
    all_files = [INPUT_PATH]
else:
    if not INPUT_DIR:
        raise ValueError('请设置 INPUT_PATH（单文件）或 INPUT_DIR（目录）')
    all_files = glob.glob(os.path.join(INPUT_DIR, "*.faa")) + glob.glob(os.path.join(INPUT_DIR, "*.fasta"))

print(f"Found {len(all_files)} files to process")

# 批量处理
total_args = 0
for i, input_path in enumerate(all_files):
    n_args = predict_file(model, input_path, config, THRESHOLD, OUTPUT_DIR, SAVE_PRED_FASTA, SAVE_SCORES_CSV)
    total_args += n_args
    if (i + 1) % 10 == 0:
        print(f"Progress: {i+1}/{len(all_files)}")

print(f"\nDone! Total predicted ARG sequences written: {total_args}")


Found 1 files to process

Done! Total predicted ARG sequences written: 334
