In [None]:
# 导入所需的库和模块
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer  # 用于加载预训练模型，例如ESM-1b，ESM-IF1，Vicuna-13b

### Step 1: 数据集的加载与预处理
class ProteinDataset:
    def __init__(self, seq_file, structure_file, description_file):
        # 载入蛋白质1D序列、3D结构和文本描述数据
        self.sequences = load_sequences(seq_file)  # 加载蛋白质1D序列
        self.structures = load_structures(structure_file)  # 加载蛋白质3D结构
        self.descriptions = load_descriptions(description_file)  # 加载对应的文本描述
        
    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.structures[idx], self.descriptions[idx] 

# 定义数据加载器
train_dataset = ProteinDataset('train_seq.txt', 'train_structure.pdb', 'train_description.txt')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

### Step 2: 初始化预训练模型 (ESM-1b, ESM-IF1, Vicuna-13b)

# 加载预训练的蛋白质编码器
seq_encoder = AutoModel.from_pretrained("facebook/esm1b_t33_650M_UR50S")  # 1D序列编码器 ESM-1b
structure_encoder = AutoModel.from_pretrained("facebook/esm_if1_gvp4_t16")  # 3D结构编码器 ESM-IF1

# 加载预训练的语言模型 (Vicuna-13b)
language_model = AutoModel.from_pretrained("lmsys/vicuna-13b")  # Vicuna-13b LLM
tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-13b")

### Step 3: 定义 PLP-former 模块
class PLPFormer(nn.Module):
    def __init__(self, seq_dim, structure_dim, hidden_dim):
        super(PLPFormer, self).__init__()
        # Transformer 模块用于对蛋白质嵌入进行处理
        self.transformer = nn.Transformer(d_model=hidden_dim, nhead=8, num_encoder_layers=4)
        self.linear_seq = nn.Linear(seq_dim, hidden_dim)  # 将1D嵌入投射到隐藏维度
        self.linear_structure = nn.Linear(structure_dim, hidden_dim)  # 将3D嵌入投射到隐藏维度

    def forward(self, seq_embedding, structure_embedding):
        # 将1D和3D嵌入映射到统一的空间
        seq_proj = self.linear_seq(seq_embedding)
        structure_proj = self.linear_structure(structure_embedding)
        
        # 合并1D和3D信息，输入到Transformer中
        combined_embedding = torch.cat((seq_proj, structure_proj), dim=1)
        output_embedding = self.transformer(combined_embedding)
        return output_embedding

# 初始化 PLP-former
plp_former = PLPFormer(seq_dim=1280, structure_dim=512, hidden_dim=768)

### Step 4: 定义投影适配器
class ProjectionAdapter(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProjectionAdapter, self).__init__()
        self.adapter = nn.Linear(input_dim, output_dim)  # 将PLP-former的输出映射到Vicuna-13b的输入维度

    def forward(self, plp_output):
        return self.adapter(plp_output)

# 初始化投影适配器
projection_adapter = ProjectionAdapter(input_dim=768, output_dim=4096)  # 假设Vicuna-13b的输入维度为4096

### Step 5: 定义完整的 ProtChatGPT 模型
class ProtChatGPT(nn.Module):
    def __init__(self, seq_encoder, structure_encoder, plp_former, projection_adapter, language_model):
        super(ProtChatGPT, self).__init__()
        self.seq_encoder = seq_encoder  # 1D序列编码器
        self.structure_encoder = structure_encoder  # 3D结构编码器
        self.plp_former = plp_former  # PLP-former
        self.projection_adapter = projection_adapter  # 投影适配器
        self.language_model = language_model  # Vicuna-13b

    def forward(self, seq_input, structure_input, question):
        # 1. 编码蛋白质1D序列和3D结构
        seq_embedding = self.seq_encoder(seq_input).last_hidden_state
        structure_embedding = self.structure_encoder(structure_input).last_hidden_state
        
        # 2. 使用 PLP-former 对齐蛋白质信息
        plp_output = self.plp_former(seq_embedding, structure_embedding)
        
        # 3. 使用投影适配器将输出嵌入投影到语言模型的输入空间
        adapted_embedding = self.projection_adapter(plp_output)
        
        # 4. 使用语言模型生成回答
        question_tokens = tokenizer(question, return_tensors='pt').input_ids
        lm_output = self.language_model(inputs_embeds=adapted_embedding, labels=question_tokens)
        return lm_output

# 初始化 ProtChatGPT 模型
protchatgpt = ProtChatGPT(seq_encoder, structure_encoder, plp_former, projection_adapter, language_model)

### Step 6: 训练过程
optimizer = torch.optim.Adam(protchatgpt.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for seq_input, structure_input, description in train_loader:
        optimizer.zero_grad()
        
        # Forward pass
        output = protchatgpt(seq_input, structure_input, description)
        
        # 计算损失
        loss = criterion(output.logits.view(-1, output.logits.size(-1)), description.view(-1))
        loss.backward()
        
        # 更新参数
        optimizer.step()
        
        print(f'Epoch {epoch}, Loss: {loss.item()}')

print("模型训练完毕")


| 数据集类型 | 读取样本数量 | 读取用时 (s) | CPU 使用率 (%) | 内存使用 (MB) | 数据处理样本数量 | 平均处理时间 (s) | 总处理用时 (s) |
|------------|--------------|--------------|----------------|--------------|------------------|-----------------|----------------|
| KITTI      | 5000         | 4.4650       | 0.80           | 23.88        | 5000             | 0.089098         | 445.4990       |
| MindRecord | 5000         | 0.1638       | 2.10           | 11.21        | 5000             | 0.039375         | 196.8923       |
|------------|--------------|--------------|----------------|--------------|------------------|-----------------|----------------|
| Libritts      | 5000         | 2.7889       | 1.20           | 45.55        | 500              | 0.014486         | 7.2438         |
| MindRecord | 5000         | 2.4033       | 0.30           | 8.06         | 500              | 0.008484         | 4.2429         |
|------------|--------------|--------------|----------------|--------------|------------------|-----------------|----------------|
| Ljspeech    | 50000        | 2.6579       | 0.50           | 49.95        | 5000             | 0.008219         | 41.1026        |
| MindRecord  | 50000        | 2.4380       | 0.20           | 3.86         | 5000             | 0.009397         | 46.9934        |
|------------|--------------|--------------|----------------|--------------|------------------|-----------------|----------------|
| Squad      | 5000         | 9.0251       | 0.00           | 938.77       | 50               | 6.166194        | 308.3099       |
| MindRecord  | 5000         | 2.8578       | 0.00           | 1.63         | 50               | 0.464135        | 23.2070        |
|------------|--------------|--------------|----------------|--------------|------------------|-----------------|----------------|
| SST        | 500          | 0.2428       | 0.60           | 10.21        | 500              | 0.214583        | 107.2924       |
| MindRecord | 500          | 0.2616       | 1.60           | 38.04        | 500              | 0.235432        | 117.7181       |
|------------|--------------|--------------|----------------|--------------|------------------|-----------------|----------------|
| MindRecord | 100          | 0.3102       | 0.60           | 61.35        | 100              | 0.210007        | 21.0011        |
| VOC        | 100          | 0.4328       | 3.10           | 8.24         | 100              | 0.373909        | 37.3911        |
|------------|--------------|--------------|----------------|--------------|------------------|-----------------|----------------|
| Wiki       | 50000        | 1.4066       | 0.80           | 13.94        | 500              | 0.052930         | 26.4657        |
| MindRecord | 50000        | 4.7785       | 1.70           | 29.96        | 500              | 0.152549        | 76.2847        |