In [1]:
import torch
from torch import nn
from transformers import T5EncoderModel, T5Tokenizer
import numpy as np
import re
import pandas as pd
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os
import time
import datetime
import math

## 数据集ProteinDataset部分

In [2]:
class ProteinDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq1 = self.data.iloc[idx, 0]  # 第一个蛋白质序列在第0列
        seq2 = self.data.iloc[idx, 1]  # 第二个蛋白质序列在第1列
        similarity_score = self.data.iloc[idx, 2]  # 结构相似度评分在第2列

        return seq1, seq2, similarity_score

## Model部分

In [3]:
class GRU_CNN_Block(nn.Module):
    def __init__(self, input_size=1024, hidden_size=256, num_layers=1, out_dim=512, dropout=0.1, nheads=4):
        super(GRU_CNN_Block, self).__init__()
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = T5Tokenizer.from_pretrained("../prot_t5_xl_uniref50", do_lower_case=False)
        self.t5model = T5EncoderModel.from_pretrained("../prot_t5_xl_uniref50")
        self.t5model.eval()
        self.t5model.to(self.device)

        self.hidden_size = hidden_size

        self.gru1 = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
                           batch_first=True, bidirectional=True)
        self.gru2 = nn.GRU(input_size=hidden_size * 2, hidden_size=hidden_size, num_layers=num_layers,
                           batch_first=True, bidirectional=True)

        self.ln1 = nn.LayerNorm(hidden_size * 2)
        self.ln2 = nn.LayerNorm(hidden_size * 2)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.leaky_relu1 = nn.LeakyReLU()
        self.leaky_relu2 = nn.LeakyReLU()

        self.conv3 = nn.Conv1d(in_channels=hidden_size * 2, out_channels=hidden_size * 2, kernel_size=3, padding=1)
        self.conv7 = nn.Conv1d(in_channels=hidden_size * 2, out_channels=hidden_size * 2, kernel_size=7, padding=3)

        
        self.ln_conv = nn.LayerNorm(hidden_size * 2)
        self.dropout_conv = nn.Dropout(dropout)
        self.leaky_relu3 = nn.LeakyReLU()

        
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size*2, num_heads=nheads, dropout=dropout, batch_first=True)
        self.ln_attn = nn.LayerNorm(hidden_size * 2)

        self.pooling = nn.AdaptiveAvgPool1d(1)
        self.mlp = nn.Linear(hidden_size * 2, out_dim)

        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.l1_loss = nn.L1Loss(reduction='mean')

    def featurize_prottrans(self, sequences):
        sequences = [(" ".join(seq)) for seq in sequences]
        sequences = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences]

        ids = self.tokenizer.batch_encode_plus(sequences, add_special_tokens=True, padding=True)
        input_ids = torch.tensor(ids['input_ids']).to(self.device)
        attention_mask = torch.tensor(ids['attention_mask']).to(self.device)

        with torch.no_grad():
            embedding = self.t5model(input_ids=input_ids, attention_mask=attention_mask)

        embedding = embedding.last_hidden_state.cpu().numpy()

        features = []
        for seq_num in range(len(embedding)):
            seq_len = (attention_mask[seq_num] == 1).sum()
            seq_emd = embedding[seq_num][:seq_len - 1]
            features.append(seq_emd)

        max_len = max(feat.shape[0] for feat in features)
        features_padded = np.zeros((len(features), max_len, features[0].shape[1]))

        for i, feat in enumerate(features):
            features_padded[i, :feat.shape[0], :] = feat

        prottrans_embedding = torch.tensor(features_padded, dtype=torch.float32).to(self.device)
        return prottrans_embedding

    def forward(self, seq1, seq2):
        seq1_encoded = self.featurize_prottrans(seq1)
        seq2_encoded = self.featurize_prottrans(seq2)

        # 第一个GRU + LN + Dropout + 激活
        x1, _ = self.gru1(seq1_encoded)
        x1 = self.ln1(x1)
        x1 = self.dropout1(x1)
        x1 = self.leaky_relu1(x1)

        # 第二个GRU + LN + Dropout + 激活
        x1, _ = self.gru2(x1)
        x1 = self.ln2(x1)
        x1 = self.dropout2(x1)
        x1 = self.leaky_relu2(x1)

        # 简单的自注意力机制
        x1_attn, _ = self.attention(x1, x1, x1)  # Q=K=V=x1
        x1 = x1 + x1_attn  # 残差连接
        x1 = self.ln_attn(x1)

        # 卷积提取局部特征 + 残差连接
        residual = x1
        conv3_output = self.conv3(x1.transpose(1, 2)).transpose(1, 2)
        conv7_output = self.conv7(x1.transpose(1, 2)).transpose(1, 2)
        x1 = conv3_output + conv7_output
        x1 = self.ln_conv(x1)
        x1 = self.dropout_conv(x1)
        x1 = self.leaky_relu3(x1)
        x1 = x1 + residual

        # 池化 + MLP映射
        x1 = self.pooling(x1.transpose(1, 2)).squeeze(2)
        x1 = self.mlp(x1)

        # 对seq2同样处理
        x2, _ = self.gru1(seq2_encoded)
        x2 = self.ln1(x2)
        x2 = self.dropout1(x2)
        x2 = self.leaky_relu1(x2)

        x2, _ = self.gru2(x2)
        x2 = self.ln2(x2)
        x2 = self.dropout2(x2)
        x2 = self.leaky_relu2(x2)

        x2_attn, _ = self.attention(x2, x2, x2)
        x2 = x2 + x2_attn
        x2 = self.ln_attn(x2)

        residual = x2
        conv3_output = self.conv3(x2.transpose(1, 2)).transpose(1, 2)
        conv7_output = self.conv7(x2.transpose(1, 2)).transpose(1, 2)
        x2 = conv3_output + conv7_output
        x2 = self.ln_conv(x2)
        x2 = self.dropout_conv(x2)
        x2 = self.leaky_relu3(x2)
        x2 = x2 + residual

        x2 = self.pooling(x2.transpose(1, 2)).squeeze(2)
        x2 = self.mlp(x2)

        return x1, x2

    def distance_loss(self, output_seq1, output_seq2, tm_score):
        dist_seq = self.cos(output_seq1, output_seq2)
        dist_tm = self.l1_loss(dist_seq.unsqueeze(0), tm_score.float().unsqueeze(0))
        return dist_tm

## 工具函数

In [4]:
def format_time(seconds):
    seconds = math.ceil(seconds)
    delta = datetime.timedelta(seconds=seconds)
    return str(delta)

## 训练函数

In [5]:
def train_epoch(model, dataloader, optimizer, device, writer, epoch, estimated_step_time, scheduler):
    model.train()
    total_loss = 0
    start_time = time.time()
    for batch_idx, batch in enumerate(dataloader):
        batch_start_time = time.time()

        optimizer.zero_grad()
        seq1_batch, seq2_batch, similarity_scores = batch
        seq1_batch, seq2_batch, similarity_scores = seq1_batch, seq2_batch, similarity_scores.to(device)
        out_seq1, out_seq2 = model(seq1_batch, seq2_batch)
        loss = model.distance_loss(out_seq1, out_seq2, similarity_scores)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        print("train loss:", loss.item())

        # Log training loss
        writer.add_scalar('Loss/train', loss.item(), epoch * len(dataloader) + batch_idx)

        batch_end_time = time.time()
        batch_time = batch_end_time - batch_start_time
        print(f"Epoch {epoch + 1}, Step {batch_idx + 1}/{len(dataloader)}, Step Time: {format_time(batch_time)}")

        if batch_idx == 0:
            estimated_step_time = batch_time

        if batch_idx % 5000 == 0:
            scheduler.step(loss)

        avg_step_time = (batch_end_time - start_time) / (batch_idx + 1)
        estimated_remaining_time = avg_step_time * (len(dataloader) - batch_idx - 1)
        print(f"Estimated time remaining for this epoch: {format_time(estimated_remaining_time)}")


    end_time = time.time()
    epoch_time = end_time - start_time
    print(f"Epoch {epoch + 1} training time: {format_time(epoch_time)}")

    return total_loss / len(dataloader), epoch_time, estimated_step_time


def validate_epoch(model, dataloader, device, writer, epoch, estimated_step_time):
    model.eval()
    total_loss = 0
    start_time = time.time()
    for batch_idx, batch in enumerate(dataloader):
        batch_start_time = time.time()

        with torch.no_grad():
            seq1_batch, seq2_batch, similarity_scores = batch
            seq1_batch, seq2_batch, similarity_scores = seq1_batch, seq2_batch, similarity_scores.to(device)
            out_seq1, out_seq2 = model(seq1_batch, seq2_batch)
            loss = model.distance_loss(out_seq1, out_seq2, similarity_scores)
            total_loss += loss.item()
            print("val loss:", loss.item())

            # Log validation loss
            writer.add_scalar('Loss/val', loss.item(), epoch * len(dataloader) + batch_idx)

        batch_end_time = time.time()
        batch_time = batch_end_time - batch_start_time
        print(f"Epoch {epoch + 1}, Step {batch_idx + 1}/{len(dataloader)}, Step Time: {format_time(batch_time)}")

        if batch_idx == 0:
            estimated_step_time = batch_time

        avg_step_time = (batch_end_time - start_time) / (batch_idx + 1)
        estimated_remaining_time = avg_step_time * (len(dataloader) - batch_idx - 1)
        print(f"Estimated time remaining for this epoch: {format_time(estimated_remaining_time)}")

    end_time = time.time()
    epoch_time = end_time - start_time
    print(f"Epoch {epoch + 1} validation time: {format_time(epoch_time)}")

    return total_loss / len(dataloader), epoch_time, estimated_step_time


## 训练主函数

In [6]:
def save_checkpoint(model, optimizer, epoch, checkpoint_dir, best_loss, num_checkpoints=1):
    checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch}_loss_{best_loss:.4f}.pth')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': best_loss,
    }, checkpoint_path)

    # Manage number of checkpoints
    checkpoints = sorted(os.listdir(checkpoint_dir), key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)))
    while len(checkpoints) > num_checkpoints:
        os.remove(os.path.join(checkpoint_dir, checkpoints.pop(0)))


In [7]:
def main():
    # 加载数据集
    csv_file = '../data/data.csv'
    dataset = ProteinDataset(csv_file)

    # 划分训练集和验证集
    train_size = int(0.95 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

    # 加载模型
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GRU_CNN_Block().to(device)
    optimizer = Adam(model.parameters(), lr=0.0001)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

    # TensorBoard
    writer = SummaryWriter(log_dir='../../tf-logs/Training')


    # 检查点目录
    checkpoint_dir = 'checkpoints/'
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # 训练模型
    num_epochs = 2
    best_val_loss = float('inf')
    total_train_time = 0
    total_val_time = 0

    # 初始估算 step 时间
    estimated_step_time = 0

    for epoch in range(num_epochs):
        train_loss, train_time, estimated_step_time = train_epoch(model, train_dataloader, optimizer, device, writer,
                                                                  epoch, estimated_step_time, scheduler)
        val_loss, val_time, estimated_step_time = validate_epoch(model, val_dataloader, device, writer, epoch,
                                                                 estimated_step_time)

        total_train_time += train_time
        total_val_time += val_time
        avg_train_time = total_train_time / (epoch + 1)
        avg_val_time = total_val_time / (epoch + 1)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        print(
            f"Estimated time remaining: {format_time(avg_train_time * (num_epochs - epoch - 1) + avg_val_time * (num_epochs - epoch - 1))}")

        # 保存最好的模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(model, optimizer, epoch + 1, checkpoint_dir, best_val_loss)

    # 保存最终模型
    torch.save(model.state_dict(), 'final_model.pth')
    writer.close()
    print("Training completed successfully!")

## 启动训练

In [None]:
if __name__ == '__main__':
    main()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
  return self.fget.__get__(instance, owner)()
