In [None]:
from transformers import AutoTokenizer, Trainer, TrainingArguments, AutoModelForSequenceClassification, AutoModel
from torch.utils.data import DataLoader, Dataset
import os
import transformers
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import requests
from tqdm.auto import tqdm
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_recall_fscore_support,precision_recall_curve, roc_curve, auc,matthews_corrcoef 
import re
from sklearn.model_selection import train_test_split, StratifiedKFold
from evaluate import load
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import copy
import seaborn as sns
import matplotlib.pyplot as plt
import random
import time
from torch.cuda.amp import autocast, GradScaler
import gc
import pickle

In [None]:
model_checkpoint = "esm1b_t33_650M_UR50S"

In [None]:
# 加载新的ESM模型并命名为 esm1b
esm1b = AutoModel.from_pretrained(model_checkpoint)

# 打印ESM模型的层次结构
print(esm1b)

# 打印ESM模型的所有参数名称
for name, param in esm1b.named_parameters():
    print(name)

In [None]:
df = pd.read_csv('Data/total40.tsv', sep='\t')
df

In [None]:
# 清除含有缺失值的行
df = df.dropna()

In [None]:
# 通过正则表达式找到每个类别的标签
dna = df['Gene Ontology (GO)'].str.contains("DNA-binding")
rna = df['Gene Ontology (GO)'].str.contains("RNA-binding")
# non = df['Gene Ontology (GO)'].str.contains("Non-binding")
non = ~dna & ~rna  # 反例中不包括核酸结合蛋白的情况


In [None]:
dna_df = df[dna & ~rna & ~non]
dna_df

In [None]:
rna_df = df[rna & ~dna & ~non]
rna_df

In [None]:
non_df = df[non & ~dna & ~rna]
non_df

In [None]:
non_sequences = non_df["Sequence"].tolist()
non_labels = [0 for protein in non_sequences]# 非核酸结合蛋白标签为0
nucleic_sequences = df[dna | rna]["Sequence"].tolist()
nucleic_labels = [1 for protein in nucleic_sequences]  # 核酸结合蛋白标签为1

In [None]:
# 第二阶段的标签分配
dna_sequences = dna_df["Sequence"].tolist()
dna_labels = [0 for protein in dna_sequences] # DNA绑定蛋白标签为0
rna_sequences = rna_df["Sequence"].tolist()
rna_labels = [1 for protein in rna_sequences]  # RNA绑定蛋白标签为1

In [None]:
# 合并序列和标签
sequences = non_sequences + nucleic_sequences  # 第一阶段的序列和标签
labels = non_labels + nucleic_labels

# 确认序列和标签数量一致
assert len(sequences) == len(labels), "序列和标签数量不匹配"

In [None]:
# 首先创建完整的数据集
data = {
    "sequence": sequences,  # 这个sequences应包含第一阶段的所有序列
    "label": labels         # labels是对应的标签，用于第一阶段的分类
}
full_data = pd.DataFrame(data)

In [None]:
# 为第二阶段准备数据
second_stage_data = {
    "sequence": dna_sequences + rna_sequences,  # 合并DNA和RNA序列
    "label": dna_labels + rna_labels            # 对应的标签为第二阶段的分类
}
second_stage_full_data = pd.DataFrame(second_stage_data)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
class ProteinSequenceDataset(Dataset):
    """蛋白质序列数据集"""
    def __init__(self, sequences, labels, tokenizer, max_length=1000, augmentation_prob=0.0):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augmentation_prob = augmentation_prob

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]

        # 数据增强只在训练时应用
        if self.augmentation_prob > 0 and random.random() < self.augmentation_prob:
            sequence = self.augment_sequence(sequence)

        # 对序列进行编码，设置 max_length 为 1000
        encoded_sequence = self.tokenizer(sequence, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)
        
        input_ids = encoded_sequence['input_ids'].squeeze(0)  # 移除批次维度
        attention_mask = encoded_sequence['attention_mask'].squeeze(0)
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": torch.tensor(label, dtype=torch.float).unsqueeze(0)  # 将标签转换为 [1] 形状的张量
        }

    def augment_sequence(self, sequence):
        """对序列进行数据增强"""
        seq_list = list(sequence)
        seq_len = len(seq_list)
        
        # 随机选择一种增强方式
        augmentation_choice = random.choice(['delete', 'swap', 'insert', 'replace'])
        
        if augmentation_choice == 'delete' and seq_len > 200:
            # 随机删除一个氨基酸，仅在序列长度大于200时进行
            del seq_list[random.randint(0, seq_len - 1)]
        
        elif augmentation_choice == 'swap' and seq_len > 1:
            # 随机交换两个氨基酸
            idx1, idx2 = random.sample(range(seq_len), 2)
            seq_list[idx1], seq_list[idx2] = seq_list[idx2], seq_list[idx1]
        
        elif augmentation_choice == 'insert' and seq_len < 1000:
            # 随机插入一个氨基酸，仅在序列长度小于1000时进行
            amino_acid = random.choice(seq_list)
            seq_list.insert(random.randint(0, seq_len), amino_acid)
        
        elif augmentation_choice == 'replace' and seq_len > 0:
            # 随机替换一个氨基酸
            idx = random.randint(0, seq_len - 1)
            seq_list[idx] = random.choice(seq_list)
        
        return ''.join(seq_list)

    def set_augmentation_prob(self, augmentation_prob):
        """设置数据增强的概率"""
        self.augmentation_prob = augmentation_prob

In [None]:
# 创建数据集和数据加载器实例
def create_data_loaders(train_data, val_data, test_data, batch_size=8, augmentation_prob=0):
    train_dataset = ProteinSequenceDataset(train_data['sequence'].tolist(), train_data['label'].tolist(), tokenizer, augmentation_prob=augmentation_prob)
    val_dataset = ProteinSequenceDataset(val_data['sequence'].tolist(), val_data['label'].tolist(), tokenizer, augmentation_prob=0)
    test_dataset = ProteinSequenceDataset(test_data['sequence'].tolist(), test_data['label'].tolist(), tokenizer, augmentation_prob=0)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    return train_loader, val_loader, test_loader

In [None]:
def train_model(model, dataloaders, optimizer, loss_function, scheduler, device, num_epochs, stage, fold, initial_augmentation_prob, patience=None):
    model.to(device)  # 确保模型在正确的设备上 
    # 检查是否使用 DataParallel 并调用 set_stage
    if isinstance(model, nn.DataParallel):
        model.module.set_stage(stage)
    else:
        model.set_stage(stage)
  # 设置模型阶段
    best_accuracy = 0.0
    best_model_state = None
    epochs_no_improve = 0

    train_loss_history = []
    val_loss_history = []
    train_acc_history = []
    val_acc_history = []

    learning_rates = []  # 记录每个epoch的学习率
    train_metrics_history = []  # 记录每个epoch的训练精度、召回率和F1值
    val_metrics_history = []  # 记录每个epoch的验证精度、召回率和F1值

    # 动态调整数据增强概率
    augmentation_prob = initial_augmentation_prob
    
    for epoch in range(num_epochs):         
        print(f"Epoch {epoch+1}/{num_epochs}")
        print("-" * 20)

         # 设置当前epoch的数据增强概率
        dataloaders['train'].dataset.set_augmentation_prob(augmentation_prob)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            all_preds = []
            all_labels = []

            for batch in dataloaders[phase]:
                inputs = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs, attention_mask)           
                    loss = loss_function(outputs, labels)

                    # 使用sigmoid获取概率值，并设置阈值0.5进行分类
                    probs = torch.sigmoid(outputs)
                    preds = (probs >= 0.5).float()

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()


                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            # 记录损失和精度
            if phase == 'train':
                train_loss_history.append(epoch_loss)
                train_acc_history.append(epoch_acc)
                precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
                train_metrics_history.append({'precision': precision, 'recall': recall, 'f1': f1})
            else:
                val_loss_history.append(epoch_loss)
                val_acc_history.append(epoch_acc)
                precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
                val_metrics_history.append({'precision': precision, 'recall': recall, 'f1': f1})

            # 打印阶段日志
            print(f"{phase.capitalize()} Stage {stage} Epoch {epoch+1}: Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

            if phase == 'val':
                if epoch_acc > best_accuracy:
                    best_accuracy = epoch_acc
                    best_model_state = model.state_dict()  # 保存最优模型状态
                    epochs_no_improve = 0
                    print(f"the best_model_state with accuracy {best_accuracy:.4f}")
                else:
                    epochs_no_improve += 1
                    if patience and epochs_no_improve >= patience:
                        print(f"Early stopping after {patience} epochs without improvement")
                        return model, best_model_state

        # 更新学习率
        if phase == 'val':
            scheduler.step(epoch_loss)
        learning_rates.append(optimizer.param_groups[0]['lr'])


        # 动态调整数据增强概率，例如随着 epoch 增加逐渐减少
        augmentation_prob *= 0.90  # 每个 epoch 后减少 10%

    print(f"Training complete with best validation accuracy: {best_accuracy:.4f}")
    
    return model, best_model_state

In [None]:
def evaluate_model(model, dataloader, device, stage, fold):
    model.eval()  # Set model to evaluation mode
    all_probs = []
    all_labels = []

    with torch.no_grad():  # Disable gradient calculation
        for batch in dataloader:
            inputs = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            outputs = model(inputs, attention_mask)
            
            # Use sigmoid to get probability values
            probs = torch.sigmoid(outputs)

            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Convert probabilities to binary predictions with a threshold of 0.5
    all_preds = (np.array(all_probs) >= 0.5).astype(float)

    # Set class names
    if stage == 1:
        target_names = ['Non-Nucleic Acid-Binding', 'Nucleic Acid-Binding']
    elif stage == 2:
        target_names = ['DNA-Binding', 'RNA-Binding']
    else:
        raise ValueError("Stage must be 1 or 2")

    # Print detailed classification report
    report = classification_report(all_labels, all_preds, target_names=target_names, output_dict=True)
    print(classification_report(all_labels, all_preds, target_names=target_names))

    # Compute and print confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    print("Confusion Matrix:")
    print(cm)

    # Extract performance metrics from classification report
    accuracy = report['accuracy']
    precision = report['macro avg']['precision']
    recall = report['macro avg']['recall']
    f1 = report['macro avg']['f1-score']
    
    # Manually calculate main performance metrics to verify
    accuracy_manual = accuracy_score(all_labels, all_preds)
    precision_manual, recall_manual, f1_manual, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')
    
    # Calculate sensitivity (SN) and specificity (SP)
    tn, fp, fn, tp = cm.ravel()
    sensitivity = recall_manual  # Recall represents sensitivity
    specificity = tn / (tn + fp)
    
    # Calculate MCC
    mcc = matthews_corrcoef(all_labels, all_preds)
    
    # Print manually calculated results for comparison
    print(f"Manual Calculations - Accuracy: {accuracy_manual:.4f}, Precision: {precision_manual:.4f}, Recall (Sensitivity): {recall_manual:.4f}, F1: {f1_manual:.4f}")
    print(f"Specificity: {specificity:.4f}, MCC: {mcc:.4f}")
    
    return accuracy, precision, recall, f1, sensitivity, specificity, mcc, all_labels, all_probs, cm

In [None]:
class MultiPathProteinClassifier(nn.Module):
    def __init__(self):
        super(MultiPathProteinClassifier, self).__init__()
        self.esm1b = AutoModel.from_pretrained(model_checkpoint)
        
        # 直接使用ESM-1b的输出进行分类
        self.classifier_stage1 = nn.Linear(1280, 1)  # 核酸结合与非核酸结合
        self.classifier_stage2 = nn.Linear(1280, 1)  # DNA与RNA
        
        # 当前训练阶段
        self.current_stage = 1

    def forward(self, input_ids, attention_mask):
        # ESM-1b 输出
        shared_output = self.esm1b(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        
        # 对ESM-1b输出取平均池化
        features = shared_output.mean(dim=1)
        
        # 根据当前阶段使用相应的分类器
        if self.current_stage == 1:
            return self.classifier_stage1(features)
        else:
            return self.classifier_stage2(features)
        
    def set_stage(self, stage):
        self.current_stage = stage

    # def freeze_layers(self, layers_to_freeze):
    #     for name, param in self.esm1b.named_parameters():
    #         if any(layer in name for layer in layers_to_freeze):
    #             param.requires_grad = False
    
    def freeze_layers(self):
        for param in self.esm1b.parameters():
            param.requires_grad = False

    def unfreeze_layers(self, layers_to_unfreeze):
        for name, param in self.esm1b.named_parameters():
            if any(layer in name for layer in layers_to_unfreeze):
                param.requires_grad = True

In [None]:
def get_optimizer(model, lr_pretrained=5e-6, lr_custom=5e-5, weight_decay_pretrained=1e-2, weight_decay_custom=3e-1):
    pretrained_decay_params = []
    pretrained_no_decay_params = []
    custom_params = []

    param_iter = model.module.named_parameters() if isinstance(model, nn.DataParallel) else model.named_parameters()

    for name, param in param_iter:
        if not param.requires_grad:
            continue
        if "esm1b" in name:
            if "bias" in name or "LayerNorm.weight" in name:
                pretrained_no_decay_params.append(param)
            else:
                pretrained_decay_params.append(param)
        else:
            custom_params.append(param)

    optimizer = optim.AdamW([
        {'params': list(filter(lambda p: p.requires_grad, pretrained_decay_params)), 'lr': lr_pretrained, 'weight_decay': weight_decay_pretrained},
        {'params': list(filter(lambda p: p.requires_grad, pretrained_no_decay_params)), 'lr': lr_pretrained, 'weight_decay': 0.0},
        {'params': list(filter(lambda p: p.requires_grad, custom_params)), 'lr': lr_custom, 'weight_decay': weight_decay_custom}
    ])

    return optimizer

In [None]:
# 检查CUDA是否可用，并使用它
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def stratified_k_fold_cross_validation_stage1(full_data, tokenizer, k=5, num_epochs=50, initial_augmentation_prob=0.5):
    # 首先将数据集划分为训练集和独立的测试集
    train_val_data_stage1, test_data_stage1 = train_test_split(full_data, test_size=0.2, random_state=42, stratify=full_data['label'])
    # 保存 test_data_stage1 到 CSV 文件
    test_data_stage1.to_csv('Data/test_data_stage1.csv', index=False)
    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)
    fold = 0

    best_accuracy = 0
    best_model_state = None

    for train_index, val_index in skf.split(train_val_data_stage1['sequence'], train_val_data_stage1['label']):
        fold += 1
        print(f"Fold {fold}/{k}")

        # 使用StratifiedKFold的索引划分训练+验证集
        train_data_stage1 = train_val_data_stage1.iloc[train_index]
        val_data_stage1 = train_val_data_stage1.iloc[val_index]

        train_loader_stage1, val_loader_stage1, test_loader_stage1 = create_data_loaders(train_data_stage1, val_data_stage1, test_data_stage1, augmentation_prob=initial_augmentation_prob)

        # 每个折叠重新初始化模型
        model = MultiPathProteinClassifier()
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model = model.to(device)

        if isinstance(model, nn.DataParallel):
            model.module.freeze_layers()
        else:
            model.freeze_layers()
        
        # 创建优化器
        optimizer = get_optimizer(model)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.6, patience=2, min_lr=1e-6)

        
        # 第一阶段训练
        if isinstance(model, nn.DataParallel):
            model.module.set_stage(1)
        else:
            model.set_stage(1)

        # 计算第一阶段的pos_weight
        num_positive1 = train_data_stage1['label'].sum()
        num_negative1 = len(train_data_stage1) - num_positive1
        if num_positive1 < num_negative1:
            pos_weight1 = torch.tensor([num_negative1 / num_positive1], dtype=torch.float32).to(device)
        else:
            pos_weight1 = torch.tensor([num_positive1 / num_negative1], dtype=torch.float32).to(device)

        # 设置第一阶段的损失函数
        loss_function = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)

        print("Training Stage 1:")
        torch.cuda.empty_cache()  # 清理未使用的显存
        model, best_model_state_fold = train_model(model, {'train': train_loader_stage1, 'val': val_loader_stage1}, optimizer, loss_function, scheduler, device, num_epochs, stage=1, fold=fold, initial_augmentation_prob=initial_augmentation_prob, patience=4)
        
        # 评估第一阶段
        print("Evaluating Stage 1:")
        accuracy, precision, recall, f1, sensitivity, specificity, mcc, fold_labels, fold_probs, cm_stage1 = evaluate_model(model, test_loader_stage1, device, stage=1, fold=fold)
        print(f"Stage 1 Evaluation - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall (Sensitivity): {recall:.4f}, F1: {f1:.4f}, Specificity: {specificity:.4f}, MCC: {mcc:.4f}")

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_state = best_model_state_fold

        del model, optimizer, scheduler, train_loader_stage1, val_loader_stage1, test_loader_stage1
        torch.cuda.empty_cache()
        gc.collect()

    return best_model_state

# 执行第一阶段的分层k折交叉验证
best_model_state_stage1 = stratified_k_fold_cross_validation_stage1(full_data, tokenizer, k=5, num_epochs=50, initial_augmentation_prob=0.5)

In [None]:
# 保存第一阶段的最佳模型状态
torch.save(best_model_state_stage1, "Model/best_model_stage1_esm1b.pth")


In [None]:
def load_model_weights(model, best_model_state_stage1):
    """
    Load weights into the model, handling the differences between single GPU and multi-GPU training states.

    Args:
        model (torch.nn.Module): The model into which the weights will be loaded.
        best_model_state_stage1 (dict): The state dictionary of the best model from stage 1.

    Raises:
        KeyError: If there is a mismatch between the keys in the state dictionary and the model.
    """
    model_state_dict = model.state_dict()
    new_state_dict = {}

    # Iterate over the items in the state dictionary from stage 1
    for k, v in best_model_state_stage1.items():
        # Check if the current key in the model's state dictionary starts with 'module.'
        if k.startswith("module.") and not next(iter(model_state_dict.keys())).startswith("module."):
            new_state_dict[k[7:]] = v
        elif not k.startswith("module.") and next(iter(model_state_dict.keys())).startswith("module."):
            new_state_dict["module." + k] = v
        else:
            new_state_dict[k] = v

    try:
        # Load the newly constructed state dictionary into the model
        model.load_state_dict(new_state_dict)
        print("Model weights loaded successfully.")
    except KeyError as e:
        print("Error in loading model weights: ", e)
        raise

    return model

In [None]:
def stratified_k_fold_cross_validation_stage2(second_stage_full_data, tokenizer, best_model_state_stage1, k=5, num_epochs=50, initial_augmentation_prob=0.5):
     # 首先将数据集划分为训练+验证集和独立的测试集
    train_val_data_stage2, test_data_stage2 = train_test_split(second_stage_full_data, test_size=0.2, random_state=42, stratify=second_stage_full_data['label'])
    # 保存 test_data_stage2 到 CSV 文件
    test_data_stage2.to_csv('Data/test_data_stage2.csv', index=False)
    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)
    fold = 0

    best_accuracy = 0
    best_model_state = None

    for train_index, val_index in skf.split(train_val_data_stage2['sequence'], train_val_data_stage2['label']):
        fold += 1
        print(f"Fold {fold}/{k}")

        # 使用StratifiedKFold的索引划分训练+验证集
        train_data_stage2 = train_val_data_stage2.iloc[train_index]
        val_data_stage2 = train_val_data_stage2.iloc[val_index]

        # 创建数据加载器
        train_loader_stage2, val_loader_stage2, test_loader_stage2 = create_data_loaders(train_data_stage2, val_data_stage2, test_data_stage2, augmentation_prob=initial_augmentation_prob)

        # 每个折叠重新初始化模型
        model = MultiPathProteinClassifier()
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model = model.to(device)


        # 加载第一阶段的最佳模型状态
        model = load_model_weights(model, best_model_state_stage1)
        
       
        if isinstance(model, nn.DataParallel):
            model.module.freeze_layers()
        else:
            model.freeze_layers()
            
        # 创建优化器
        optimizer = get_optimizer(model)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.6, patience=2, min_lr=1e-6)

        
        # 第二阶段训练
        if isinstance(model, nn.DataParallel):
            model.module.set_stage(2)
        else:
            model.set_stage(2)

        # 计算第二阶段的pos_weight
        num_positive2 = train_data_stage2['label'].sum()
        num_negative2 = len(train_data_stage2) - num_positive2
        if num_positive2 < num_negative2:
            pos_weight2 = torch.tensor([num_negative2 / num_positive2], dtype=torch.float32).to(device)
        else:
            pos_weight2 = torch.tensor([num_positive2 / num_negative2], dtype=torch.float32).to(device)

        # 设置第二阶段的损失函数
        loss_function = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)

        print("Training Stage 2:")
        torch.cuda.empty_cache()  # 清理未使用的显存
        model, best_model_state_fold = train_model(model, {'train': train_loader_stage2, 'val': val_loader_stage2}, optimizer, loss_function, scheduler, device, num_epochs, stage=2, fold=fold, initial_augmentation_prob=initial_augmentation_prob, patience=4)

        print("Evaluating Stage 2:")
        accuracy, precision, recall, f1, sensitivity, specificity, mcc, fold_labels, fold_probs, cm_stage2 = evaluate_model(model, test_loader_stage2, device, stage=2, fold=fold)
        print(f"Stage 2 Evaluation - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall (Sensitivity): {recall:.4f}, F1: {f1:.4f}, Specificity: {specificity:.4f}, MCC: {mcc:.4f}")

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_state = best_model_state_fold

        del model, optimizer, scheduler, train_loader_stage2, val_loader_stage2, test_loader_stage2
        torch.cuda.empty_cache()
        gc.collect()

    return best_model_state

# 执行第二阶段的分层k折交叉验证
best_model_state_stage2 = stratified_k_fold_cross_validation_stage2(second_stage_full_data, tokenizer, best_model_state_stage1, k=5, num_epochs=50, initial_augmentation_prob=0.5)

In [None]:
# 保存第二阶段的最佳模型状态
torch.save(best_model_state_stage2, "Model/best_model_stage2_esm1b.pth")
