In [36]:
from model_one import BepiPredDDPM, ESM2Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torch import optim
from torch.nn.utils.rnn import pad_sequence
import os
import pandas as pd
from tqdm import tqdm
import math
from scipy.stats import norm
import random
import pathlib
from pathlib import Path
random.seed(418)  # 设置种子值


In [48]:
### SET GPU OR CPU ###
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"GPU device detected: {device}")
else:
    device = torch.device("cpu")
    print(f"GPU device not detected. Using CPU: {device}")
    


def get_alpha_cumprod(t, steps=1000, beta_start=0.0001, beta_end=0.02):
    """
    计算alpha的累计乘积
    """
    beta = torch.linspace(beta_start, beta_end, steps)
    alpha = 1 - beta
    alpha_cumprod = torch.cumprod(alpha, dim=0)
    return alpha_cumprod[t]

GPU device not detected. Using CPU: cpu


In [None]:
# 加载数据集
dataset = ESM2Dataset(esm_encoding_dir=Path("data/esm_encodings"))

def collate_fn(batch):
    # 每个样本是 (input_tensor, target_tensor)
    input_seqs = [item[0] for item in batch]
    target_seqs = [item[1] for item in batch]

    input_lengths = torch.tensor([len(seq) for seq in input_seqs])
    target_lengths = torch.tensor([len(seq) for seq in target_seqs])

    # 两个序列的长度是一样的，可以统一 pad 成 max_len
    # 你也可以分别 pad（如果以后变成不一样长）

    input_padded = pad_sequence(input_seqs, batch_first=True, padding_value=0)
    target_padded = pad_sequence(target_seqs, batch_first=True, padding_value=0)

    return input_padded, input_lengths, target_padded, target_lengths
    
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)


for esm, epitope, esm_len, epitope_len in dataloader:
    print(esm.shape)
    print(epitope.shape)
    break

In [60]:
class Diffusion:
    def __init__(self, timesteps=1000, beta_start=0.0001, beta_end=0.02, device=device):
        self.timesteps = timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.device = device
        
        # 线性噪声调度（公式4）
        self.betas = torch.linspace(self.beta_start, self.beta_end, self.timesteps, device=self.device)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)  # ᾱ_t（公式4推导）
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)  # √ᾱ_t
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)  # √(1-ᾱ_t)
        
        # 后验q(x_{t-1}|x_t,x_0)的计算（公式7）
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

    def extract(self, a, t, x_shape):
        """从a中根据t提取系数并重塑使其能与x_shape广播"""
        batch_size = t.shape[0]
        out = a.gather(-1, t)
        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

    def q_sample(self, x_start, t, noise=None):
        """前向扩散过程：q(x_t | x_0)（公式4推导）"""
        if noise is None:
            noise = torch.randn_like(x_start)
            
        sqrt_alphas_cumprod_t = self.extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = self.extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    

In [101]:
 

class MyDenseNetWithSeqLen(nn.Module):
    def __init__(self,
                 esm_embedding_size = 1281,
                 fc1_size = 150,
                 fc2_size = 120,
                 fc3_size = 45,
                 fc1_dropout = 0.7,
                 fc2_dropout = 0.7,
                 fc3_dropout = 0.7,
                 num_of_classes = 2):
        super(MyDenseNetWithSeqLen, self).__init__()
        
        
        self.esm_embedding_size = esm_embedding_size
        self.fc1_size = fc1_size
        self.fc2_size = fc2_size
        self.fc3_size = fc3_size
        self.fc1_dropout = fc1_dropout
        self.fc2_dropout = fc2_dropout
        self.fc3_dropout = fc3_dropout
        
        self.ff_model = nn.Sequential(nn.Linear(esm_embedding_size, fc1_size),
                                      nn.ReLU(),
                                      nn.Dropout(fc1_dropout),
                                      nn.Linear(fc1_size, fc2_size),
                                      nn.ReLU(),
                                      nn.Dropout(fc2_dropout),
                                      nn.Linear(fc2_size, fc3_size),
                                      nn.ReLU(),
                                      nn.Dropout(fc3_dropout),
                                      nn.Linear(fc3_size, num_of_classes))
    
    def forward(self, antigen):
        batch_size = antigen.size(0)
        seq_len = antigen.size(1)
        print(antigen.shape)
        #convert dim (N, L, esm_embedding) --> (N*L, esm_embedding)
        output = torch.reshape(antigen, (batch_size*seq_len, self.esm_embedding_size))
        output = self.ff_model(output)                                               
        return output


class BepiPredDDPM(nn.Module):
    def __init__(self, esm_embedding_size=1280, timestep_dim=64):
        super().__init__()
        # 时间步嵌入层
        self.timestep_embed = nn.Sequential(
            nn.Linear(timestep_dim, esm_embedding_size + 1),
            nn.SiLU(),
            nn.Linear(esm_embedding_size+1, esm_embedding_size+1)
        )
        
        # 继承 BepiPred-3.0 的 FFNN
        self.ffnn = MyDenseNetWithSeqLen()  # 输入拼接时间步信息

        
        self.classifier = nn.Linear(esm_embedding_size+1, 1)  # 二分类
        

    def get_time_embedding(self, t, dim):
        # t: 时间步张量, shape=(batch_size, 1)
        # dim: 嵌入向量的维度
        half_dim = dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
        emb = t.float() * emb.unsqueeze(0).to(device)  
        # emb.shape ==> (batch_size, half_dim)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # shape=(batch_size, dim)
        return emb.to(device)

    def forward(self, x, t):
        # x: 噪声化ESM嵌入 [B, L, D]
        # t: 时间步 [B]
        t = t.unsqueeze(-1)
        t_embed = self.get_time_embedding(t, dim=64)  # [B, D]
        t_embed = self.timestep_embed(t_embed)
        t_embed = t_embed.unsqueeze(1).expand(-1, x.size(1), -1)  # [B, L, D]
        # print(t_embed.shape, x.shape)
        x = t_embed + x
        # x = torch.cat([x, t_embed], dim=-1)
        print(x.shape)
        noise_pred = self.ffnn(x)  # 预测噪声 [B, L, D]
        # 若需联合表位分类
        print(x.shape)
        
        epitope_prob = torch.sigmoid(self.classifier(x))  # [B, L, 1]
        return noise_pred, epitope_prob


In [102]:
epochs = 10
model = BepiPredDDPM()
diffusion = Diffusion()

optimizer = optim.Adam(model.parameters(), lr=0.001)

steps = 1000

In [103]:
# def train(model, dataloader, optimizer, steps=1000, device=device, epochs=10):
model.train()

for epoch in range(epochs):
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False)
    loss_record= []
    for x0, epitope_labels, x0_len, epitope_len in progress_bar:  # x0: 真实ESM嵌入 [B, L, D]
        print(x0.shape)
        
        # epitope_labels = epitope_labels.to(device)
        # x0 = x0.to(device)
        
        # 1. 随机采样时间步和噪声
        t = torch.randint(0, steps, (x0.size(0),), device=device)
        noise = torch.randn_like(x0)
        
        # 2. 前向加噪（根据噪声调度）
        
        xt = diffusion.q_sample(x0, t)
        t = t.float()
        # 3. 预测噪声并计算损失
        pred_noise, epitope_prob = model(xt, t)
        loss_diffusion = F.mse_loss(pred_noise, noise)
        loss_epitope = F.binary_cross_entropy(epitope_prob, epitope_labels)  # 需提供真实表位标签
        total_loss = loss_diffusion + 0.1 * loss_epitope  # 加权平衡
        
        # 4. 反向传播
        optimizer.zero_grad()
        total_loss.backward()
        loss_record.append(total_loss.item())
        optimizer.step()
        
        # 5. 记录损失
        progress_bar.set_postfix({"loss": total_loss.item()})

        # 每个 epoch 结束后打印损失并保存
        print(f"Epoch {epoch + 1}, Loss_Mean: {torch.tensor(loss_record).mean()}", end="\r")



Epoch 1/10:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([20, 500, 1281])


  loss_diffusion = F.mse_loss(pred_noise, noise)
                                                 

torch.Size([20, 500, 1281])
torch.Size([20, 500, 1281])
torch.Size([20, 500, 1281])




RuntimeError: The size of tensor a (2) must match the size of tensor b (1281) at non-singleton dimension 2

torch.Size([20, 500, 1281])
torch.Size([20])
