# BiLSTM 多分类模型 - 批量预测

加载训练好的模型，对ARG序列进行类别分类


In [None]:
import os
import glob
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from Bio import SeqIO
from tqdm import tqdm


## 1. 配置


In [None]:
# ================== 路径配置（请修改）==================
MODEL_PATH = "/path/to/well-trained/bilstm_multi_xxx.pth"  # 训练保存的模型
INPUT_DIR = "/path/to/predicted_ARGs"                      # 二分类预测出的ARG文件夹
OUTPUT_CSV = "./classification_results.csv"                # 输出结果

# 设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


## 2. 模型与数据处理定义

**注意**: 这里的定义必须与训练代码完全一致！


In [None]:
# 氨基酸编码字典（必须与训练一致）
AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWY'
AA_DICT = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
AA_DICT.update({
    'B': [AA_DICT['D'], AA_DICT['N']],
    'Z': [AA_DICT['E'], AA_DICT['Q']],
    'J': [AA_DICT['I'], AA_DICT['L']],
    'X': 'ANY',
    'PAD': 20
})

def one_hot_encode(sequence, max_length):
    """将氨基酸序列转换为one-hot编码"""
    encoding = np.zeros((max_length, 21), dtype=np.float32)
    for i in range(min(len(sequence), max_length)):
        aa = sequence[i]
        if aa in AA_DICT:
            idx = AA_DICT[aa]
            if isinstance(idx, list):
                for j in idx:
                    encoding[i, j] = 0.5
            elif idx == 'ANY':
                encoding[i, :20] = 0.05
            else:
                encoding[i, idx] = 1.0
        else:
            encoding[i, :20] = 0.05
    if len(sequence) < max_length:
        encoding[len(sequence):, 20] = 1.0
    return encoding


class BiLSTMClassifier(nn.Module):
    """BiLSTM + Global Pooling 多分类模型（必须与训练一致）"""
    
    def __init__(self, config, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=config['embedding_size'],
            hidden_size=config['hidden_size'],
            num_layers=config['num_layers'],
            batch_first=True,
            bidirectional=True,
            dropout=config['dropout'] if config['num_layers'] > 1 else 0
        )
        self.dropout = nn.Dropout(config['dropout'])
        self.classifier = nn.Sequential(
            nn.Linear(config['hidden_size'] * 4, config['hidden_size']),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(config['hidden_size'], num_classes)
        )

    def forward(self, x):
        output, _ = self.lstm(x)
        max_pool, _ = torch.max(output, dim=1)
        avg_pool = torch.mean(output, dim=1)
        features = torch.cat([max_pool, avg_pool], dim=1)
        return self.classifier(self.dropout(features))


## 3. 加载模型


In [None]:
# 加载模型（自动读取保存的配置）
print(f"Loading model from: {MODEL_PATH}")
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)

# 从checkpoint中读取配置
config = checkpoint['model_config']
class_names = checkpoint['class_names']
max_length = checkpoint['max_length']

print(f"Model config: {config}")
print(f"Classes: {class_names}")
print(f"Max length: {max_length}")

# 初始化并加载模型
model = BiLSTMClassifier(config, len(class_names)).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully!")


## 4. 批量预测


In [None]:
def process_batch(model, sequences, metadata, class_names, max_length):
    """处理一个Batch并返回结果"""
    # 编码
    encoded = np.array([one_hot_encode(seq, max_length) for seq in sequences])
    inputs = torch.tensor(encoded, dtype=torch.float32).to(device)
    
    with torch.no_grad():
        logits = model(inputs)
        probs = torch.softmax(logits, dim=1)
        max_probs, preds = torch.max(probs, dim=1)
    
    results = []
    for i in range(len(preds)):
        file_name, seq_id, seq_str = metadata[i]
        pred_class = class_names[preds[i].item()]
        prob = max_probs[i].item()
        results.append({
            'FileName': file_name,
            'SequenceID': seq_id,
            'PredictedClass': pred_class,
            'Probability': prob,
            'Sequence': seq_str
        })
    return results


def run_classification():
    """批量分类"""
    # 获取所有输入文件
    files = glob.glob(os.path.join(INPUT_DIR, "*.fasta")) + \
            glob.glob(os.path.join(INPUT_DIR, "*.faa"))
    print(f"Found {len(files)} files to process")
    
    all_results = []
    batch_size = 256
    batch_seqs = []
    batch_meta = []
    
    for file_path in tqdm(files, desc="Processing"):
        try:
            records = list(SeqIO.parse(file_path, "fasta"))
            if not records:
                continue
            
            file_name = os.path.basename(file_path)
            
            for record in records:
                seq_str = str(record.seq).upper()
                batch_seqs.append(seq_str)
                batch_meta.append((file_name, record.id, seq_str))
                
                if len(batch_seqs) >= batch_size:
                    results = process_batch(model, batch_seqs, batch_meta, class_names, max_length)
                    all_results.extend(results)
                    batch_seqs = []
                    batch_meta = []
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            continue
    
    # 处理剩余数据
    if batch_seqs:
        results = process_batch(model, batch_seqs, batch_meta, class_names, max_length)
        all_results.extend(results)
    
    # 保存结果
    df = pd.DataFrame(all_results)
    df.to_csv(OUTPUT_CSV, index=False)
    print(f"\nDone! Results saved to: {OUTPUT_CSV}")
    print(f"Total sequences classified: {len(df)}")
    
    # 打印统计
    print("\nClass distribution:")
    print(df['PredictedClass'].value_counts())
    
    return df


## 5. 运行


In [None]:
# 运行分类
results_df = run_classification()
