In [None]:
from model_one import BepiPredDDPM
from data import ESM2Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from torch import optim
from torch.nn.utils.rnn import pad_sequence
import os
import pandas as pd
from tqdm import tqdm
import math
from scipy.stats import norm
import random
import pathlib
from pathlib import Path
import matplotlib.pyplot as plt
random.seed(418)  # 设置种子值


ModuleNotFoundError: No module named 'esm'

In [None]:
### SET GPU OR CPU ###
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"GPU device detected: {device}")
else:
    device = torch.device("cpu")
    print(f"GPU device not detected. Using CPU: {device}")
    


def get_alpha_cumprod(t, steps=1000, beta_start=0.0001, beta_end=0.02):
    """
    计算alpha的累计乘积
    """
    beta = torch.linspace(beta_start, beta_end, steps)
    alpha = 1 - beta
    alpha_cumprod = torch.cumprod(alpha, dim=0)
    return alpha_cumprod[t]

GPU device not detected. Using CPU: cpu


In [None]:
# 划分训练集和测试集并加载数据集
esm_encoding_dir=Path("data/esm_encodings")
train_esm_encoding_dir, test_esm_encoding_dir = train_test_split(esm_encoding_dir, test_size=0.2, random_state=42)
train_dataset = ESM2Dataset(train_esm_encoding_dir)
test_dataset = ESM2Dataset(test_esm_encoding_dir)


def collate_fn(batch):
    # 每个样本是 (input_tensor, target_tensor)
    input_seqs = [item[0] for item in batch]
    target_seqs = [item[1] for item in batch]

    input_lengths = torch.tensor([len(seq) for seq in input_seqs])
    target_lengths = torch.tensor([len(seq) for seq in target_seqs])

    # 两个序列的长度是一样的，可以统一 pad 成 max_len
    # 你也可以分别 pad（如果以后变成不一样长）

    input_padded = pad_sequence(input_seqs, batch_first=True, padding_value=0)
    target_padded = pad_sequence(target_seqs, batch_first=True, padding_value=0)

    return input_padded, input_lengths, target_padded, target_lengths
    
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)



for esm, epitope, esm_len, epitope_len in train_dataloader:
    print(esm.shape)
    print(epitope.shape)
    break

In [None]:
# checkpoint
checkpoint_dir = './checkpoints'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

def save_checkpoint(epoch, model, optimizer, loss, checkpoint_dir):
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)
    print(f'Checkpoint saved to {checkpoint_path}')

def load_checkpoint(model, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return model, optimizer, epoch, loss

In [102]:
epochs = 10
model = BepiPredDDPM()
diffusion = Diffusion()

optimizer = optim.Adam(model.parameters(), lr=0.001)

steps = 1000

In [None]:
# def train(model, dataloader, optimizer, steps=1000, device=device, epochs=10):
model.train()

train_losses = []
epochs_losses = []  # 记录每个epoch的平均损失

# 设置绘图风格
plt.style.use('seaborn')

for epoch in range(epochs):
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False)
    loss_record= []
    for x0, epitope_labels, x0_len, epitope_len in progress_bar:  # x0: 真实ESM嵌入 [B, L, D]
        print(x0.shape)
        
        # epitope_labels = epitope_labels.to(device)
        # x0 = x0.to(device)
        
        # 1. 随机采样时间步和噪声
        t = torch.randint(0, steps, (x0.size(0),), device=device)
        noise = torch.randn_like(x0)
        
        # 2. 前向加噪（根据噪声调度）
        
        xt = diffusion.q_sample(x0, t)
        t = t.float()
        # 3. 预测噪声并计算损失
        pred_noise, epitope_prob = model(xt, t)

        # 4. 计算表位分类损失
        loss_diffusion = F.mse_loss(pred_noise, noise)
        loss_epitope = F.binary_cross_entropy(epitope_prob, epitope_labels)  # 需提供真实表位标签
        total_loss = 0.3 * loss_diffusion + 0.7 * loss_epitope  # 加权平衡
        
        # 4. 反向传播
        optimizer.zero_grad()
        total_loss.backward()
        loss_record.append(total_loss.item())
        optimizer.step()
        
        # 5. 记录损失
        progress_bar.set_postfix({"loss": total_loss.item()})
        epoch_loss = torch.tensor(loss_record).mean().item()
        epochs_losses.append(epoch_loss)
        
    # 每15个epoch保存一次checkpoint
    if (epoch + 1) % 15 == 0:
        save_checkpoint(epoch + 1, model, optimizer, epoch_loss, checkpoint_dir)
        print(f"Epoch {epoch + 1}, Loss_Mean: {epoch_loss}", end="\r")
        plt.plot(epochs_losses)
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training Loss')
        plt.show()



Epoch 1/10:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([20, 500, 1281])


  loss_diffusion = F.mse_loss(pred_noise, noise)
                                                 

torch.Size([20, 500, 1281])
torch.Size([20, 500, 1281])
torch.Size([20, 500, 1281])




RuntimeError: The size of tensor a (2) must match the size of tensor b (1281) at non-singleton dimension 2

In [None]:
def eval_model(model, dataloader, device=device):
    probs, labels = [], []
    model.eval()
    with torch.no_grad():
        for x0, epitope_labels in dataloader:
            _, epitope_prob = model(x0.to(device), torch.zeros(x0.size(0)).to(device))
            probs.append(epitope_prob.cpu())
            labels.append(epitope_labels.cpu())
    
    probs = torch.cat(probs).numpy()
    labels = torch.cat(labels).numpy()
    precision, recall, _ = precision_recall_curve(labels, probs)
    pr_auc = auc(recall, precision)  # 更关注正类的指标
    return pr_auc

torch.Size([20, 500, 1281])
torch.Size([20])
