In [2]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import torch
import numpy as np
'''
class VirusDataset(Dataset):
    def __init__(self, X, y):
        self.X = X#torch.tensor(X, dtype=torch.float32)#因为之前scaler.fit_transform(X)过后是array形状
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
'''
class VirusDataset(Dataset):
    def __init__(self, X, Y,Z,label, max_length=256,max_length_gene=1024):
        self.aacid_to_index = {'<cls>': 0,
                                 '<pad>': 1,
                                 '<eos>': 2,
                                 '<unk>': 3,
                                 'L': 4,
                                 'J': 4,
                                 'A': 5,
                                 'G': 6,
                                 'V': 7,
                                 'S': 8,
                                 'E': 9,
                                 'R': 10,
                                 'T': 11,
                                 'I': 12,
                                 'D': 13,
                                 'P': 14,
                                 'K': 15,
                                 'Q': 16,
                                 'N': 17,
                                 'F': 18,
                                 'Y': 19,
                                 'M': 20,
                                 'H': 21,
                                 'W': 22,
                                 'C': 23,
                                 'X': 24,
                                 'B': 25,
                                 'U': 26,
                                 'Z': 27,
                                 'O': 28,
                                 '.': 29,
                                 '-': 30,
                                 '<null_1>': 31,
                                 '<mask>': 32}
        self.start_token = '<cls>'
        self.end_token = '<eos>'
        self.pad_token = '<pad>'
        self.X = [self.tokenize_aacid_sequence(seq, max_length) for seq in X]
        self.Y = [self.tokenize_aacid_sequence(seq, max_length) for seq in Y]
        self.Z = [self.tokenize_aacid_sequence(seq, max_length) for seq in Z]
        self.label = label
    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx],self.Z[idx],self.label[idx]

    def tokenize_aacid_sequence(self, sequence, max_length):
        # 将序列截断或填充到max_length
        sequence = sequence.replace(' ','')
        sequence = [self.aacid_to_index[aacid] for aacid in sequence]
        sequence = [self.aacid_to_index[self.start_token]] + sequence + [self.aacid_to_index[self.end_token]]
        sequence = sequence[:max_length] + [self.aacid_to_index[self.pad_token]] * (max_length - len(sequence))

        # 转换为tensor
        sequence = torch.tensor(sequence, dtype=torch.long)

        return sequence
# 读取数据
selected_columns = pd.read_csv('/public/home/ligroupprotein/ckx/affinity/data/merged_dataset.tsv', sep='\t')

# X是每行的3-6列元素，Y是第一列的元素 第一列是基因
X = selected_columns.iloc[:, 1].values.reshape(-1, 1).tolist()
for i in range(len(X)):
    X[i] = ' '.join(X[i])
Y = selected_columns.iloc[:, 2].values.reshape(-1, 1).tolist()
for i in range(len(Y)):
    Y[i] = ' '.join(Y[i])
Z = selected_columns.iloc[:, 3].values.reshape(-1, 1).tolist()
for i in range(len(Z)):
    Z[i] = ' '.join(Z[i])
label = selected_columns.iloc[:, 4].values.reshape(-1, 1).tolist()
for i in range(len(label)):
    label[i] =  label[i][0]
X_train, X_test, y_train, y_test, Z_train, Z_test,label_train, label_test = train_test_split(X, Y,Z,label, test_size=0.2, random_state=42)

# 创建数据集
train_dataset = VirusDataset(X_train, y_train,Z_train,label_train)
test_dataset = VirusDataset(X_test, y_test,Z_test,label_test)

train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
# 创建数据加载器
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True)

In [3]:
import pytorch_lightning as pl
from torch import nn
import torch
import torchmetrics
from transformers import EsmTokenizer,EsmModel
import torch.nn.functional as F
from esm.models.esmc import ESMC

class TextCNN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes, num_classes,maxpool):
        super(TextCNN, self).__init__()
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=maxpool)
            )
            for kernel_size in kernel_sizes
        ])

    def forward(self, x):
        x = x.permute(0, 2, 1)#batch representation length
        x = [conv(x) for conv in self.convs]
        x = torch.cat(x, dim=2)
        return x

class HuberLoss(nn.Module):
    def __init__(self, delta=1.0):
        super().__init__()
        self.delta = delta
    
    def forward(self, pred, target):
        # 确保pred和target维度一致
        pred = pred.view(-1, 1)
        target = target.view(-1, 1)
        
        # 计算差值
        diff = pred - target
        abs_diff = torch.abs(diff)
        condition = abs_diff <= self.delta
        
        # 分段计算损失
        quadratic = 0.5 * diff ** 2
        linear = self.delta * abs_diff - 0.5 * self.delta ** 2
        
        loss = torch.where(condition, quadratic, linear)
        return loss.mean()

class LogCoshLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, pred, target):
        loss = torch.log(torch.cosh(pred - target))
        return torch.mean(loss)

class CustomProteinLoss(nn.Module):
    def __init__(self, mse_weight=0.7, mae_weight=0.3):
        super().__init__()
        self.mse_weight = mse_weight
        self.mae_weight = mae_weight
        
    def forward(self, pred, target):
        mse_loss = F.mse_loss(pred, target)
        mae_loss = F.l1_loss(pred, target)
        return self.mse_weight * mse_loss + self.mae_weight * mae_loss

class ClassifierNet(pl.LightningModule):
    def __init__(self):#应该是要上面的Transformer里从encoder里出来的用来做分类
        super(ClassifierNet, self).__init__()
        # 定义参数
        num_features = 512  # 特征数量，也是Transformer的d_model参数
        num_classes = 30  # 类别数量 因为是预测结果也是氨基酸 所以是词表大小 为30
        nhead = 8  # Transformer的头的数量
        num_encoder_layers = 3  # Transformer编码器的层数
        num_decoder_layers = 3  # Transformer解码器的层数
        learning_rate = 0.0001  # 学习率
        num_epochs = 100
        # 初始化模型
        seed = random.randint(0, 10000)
        self.esm = ESMC.from_pretrained("esmc_300m")
        self.esm_antigen = ESMC.from_pretrained("esmc_300m")
        self.layer1 = nn.Sequential(
            nn.Conv1d(960, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2))
        self.light_layer1 = nn.Sequential(
            nn.Conv1d(960, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.light_layer2 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2))
        self.antigen_layer1 = nn.Sequential(
            nn.Conv1d(960, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.antigen_layer2 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2))
        self.drop_out = nn.Dropout()
        self.fc1 = None#nn.Linear(self.out_channels * len(self.kernel_sizes) * (256 - max(self.kernel_sizes) + 1), 1000)
        self.fc2 = nn.Linear(1000,1)
        self.multihead_attention = nn.MultiheadAttention(embed_dim=128, num_heads=nhead)
        self.multihead_attentio_light = nn.MultiheadAttention(embed_dim=128, num_heads=nhead)
        self.multihead_attentio_antigen = nn.MultiheadAttention(embed_dim=128, num_heads=nhead)
        # 回归任务的评价指标
        self.mse = torchmetrics.MeanSquaredError()
        self.mae = torchmetrics.MeanAbsoluteError()
        self.rmse = torchmetrics.MeanSquaredError(squared=False)  # RMSE
        self.r2score = torchmetrics.R2Score()  # R²决定系数
        self.pearson = torchmetrics.PearsonCorrCoef()  # 皮尔森相关系数
        dtype = torch.float32  # 或 torch.bfloat16
        self.to(dtype)

    def on_train_epoch_start(self):
        # 记录当前使用哪个模型，便于日志记录
        current_model = "esm" if self.current_epoch % 2 == 0 else "esm_antigen"
        self.log("current_model", current_model)
        print(f"Epoch {self.current_epoch}: Using {current_model} model")
        
    def one_hot_encode(self,input_string):
        mapping = {
            'A': 0,
            'T': 1,
            'G': 2,
            'C': 3,
            '<eos>': 4,
            '<sep>': 5,
            '<mask>': 6,
            '<pad>': 7,
        }

        # 将整数转换为one-hot编码
        one_hot_encoded = F.one_hot(input_string, num_classes=len(mapping))

        return one_hot_encoded
    
    def pad_or_truncate_tensor(self,tensor):
        target_length = 1024
        padding_value = [0, 0, 0, 0, 0, 0, 0, 0, 1]
        # 如果张量的长度小于目标长度，那么补齐它
        if tensor.size(0) < target_length:
            padding_length = target_length - tensor.size(0)
            padding_tensor = torch.tensor(padding_value).repeat(padding_length, 1).to(device)
            tensor = torch.cat([tensor, padding_tensor], dim=0)
        # 如果张量的长度大于目标长度，那么截断它
        elif tensor.size(0) > target_length:
            tensor = tensor[:target_length]

        return tensor
    def forward(self, x,y,z):
        embeddings_h = None
        embeddings_l = None
        embeddings_g = None
        device = torch.device("cuda:0")
        for i in range(0,len(x)):
            token = x[i].unsqueeze(0)
            #print(token,self.esm.device)
            if embeddings_h is None:               
                embeddings_h = self.esm(token).embeddings
            else:
                abitembeddings_h = self.esm(token).embeddings
                embeddings_h = torch.cat((embeddings_h, abitembeddings_h), dim=0)
        out_heavy = embeddings_h
        out_heavy = out_heavy.permute(0, 2, 1)#batch representation length
        out_heavy = self.layer1(out_heavy)
        out_heavy = self.layer2(out_heavy)
        for i in range(0,len(y)):
            token = y[i].unsqueeze(0)
            #print(token,self.esm.device)
            if embeddings_l is None:               
                embeddings_l = self.esm(token).embeddings
            else:
                abitembeddings_l = self.esm(token).embeddings
                embeddings_l = torch.cat((embeddings_l, abitembeddings_l), dim=0)
        out_light = embeddings_l
        out_light = out_light.permute(0, 2, 1)
        out_light = self.light_layer1(out_light)
        out_light = self.light_layer2(out_light)
        for i in range(0,len(z)):
            token = z[i].unsqueeze(0)
            #print(token,self.esm.device)
            if embeddings_g is None:               
                embeddings_g = self.esm_antigen(token).embeddings
            else:
                abitembeddings_g = self.esm_antigen(token).embeddings
                embeddings_g = torch.cat((embeddings_g, abitembeddings_g), dim=0)
        out_antigen = embeddings_g
        out_antigen = out_antigen.permute(0, 2, 1)
        out_antigen = self.antigen_layer1(out_antigen)
        out_antigen = self.antigen_layer2(out_antigen)
        out_heavy = out_heavy.permute(2, 0, 1)  # Change the shape to (seq_len, batch, embed_dim)
        out_heavy, attn_weights_h = self.multihead_attention(out_heavy, out_heavy, out_heavy)
        out_heavy = out_heavy.permute(1, 2, 0)  # Change the shape back to (batch, embed_dim, seq_len)
        out_light = out_light.permute(2, 0, 1)  # Change the shape to (seq_len, batch, embed_dim)
        out_light, attn_weights_l = self.multihead_attentio_light(out_light, out_light, out_light)
        out_light = out_light.permute(1, 2, 0)  # Change the shape back to (batch, embed_dim, seq_len)
        out_antigen = out_antigen.permute(2, 0, 1)  # Change the shape to (seq_len, batch, embed_dim)
        out_antigen, attn_weights_g = self.multihead_attentio_antigen(out_antigen, out_antigen, out_antigen)
        out_antigen = out_antigen.permute(1, 2, 0)  # Change the shape back to (batch, embed_dim, seq_len)
        out = torch.cat((out_heavy, out_light,out_antigen), dim=2)
        out = out.reshape(out.size(0), -1)
        if self.fc1 is None:
            self.fc1 = nn.Linear(out.size(1), 1000).to(out.device)
        out = self.drop_out(out)
        out = self.fc1(out)
        out = self.drop_out(out)
        out = self.fc2(out)
        return out

    def training_step(self, batch, batch_idx):
        self.train()
        x, y,z,label = batch
        x_hat = self.forward(x,y,z)
        # 使用示例
        criterion = HuberLoss(delta=1.0)
        '''
        # 或
        criterion = LogCoshLoss()
        # 或
        criterion = CustomProteinLoss(mse_weight=0.7, mae_weight=0.3)
        '''
            # 调整label维度以匹配x_hat
        label = label.view(-1, 1)  # [100] -> [100, 1]
        loss = criterion(x_hat, label)
        # 计算各项评价指标
        mse = self.mse(x_hat, label)
        mae = self.mae(x_hat, label)
        rmse = self.rmse(x_hat, label)
        r2score = self.r2score(x_hat, label)
        pearson = self.pearson(x_hat, label)
        
        # 记录训练过程中的指标
        self.log('train_loss', loss,  on_epoch=True, prog_bar=True)
        self.log('train_mse', mse, on_epoch=True, prog_bar=True)
        self.log('train_mae', mae, on_epoch=True, prog_bar=True)
        self.log('train_rmse', rmse, on_epoch=True, prog_bar=True)
        self.log('train_r2', r2score, on_epoch=True, prog_bar=True)
        self.log('train_pearson', pearson, on_epoch=True, prog_bar=True)
        return loss

    def on_train_epoch_start(self):
        if self.current_epoch % 2 == 0:
            # 偶数轮：冻结 esm_antigen，解冻 esm
            for param in self.esm_antigen.parameters():
                param.requires_grad = False
            for param in self.esm.parameters():
                param.requires_grad = True
            current_model = "esm (esm_antigen frozen)"
            model_id = 0  # 使用数值表示当前模型
        else:
            # 奇数轮：冻结 esm，解冻 esm_antigen
            for param in self.esm.parameters():
                param.requires_grad = False
            for param in self.esm_antigen.parameters():
                param.requires_grad = True
            current_model = "esm_antigen (esm frozen)"
            model_id = 1  # 使用数值表示当前模型
        
        # 使用数值记录当前模型类型
        self.log("current_model_id", float(model_id))
        print(f"Epoch {self.current_epoch}: Training {current_model}")
        
        # 输出模型参数状态以验证
        esm_trainable = sum(p.numel() for p in self.esm.parameters() if p.requires_grad)
        antigen_trainable = sum(p.numel() for p in self.esm_antigen.parameters() if p.requires_grad)
        print(f"ESM 参数: {esm_trainable:,} trainable")
        print(f"ESM_antigen 参数: {antigen_trainable:,} trainable")
    
    def test_step(self, batch, batch_idx):
        self.eval()
        x, y,z,label = batch
        x_hat = self.forward(x,y,z)
        criterion = HuberLoss(delta=1.0)
            # 调整label维度以匹配x_hat
        label = label.view(-1, 1)  # [100] -> [100, 1]
        loss = criterion(x_hat, label)
    
        # 计算各项评价指标
        mse = self.mse(x_hat, label)
        mae = self.mae(x_hat, label)
        rmse = self.rmse(x_hat, label)
        r2score = self.r2score(x_hat, label)
        pearson = self.pearson(x_hat, label)
        
        # 记录训练过程中的指标
        self.log('test_loss', loss, on_epoch=True, prog_bar=True)
        self.log('test_mse', mse, on_epoch=True, prog_bar=True)
        self.log('test_mae', mae,  on_epoch=True, prog_bar=True)
        self.log('test_rmse', rmse,  on_epoch=True, prog_bar=True)
        self.log('test_r2', r2score, on_epoch=True, prog_bar=True)
        self.log('test_pearson', pearson,  on_epoch=True, prog_bar=True)
        return {"test_loss": loss, "test_r2": r2score, "test_pearson": pearson}

    '''
    def validation_step(self, batch, batch_idx):
        self.eval()
        x, y, z, label = batch
        x_hat = self.forward(x, y, z)
        label = label.view(-1, 1)
        loss = LogCoshLoss()(x_hat, label)  # 与训练一致
        pearson = self.pearson(x_hat, label)
        # 记录验证集指标
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_mse', self.mse(x_hat, label), on_epoch=True)
        self.log("val_pearson", pearson, on_epoch=True, prog_bar=True)
        return pearson
    '''
    def validation_step(self, batch, batch_idx):
        self.eval()
        x, y, z, label = batch
        x_hat = self.forward(x, y, z)
        label = label.view(-1, 1)
        loss = LogCoshLoss()(x_hat, label)  # 与训练一致
        pearson = self.pearson(x_hat, label)
        
        # 记录验证集指标
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_mse', self.mse(x_hat, label), on_epoch=True)
        self.log("val_pearson", pearson, on_epoch=True, prog_bar=True)
        
        # 打印预测值与实际值 (添加此部分)
        if batch_idx == 0:  # 只打印第一个batch，避免过多输出
            print("\n===== 验证集预测与实际值比较 =====")
            for i in range(min(5, len(x_hat))):  # 打印前5个样本
                print(f"样本 {i}: 预测值 = {x_hat[i].item():.4f}, 实际值 = {label[i].item():.4f}, 差值 = {(x_hat[i] - label[i]).item():.4f}")
            
            # 计算整个batch的统计信息
            mean_pred = x_hat.mean().item()
            mean_true = label.mean().item()
            print(f"\n批次统计: 平均预测值 = {mean_pred:.4f}, 平均实际值 = {mean_true:.4f}")
            print(f"预测值范围: [{x_hat.min().item():.4f}, {x_hat.max().item():.4f}]")
            print(f"实际值范围: [{label.min().item():.4f}, {label.max().item():.4f}]")
        
        # 返回更多信息用于epoch_end汇总
        return {"val_loss": loss, "val_pearson": pearson, 
                "pred": x_hat.detach(), "true": label.detach()}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
        steplr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size  = 1, gamma = 0.8)#每一步都进行学习率的衰减
        return {"optimizer": optimizer , "lr_scheduler": steplr_scheduler}

In [4]:
from pytorch_lightning import Trainer, loggers
import random
# 定义参数
num_features = 512  # 特征数量，也是Transformer的d_model参数
num_classes = 30  # 类别数量 因为是预测结果也是氨基酸 所以是词表大小 为30
nhead = 8  # Transformer的头的数量
num_encoder_layers = 3  # Transformer编码器的层数
num_decoder_layers = 3  # Transformer解码器的层数
learning_rate = 0.0001  # 学习率
num_epochs = 30
# 初始化模型
seed = random.randint(0, 10000)
model = ClassifierNet()
device = torch.device('cuda:0')
model = model.to(device)

csv_logger = loggers.CSVLogger('L_20_AF_lc_alldata_ESM3CNN+MUTIATTN_conlogs/')
# 初始化训练器
trainer = Trainer(max_epochs=num_epochs,logger = csv_logger,accelerator="gpu", devices=[0])

# 训练模型
#tokenized_sequence = tokenize_aacid_sequence(sequence)
trainer.fit(model, train_dataloaders=train_dataloader,val_dataloaders=val_dataloader)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA L20') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                       | Type               | Params | Mode 
---------------------------------------------------------------------------
0  | esm                        | ESMC               | 332 M  | eval 
1  | esm_antigen                | ESMC               | 332 M  | eval 
2  | layer1                     | Sequential         | 307 K  | train
3

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/public/home/ligroupprotein/.conda/envs/esm3test/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/public/home/ligroupprotein/.conda/envs/esm3test/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=87` in the `DataLoader` to improve performance.



===== 验证集预测与实际值比较 =====
样本 0: 预测值 = 0.0053, 实际值 = -8.9200, 差值 = 8.9253
样本 1: 预测值 = 0.0054, 实际值 = -10.7700, 差值 = 10.7754
样本 2: 预测值 = 0.0043, 实际值 = -10.0380, 差值 = 10.0423
样本 3: 预测值 = 0.0062, 实际值 = -13.4700, 差值 = 13.4762
样本 4: 预测值 = 0.0052, 实际值 = -12.7922, 差值 = 12.7974

批次统计: 平均预测值 = 0.0055, 平均实际值 = -10.8935
预测值范围: [0.0043, 0.0063]
实际值范围: [-14.1550, -6.3500]


/public/home/ligroupprotein/.conda/envs/esm3test/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=87` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Epoch 0: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -13.3423, 实际值 = -13.5000, 差值 = 0.1577
样本 1: 预测值 = -11.8821, 实际值 = -12.5000, 差值 = 0.6179
样本 2: 预测值 = -9.7485, 实际值 = -8.0391, 差值 = -1.7094
样本 3: 预测值 = -10.8799, 实际值 = -9.5497, 差值 = -1.3302
样本 4: 预测值 = -11.8503, 实际值 = -11.1180, 差值 = -0.7323

批次统计: 平均预测值 = -11.1840, 平均实际值 = -10.8077
预测值范围: [-13.4032, -7.4194]
实际值范围: [-13.7000, -8.0391]
Epoch 1: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -8.8504, 实际值 = -8.8391, 差值 = -0.0113
样本 1: 预测值 = -12.2710, 实际值 = -12.1702, 差值 = -0.1009
样本 2: 预测值 = -12.0196, 实际值 = -11.5600, 差值 = -0.4596
样本 3: 预测值 = -13.3894, 实际值 = -15.7200, 差值 = 2.3306
样本 4: 预测值 = -11.6424, 实际值 = -10.3700, 差值 = -1.2724

批次统计: 平均预测值 = -10.9346, 平均实际值 = -11.2518
预测值范围: [-13.3894, -8.8204]
实际值范围: [-15.7200, -7.9391]
Epoch 2: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -10.1811, 实际值 = -10.4000, 差值 = 0.2189
样本 1: 预测值 = -9.9916, 实际值 = -11.5000, 差值 = 1.5084
样本 2: 预测值 = -10.4075, 实际值 = -12.5000, 差值 = 2.0925
样本 3: 预测值 = -10.3509, 实际值 = -10.4400, 差值 = 0.0891
样本 4: 预测值 = -10.6727, 实际值 = -7.8600, 差值 = -2.8127

批次统计: 平均预测值 = -10.4453, 平均实际值 = -10.1824
预测值范围: [-13.6517, -8.9938]
实际值范围: [-14.4500, -5.1785]
Epoch 3: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -9.1521, 实际值 = -8.2413, 差值 = -0.9107
样本 1: 预测值 = -14.8838, 实际值 = -14.7400, 差值 = -0.1438
样本 2: 预测值 = -10.3893, 实际值 = -11.0688, 差值 = 0.6795
样本 3: 预测值 = -9.7171, 实际值 = -6.3500, 差值 = -3.3671
样本 4: 预测值 = -11.1364, 实际值 = -8.9200, 差值 = -2.2164

批次统计: 平均预测值 = -10.7399, 平均实际值 = -10.8160
预测值范围: [-14.8838, -8.9317]
实际值范围: [-14.9000, -5.8600]
Epoch 4: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -10.4383, 实际值 = -8.7400, 差值 = -1.6983
样本 1: 预测值 = -10.6041, 实际值 = -11.8600, 差值 = 1.2559
样本 2: 预测值 = -11.1571, 实际值 = -12.4400, 差值 = 1.2829
样本 3: 预测值 = -9.1780, 实际值 = -9.4900, 差值 = 0.3120
样本 4: 预测值 = -11.7612, 实际值 = -14.9000, 差值 = 3.1388

批次统计: 平均预测值 = -10.3670, 平均实际值 = -11.0135
预测值范围: [-11.8980, -8.6630]
实际值范围: [-14.9000, -8.7400]
Epoch 5: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -10.1246, 实际值 = -12.2700, 差值 = 2.1454
样本 1: 预测值 = -12.1361, 实际值 = -9.9100, 差值 = -2.2261
样本 2: 预测值 = -10.4974, 实际值 = -10.8400, 差值 = 0.3426
样本 3: 预测值 = -9.6527, 实际值 = -8.6500, 差值 = -1.0027
样本 4: 预测值 = -9.1789, 实际值 = -11.2800, 差值 = 2.1011

批次统计: 平均预测值 = -10.3977, 平均实际值 = -11.2016
预测值范围: [-12.1361, -8.2172]
实际值范围: [-15.4000, -8.6500]
Epoch 6: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -10.8651, 实际值 = -9.2124, 差值 = -1.6527
样本 1: 预测值 = -10.2231, 实际值 = -11.2500, 差值 = 1.0269
样本 2: 预测值 = -13.3201, 实际值 = -14.2290, 差值 = 0.9089
样本 3: 预测值 = -13.3504, 实际值 = -14.7400, 差值 = 1.3896
样本 4: 预测值 = -9.4132, 实际值 = -9.7500, 差值 = 0.3368

批次统计: 平均预测值 = -10.9711, 平均实际值 = -11.2992
预测值范围: [-13.3504, -8.8454]
实际值范围: [-14.7400, -5.7858]
Epoch 7: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -9.4107, 实际值 = -9.8400, 差值 = 0.4293
样本 1: 预测值 = -11.8020, 实际值 = -11.5000, 差值 = -0.3020
样本 2: 预测值 = -12.3423, 实际值 = -14.0000, 差值 = 1.6577
样本 3: 预测值 = -9.5700, 实际值 = -10.2436, 差值 = 0.6736
样本 4: 预测值 = -9.2942, 实际值 = -10.4300, 差值 = 1.1358

批次统计: 平均预测值 = -10.8620, 平均实际值 = -11.9124
预测值范围: [-13.2088, -8.6968]
实际值范围: [-14.6900, -9.8400]
Epoch 8: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -9.6755, 实际值 = -8.2700, 差值 = -1.4055
样本 1: 预测值 = -8.8193, 实际值 = -10.0600, 差值 = 1.2407
样本 2: 预测值 = -9.6150, 实际值 = -9.5497, 差值 = -0.0653
样本 3: 预测值 = -10.7358, 实际值 = -12.7200, 差值 = 1.9842
样本 4: 预测值 = -9.4309, 实际值 = -10.2900, 差值 = 0.8591

批次统计: 平均预测值 = -10.8597, 平均实际值 = -10.7129
预测值范围: [-15.0766, -8.8193]
实际值范围: [-14.0400, -8.2700]
Epoch 9: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -9.2774, 实际值 = -9.2500, 差值 = -0.0274
样本 1: 预测值 = -10.7938, 实际值 = -9.3600, 差值 = -1.4338
样本 2: 预测值 = -15.1986, 实际值 = -15.0067, 差值 = -0.1919
样本 3: 预测值 = -9.5700, 实际值 = -10.9698, 差值 = 1.3998
样本 4: 预测值 = -9.8430, 实际值 = -10.3400, 差值 = 0.4970

批次统计: 平均预测值 = -11.4762, 平均实际值 = -11.5344
预测值范围: [-15.1986, -9.2774]
实际值范围: [-15.0067, -9.0800]
Epoch 10: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -11.0905, 实际值 = -10.7700, 差值 = -0.3205
样本 1: 预测值 = -9.4094, 实际值 = -13.4300, 差值 = 4.0206
样本 2: 预测值 = -10.2178, 实际值 = -5.1785, 差值 = -5.0393
样本 3: 预测值 = -9.8103, 实际值 = -11.5000, 差值 = 1.6897
样本 4: 预测值 = -11.8051, 实际值 = -9.8000, 差值 = -2.0051

批次统计: 平均预测值 = -10.7302, 平均实际值 = -10.6114
预测值范围: [-13.0915, -8.7554]
实际值范围: [-14.0000, -2.9990]
Epoch 11: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -12.8763, 实际值 = -6.0900, 差值 = -6.7863
样本 1: 预测值 = -8.2759, 实际值 = -7.8372, 差值 = -0.4387
样本 2: 预测值 = -10.7934, 实际值 = -10.3700, 差值 = -0.4234
样本 3: 预测值 = -10.7697, 实际值 = -9.3700, 差值 = -1.3997
样本 4: 预测值 = -9.7840, 实际值 = -11.7000, 差值 = 1.9160

批次统计: 平均预测值 = -10.6234, 平均实际值 = -9.9848
预测值范围: [-12.8763, -8.2759]
实际值范围: [-12.7000, -6.0900]
Epoch 12: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -11.1215, 实际值 = -12.2000, 差值 = 1.0785
样本 1: 预测值 = -9.0768, 实际值 = -10.6500, 差值 = 1.5732
样本 2: 预测值 = -9.3185, 实际值 = -7.7748, 差值 = -1.5438
样本 3: 预测值 = -9.6244, 实际值 = -14.2300, 差值 = 4.6056
样本 4: 预测值 = -12.4030, 实际值 = -12.5000, 差值 = 0.0970

批次统计: 平均预测值 = -10.3030, 平均实际值 = -10.5993
预测值范围: [-13.0331, -8.1144]
实际值范围: [-14.6584, -3.0000]
Epoch 13: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -13.2897, 实际值 = -13.8800, 差值 = 0.5903
样本 1: 预测值 = -10.6638, 实际值 = -10.7000, 差值 = 0.0362
样本 2: 预测值 = -11.9010, 实际值 = -14.6584, 差值 = 2.7574
样本 3: 预测值 = -12.1815, 实际值 = -12.9000, 差值 = 0.7185
样本 4: 预测值 = -11.4875, 实际值 = -7.8600, 差值 = -3.6275

批次统计: 平均预测值 = -10.3687, 平均实际值 = -11.0821
预测值范围: [-13.2897, -8.0096]
实际值范围: [-14.6584, -5.7858]
Epoch 14: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -8.0672, 实际值 = -7.5000, 差值 = -0.5672
样本 1: 预测值 = -12.3113, 实际值 = -12.2260, 差值 = -0.0853
样本 2: 预测值 = -9.3669, 实际值 = -9.5497, 差值 = 0.1828
样本 3: 预测值 = -9.1841, 实际值 = -9.0826, 差值 = -0.1016
样本 4: 预测值 = -8.2565, 实际值 = -10.0600, 差值 = 1.8035

批次统计: 平均预测值 = -10.4373, 平均实际值 = -10.8736
预测值范围: [-13.4451, -7.4278]
实际值范围: [-15.8280, -5.0800]
Epoch 15: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -12.3386, 实际值 = -15.8280, 差值 = 3.4895
样本 1: 预测值 = -8.7844, 实际值 = -9.0068, 差值 = 0.2225
样本 2: 预测值 = -14.2180, 实际值 = -14.2990, 差值 = 0.0810
样本 3: 预测值 = -10.9036, 实际值 = -11.2500, 差值 = 0.3464
样本 4: 预测值 = -9.4481, 实际值 = -9.7400, 差值 = 0.2919

批次统计: 平均预测值 = -11.0738, 平均实际值 = -11.4030
预测值范围: [-14.2180, -8.7844]
实际值范围: [-15.8280, -2.9990]
Epoch 16: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -12.8859, 实际值 = -9.3500, 差值 = -3.5359
样本 1: 预测值 = -11.1243, 实际值 = -11.4100, 差值 = 0.2857
样本 2: 预测值 = -9.8327, 实际值 = -11.6897, 差值 = 1.8570
样本 3: 预测值 = -11.1095, 实际值 = -11.9000, 差值 = 0.7905
样本 4: 预测值 = -7.6915, 实际值 = -10.6543, 差值 = 2.9628

批次统计: 平均预测值 = -10.6549, 平均实际值 = -10.3430
预测值范围: [-14.4510, -7.6915]
实际值范围: [-14.1000, -3.0000]
Epoch 17: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -10.2522, 实际值 = -11.0924, 差值 = 0.8401
样本 1: 预测值 = -9.4280, 实际值 = -10.8500, 差值 = 1.4220
样本 2: 预测值 = -12.2004, 实际值 = -11.8900, 差值 = -0.3104
样本 3: 预测值 = -12.4756, 实际值 = -10.9965, 差值 = -1.4792
样本 4: 预测值 = -10.1257, 实际值 = -10.3400, 差值 = 0.2143

批次统计: 平均预测值 = -11.0685, 平均实际值 = -11.0888
预测值范围: [-14.8831, -8.8048]
实际值范围: [-14.2990, -9.0826]
Epoch 18: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -9.7609, 实际值 = -9.5041, 差值 = -0.2567
样本 1: 预测值 = -12.2917, 实际值 = -14.0000, 差值 = 1.7083
样本 2: 预测值 = -13.1074, 实际值 = -13.0800, 差值 = -0.0274
样本 3: 预测值 = -9.7419, 实际值 = -10.0780, 差值 = 0.3361
样本 4: 预测值 = -10.5144, 实际值 = -11.0600, 差值 = 0.5456

批次统计: 平均预测值 = -11.1380, 平均实际值 = -11.0589
预测值范围: [-14.0512, -8.6140]
实际值范围: [-14.0000, -8.8300]
Epoch 19: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -12.6405, 实际值 = -13.3600, 差值 = 0.7195
样本 1: 预测值 = -15.0105, 实际值 = -16.0800, 差值 = 1.0695
样本 2: 预测值 = -9.7617, 实际值 = -10.2500, 差值 = 0.4883
样本 3: 预测值 = -11.9479, 实际值 = -13.4022, 差值 = 1.4543
样本 4: 预测值 = -13.7742, 实际值 = -14.5600, 差值 = 0.7858

批次统计: 平均预测值 = -11.5106, 平均实际值 = -11.8941
预测值范围: [-15.0105, -9.4833]
实际值范围: [-16.0800, -6.8900]
Epoch 20: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -10.6842, 实际值 = -10.5690, 差值 = -0.1152
样本 1: 预测值 = -14.7566, 实际值 = -14.1000, 差值 = -0.6566
样本 2: 预测值 = -10.2752, 实际值 = -10.2690, 差值 = -0.0062
样本 3: 预测值 = -12.1582, 实际值 = -11.3000, 差值 = -0.8582
样本 4: 预测值 = -9.9381, 实际值 = -10.9698, 差值 = 1.0317

批次统计: 平均预测值 = -11.3165, 平均实际值 = -11.0244
预测值范围: [-14.8061, -9.1284]
实际值范围: [-14.9000, -5.8600]
Epoch 21: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -11.6118, 实际值 = -9.8000, 差值 = -1.8118
样本 1: 预测值 = -9.8756, 实际值 = -8.8100, 差值 = -1.0656
样本 2: 预测值 = -10.3276, 实际值 = -10.3000, 差值 = -0.0276
样本 3: 预测值 = -11.4258, 实际值 = -11.5000, 差值 = 0.0742
样本 4: 预测值 = -11.9228, 实际值 = -12.5700, 差值 = 0.6472

批次统计: 平均预测值 = -10.4481, 平均实际值 = -10.1086
预测值范围: [-13.8611, -7.8415]
实际值范围: [-14.2300, -6.3500]
Epoch 22: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -14.4381, 实际值 = -13.5200, 差值 = -0.9181
样本 1: 预测值 = -10.9980, 实际值 = -9.0800, 差值 = -1.9180
样本 2: 预测值 = -9.6123, 实际值 = -9.0800, 差值 = -0.5323
样本 3: 预测值 = -10.6704, 实际值 = -11.5200, 差值 = 0.8496
样本 4: 预测值 = -11.3354, 实际值 = -11.5500, 差值 = 0.2146

批次统计: 平均预测值 = -10.9516, 平均实际值 = -10.8722
预测值范围: [-14.4381, -8.0405]
实际值范围: [-14.7400, -4.1060]
Epoch 23: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -11.8354, 实际值 = -12.5980, 差值 = 0.7626
样本 1: 预测值 = -8.9982, 实际值 = -8.0900, 差值 = -0.9082
样本 2: 预测值 = -9.1303, 实际值 = -9.2200, 差值 = 0.0897
样本 3: 预测值 = -10.2521, 实际值 = -9.2690, 差值 = -0.9831
样本 4: 预测值 = -14.4322, 实际值 = -12.8800, 差值 = -1.5522

批次统计: 平均预测值 = -10.4316, 平均实际值 = -9.9158
预测值范围: [-14.4322, -6.8447]
实际值范围: [-14.5300, -3.0000]
Epoch 24: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -12.9133, 实际值 = -14.0000, 差值 = 1.0867
样本 1: 预测值 = -13.2210, 实际值 = -13.7000, 差值 = 0.4790
样本 2: 预测值 = -10.8655, 实际值 = -12.2782, 差值 = 1.4127
样本 3: 预测值 = -9.4695, 实际值 = -9.3000, 差值 = -0.1695
样本 4: 预测值 = -10.4545, 实际值 = -11.5900, 差值 = 1.1355

批次统计: 平均预测值 = -10.5838, 平均实际值 = -10.6073
预测值范围: [-13.2320, -6.7817]
实际值范围: [-14.6584, -3.0000]
Epoch 25: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -9.9608, 实际值 = -11.8000, 差值 = 1.8392
样本 1: 预测值 = -9.9150, 实际值 = -9.8400, 差值 = -0.0750
样本 2: 预测值 = -10.0090, 实际值 = -8.9200, 差值 = -1.0890
样本 3: 预测值 = -9.6987, 实际值 = -8.3201, 差值 = -1.3786
样本 4: 预测值 = -9.9639, 实际值 = -10.0800, 差值 = 0.1161

批次统计: 平均预测值 = -10.8238, 平均实际值 = -10.9121
预测值范围: [-12.7202, -9.6269]
实际值范围: [-13.4022, -7.5200]
Epoch 26: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -9.0597, 实际值 = -8.6500, 差值 = -0.4097
样本 1: 预测值 = -10.5786, 实际值 = -10.6000, 差值 = 0.0214
样本 2: 预测值 = -13.8760, 实际值 = -14.2290, 差值 = 0.3530
样本 3: 预测值 = -10.0867, 实际值 = -10.4000, 差值 = 0.3133
样本 4: 预测值 = -13.6362, 实际值 = -14.7400, 差值 = 1.1038

批次统计: 平均预测值 = -11.3356, 平均实际值 = -11.1308
预测值范围: [-14.7073, -7.8658]
实际值范围: [-15.0067, -6.3500]
Epoch 27: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -10.0277, 实际值 = -11.1000, 差值 = 1.0723
样本 1: 预测值 = -10.9140, 实际值 = -10.2000, 差值 = -0.7140
样本 2: 预测值 = -11.9927, 实际值 = -10.9965, 差值 = -0.9963
样本 3: 预测值 = -9.1680, 实际值 = -9.8200, 差值 = 0.6520
样本 4: 预测值 = -9.7550, 实际值 = -8.2612, 差值 = -1.4938

批次统计: 平均预测值 = -10.2479, 平均实际值 = -10.7224
预测值范围: [-12.1362, -9.1680]
实际值范围: [-14.6584, -8.2612]
Epoch 28: Training esm (esm_antigen frozen)
ESM 参数: 332,997,184 trainable
ESM_antigen 参数: 0 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -10.2193, 实际值 = -9.8623, 差值 = -0.3569
样本 1: 预测值 = -10.9906, 实际值 = -9.7900, 差值 = -1.2006
样本 2: 预测值 = -11.0994, 实际值 = -11.6600, 差值 = 0.5606
样本 3: 预测值 = -9.5238, 实际值 = -10.4300, 差值 = 0.9062
样本 4: 预测值 = -9.0547, 实际值 = -9.0068, 差值 = -0.0479

批次统计: 平均预测值 = -10.7034, 平均实际值 = -10.8148
预测值范围: [-12.9997, -9.0547]
实际值范围: [-14.0400, -6.7648]
Epoch 29: Training esm_antigen (esm frozen)
ESM 参数: 0 trainable
ESM_antigen 参数: 332,997,184 trainable


Validation: |          | 0/? [00:00<?, ?it/s]


===== 验证集预测与实际值比较 =====
样本 0: 预测值 = -12.1614, 实际值 = -12.5519, 差值 = 0.3906
样本 1: 预测值 = -10.8472, 实际值 = -11.2500, 差值 = 0.4028
样本 2: 预测值 = -8.9485, 实际值 = -8.3800, 差值 = -0.5685
样本 3: 预测值 = -13.8132, 实际值 = -14.5300, 差值 = 0.7168
样本 4: 预测值 = -11.3168, 实际值 = -10.9500, 差值 = -0.3668

批次统计: 平均预测值 = -11.7972, 平均实际值 = -11.9569
预测值范围: [-14.5676, -8.9485]
实际值范围: [-15.7200, -7.8200]


`Trainer.fit` stopped: `max_epochs=30` reached.


In [5]:
trainer.test(model, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/public/home/ligroupprotein/.conda/envs/esm3test/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:476: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/public/home/ligroupprotein/.conda/envs/esm3test/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=87` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.7563793683693961
        test_mae            1.1437164545059204
        test_mse             2.960632562637329
      test_pearson          0.7118431925773621
         test_r2            0.4668962591179704
        test_rmse            1.645863652229309
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.7563793683693961,
  'test_mse': 2.960632562637329,
  'test_mae': 1.1437164545059204,
  'test_rmse': 1.645863652229309,
  'test_r2': 0.4668962591179704,
  'test_pearson': 0.7118431925773621}]