In [17]:
from model_one import BepiPredDDPM, ESM2Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torch import optim
import os
import pandas as pd
from tqdm import tqdm
import math
from scipy.stats import norm
import random
import pathlib
from pathlib import Path
random.seed(418)  # 设置种子值


In [18]:
### SET GPU OR CPU ###
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"GPU device detected: {device}")
else:
    device = torch.device("cpu")
    print(f"GPU device not detected. Using CPU: {device}")


def get_alpha_cumprod(t, steps=1000, beta_start=0.0001, beta_end=0.02):
    """
    计算alpha的累计乘积
    """
    beta = torch.linspace(beta_start, beta_end, steps)
    alpha = 1 - beta
    alpha_cumprod = torch.cumprod(alpha, dim=0)
    return alpha_cumprod[t]

GPU device not detected. Using CPU: cpu


In [19]:
def train(model, dataloader, optimizer, steps=1000, device="cuda", epochs=10):
    model.train()
    for epoch in range(epochs):
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False)
        loss_record= []
        for x0, epitope_labels in progress_bar:  # x0: 真实ESM嵌入 [B, L, D]
            print(x0.shape)
            epitope_labels = epitope_labels.to(device)
            x0 = x0.to(device)
            
            # 1. 随机采样时间步和噪声
            t = torch.randint(0, steps, (x0.size(0),), device=device)
            noise = torch.randn_like(x0)
            
            # 2. 前向加噪（根据噪声调度）
            alpha_cumprod = get_alpha_cumprod(t)  
            xt = torch.sqrt(alpha_cumprod) * x0 + torch.sqrt(1 - alpha_cumprod) * noise
            
            # 3. 预测噪声并计算损失
            pred_noise, epitope_prob = model(xt, t)
            loss_diffusion = F.mse_loss(pred_noise, noise)
            loss_epitope = F.binary_cross_entropy(epitope_prob, epitope_labels)  # 需提供真实表位标签
            total_loss = loss_diffusion + 0.1 * loss_epitope  # 加权平衡
            
            # 4. 反向传播
            optimizer.zero_grad()
            total_loss.backward()
            loss_record.append(total_loss.item())
            optimizer.step()
            
            # 5. 记录损失
            progress_bar.set_postfix({"loss": total_loss.item()})

            # 每个 epoch 结束后打印损失并保存
            print(f"Epoch {epoch + 1}, Loss_Mean: {torch.tensor(loss_record).mean()}", end="\r")



In [21]:
# 加载数据集
dataset = ESM2Dataset(esm_encoding_dir=Path("data/esm_encodings"))
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
for esm, epitope in dataset:
    print(esm.shape)
    print(epitope.shape)
    break
for esm, epitope in dataloader:
    print(esm.shape)
    print(epitope.shape)
    break

torch.Size([7, 1281])


AttributeError: 'list' object has no attribute 'shape'

In [None]:


# 初始化模型和优化器
model = BepiPredDDPM()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
train(model, dataloader, optimizer, epochs=10)

                                                 

RuntimeError: stack expects each tensor to be equal size, but got [500, 1281] at entry 0 and [17, 1281] at entry 1