In [None]:

from models import classifier_v3
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
# Load the data

tf_data_v4 = torch.load('../embedding/training_set_tf_embedding_v4.pt')
dna_data_v4 = torch.load('../embedding/training_set_DNA_embedding_v4.pt')

In [3]:
tf_data_v4.shape, len(dna_data_v4), dna_data_v4[0].shape

(torch.Size([35715, 960]), 35715, torch.Size([768]))

In [4]:
tf_data_v4.to('cpu')

tensor([[ 8.7420e-03, -5.9488e-03, -7.2344e-03,  ..., -4.7327e-03,
         -1.4604e-02, -1.5792e-02],
        [-3.3671e-03,  7.8256e-03, -2.2358e-03,  ...,  1.2239e-03,
         -2.0149e-02, -1.0832e-02],
        [ 1.1129e-02, -1.5702e-02, -2.6762e-03,  ...,  1.0946e-02,
         -1.7149e-02, -1.4492e-02],
        ...,
        [-7.5461e-05, -1.2941e-02,  1.3136e-02,  ..., -1.4155e-02,
         -3.1340e-02,  1.1025e-02],
        [-7.5706e-03, -8.7105e-05, -1.4696e-02,  ..., -4.4630e-03,
         -2.2061e-02, -2.7349e-02],
        [ 4.3478e-03, -1.8013e-02, -1.6710e-02,  ...,  1.2651e-03,
         -1.6302e-02, -1.4976e-02]])

In [5]:
# Get the dimensions of the first tensor to determine shape
feature_dim = dna_data_v4[0].size(0)
num_samples = len(dna_data_v4)

# Create a new tensor with the appropriate dimensions
dna_tensor = torch.zeros((num_samples, feature_dim))

# Copy data from each tensor in the list to the new tensor
for i, tensor in enumerate(dna_data_v4):
    dna_tensor[i] = tensor

# Verify shape
print(f"DNA tensor shape: {dna_tensor.shape}")

DNA tensor shape: torch.Size([35715, 768])


In [6]:
import pandas as pd
#get labels
labels = pd.read_csv('../dataset/training_dataset_with_negatives_v4.csv')['label']

In [None]:
# data V3: new training set with pos:neg = 1:2, negtive samples are generatedby shuffling across different species

# Model R1
# === Cross Validation Mean Accuracy: 91.98% ===
# === Cross Validation Mean ROC AUC: 0.9623 ===
# === Overall ROC AUC (all folds combined): 0.9599 ===

# Model R2: removed feature extraction layer before cross attention
# === Cross Validation Mean Accuracy: 90.46% ===
# === Cross Validation Mean ROC AUC: 0.9458 ===
# === Overall ROC AUC (all folds combined): 0.9439 ===

# Data v4: clustering based on proteins to split training and test set

# === Cross Validation Mean Accuracy: 90.32% ===
# === Cross Validation Mean ROC AUC: 0.9486 ===
# === Overall ROC AUC (all folds combined): 0.9459 ===


In [8]:
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
import numpy as np
import torch
from torch.optim.lr_scheduler import LambdaLR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 交叉验证参数
num_folds = 5
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# 记录每折的性能
fold_results = []
fold_auc_scores = []
all_true_labels = []
all_pred_probs = []

# Learning rate warmup parameters
warmup_steps = 100  # Number of iterations for warmup
total_steps = 0  # Will be calculated based on epochs and batches

# 交叉验证循环
for fold, (train_idx, val_idx) in enumerate(kf.split(dna_tensor)):
    print(f"\n==== Fold {fold+1}/{num_folds} ====")

    # 划分训练集和验证集
    dna_train, dna_val = dna_tensor[train_idx], dna_tensor[val_idx]
    protein_train, protein_val = tf_data_v4[train_idx], tf_data_v4[val_idx]

    # Convert labels pandas Series to tensor
    labels_tensor = torch.tensor(labels, dtype=torch.float32)
    labels_train, labels_val = labels_tensor[train_idx], labels_tensor[val_idx]

    # 转换成 DataLoader
    train_dataset = torch.utils.data.TensorDataset(dna_train, protein_train, labels_train)
    val_dataset = torch.utils.data.TensorDataset(dna_val, protein_val, labels_val)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

    # 初始化模型
    classifier = model.DNAProteinClassifier().to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(classifier.parameters(), lr=1e-5)

    # Calculate total steps for the scheduler
    num_epochs = 20
    total_steps = len(train_loader) * num_epochs

    # Define warmup function
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return 1.0

    # Create scheduler with warmup
    scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

    # 训练模型
    for epoch in range(num_epochs):
        classifier.train()
        total_loss = 0
        current_lr = optimizer.param_groups[0]['lr']

        for dna_batch, protein_batch, label_batch in train_loader:
            dna_batch, protein_batch, label_batch = dna_batch.to(device), protein_batch.to(device), label_batch.to(device)

            optimizer.zero_grad()
            outputs = classifier(dna_batch, protein_batch)
            loss = criterion(outputs, label_batch)
            loss.backward()
            optimizer.step()
            scheduler.step()  # Update learning rate

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}, LR: {current_lr:.6f}")

    # 评估模型
    classifier.eval()
    correct, total = 0, 0
    true_labels = []
    pred_probs = []

    with torch.no_grad():
        for dna_batch, protein_batch, label_batch in val_loader:
            dna_batch, protein_batch, label_batch = dna_batch.to(device), protein_batch.to(device), label_batch.to(device)
            outputs = classifier(dna_batch, protein_batch)
            predictions = (outputs > 0.5).float()
            correct += (predictions == label_batch).sum().item()
            total += label_batch.size(0)

            # Collect true labels and prediction probabilities for ROC AUC calculation
            true_labels.extend(label_batch.cpu().numpy())
            pred_probs.extend(outputs.cpu().numpy())

    # Calculate metrics
    accuracy = correct / total
    auc_score = roc_auc_score(true_labels, pred_probs)

    print(f"Validation Accuracy (Fold {fold+1}): {accuracy * 100:.2f}%")
    print(f"ROC AUC Score (Fold {fold+1}): {auc_score:.4f}")

    fold_results.append(accuracy)
    fold_auc_scores.append(auc_score)

    # Store for overall ROC AUC
    all_true_labels.extend(true_labels)
    all_pred_probs.extend(pred_probs)

# 计算平均准确率和AUC
print(f"\n=== Cross Validation Mean Accuracy: {np.mean(fold_results) * 100:.2f}% ===")
print(f"=== Cross Validation Mean ROC AUC: {np.mean(fold_auc_scores):.4f} ===")

# Calculate overall ROC AUC from all folds combined
overall_auc = roc_auc_score(all_true_labels, all_pred_probs)
print(f"=== Overall ROC AUC (all folds combined): {overall_auc:.4f} ===")


==== Fold 1/5 ====
Epoch 1/20, Loss: 0.5747, LR: 0.000000
Epoch 2/20, Loss: 0.4428, LR: 0.000010
Epoch 3/20, Loss: 0.3861, LR: 0.000010
Epoch 4/20, Loss: 0.3518, LR: 0.000010
Epoch 5/20, Loss: 0.3266, LR: 0.000010
Epoch 6/20, Loss: 0.3067, LR: 0.000010
Epoch 7/20, Loss: 0.2848, LR: 0.000010
Epoch 8/20, Loss: 0.2649, LR: 0.000010
Epoch 9/20, Loss: 0.2501, LR: 0.000010
Epoch 10/20, Loss: 0.2341, LR: 0.000010
Epoch 11/20, Loss: 0.2175, LR: 0.000010
Epoch 12/20, Loss: 0.2049, LR: 0.000010
Epoch 13/20, Loss: 0.1946, LR: 0.000010
Epoch 14/20, Loss: 0.1837, LR: 0.000010
Epoch 15/20, Loss: 0.1732, LR: 0.000010
Epoch 16/20, Loss: 0.1627, LR: 0.000010
Epoch 17/20, Loss: 0.1528, LR: 0.000010
Epoch 18/20, Loss: 0.1445, LR: 0.000010
Epoch 19/20, Loss: 0.1416, LR: 0.000010
Epoch 20/20, Loss: 0.1288, LR: 0.000010
Validation Accuracy (Fold 1): 89.51%
ROC AUC Score (Fold 1): 0.9464

==== Fold 2/5 ====
Epoch 1/20, Loss: 0.5736, LR: 0.000000
Epoch 2/20, Loss: 0.4445, LR: 0.000010
Epoch 3/20, Loss: 0.390

In [9]:
# Save the model to a local file
model_path = '../models_v3/dna_protein_classifier_v4r1.pth'
torch.save(classifier.state_dict(), model_path)
print(f"Model saved to {model_path}")

# To save the entire model (not just state_dict)
full_model_path = '../models_v3/dna_protein_classifier_full_v4r1.pt'
torch.save(classifier, full_model_path)
print(f"Full model saved to {full_model_path}")

Model saved to ../models_v3/dna_protein_classifier_v4r1.pth
Full model saved to ../models_v3/dna_protein_classifier_full_v4r1.pt


In [None]:
classifier = model.DNAProteinClassifier()
# Print the model structure
print("Model Structure:")
print(classifier)

# Get the number of parameters in the model
total_params = sum(p.numel() for p in classifier.parameters())
trainable_params = sum(p.numel() for p in classifier.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Print the model's layer types and sizes
for name, module in classifier.named_children():
    print(f"\n{name}:")
    print(module)

Model Structure:
DNAProteinClassifier(
  (dna_feature_extractor): Sequential(
    (0): Linear(in_features=768, out_features=768, bias=True)
    (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
  )
  (protein_feature_extractor): Sequential(
    (0): Linear(in_features=960, out_features=960, bias=True)
    (1): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
  )
  (bi_cross_attn): BiCrossAttention(
    (dna_proj): Linear(in_features=768, out_features=960, bias=True)
    (cross_attn_dna): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=960, out_features=960, bias=True)
    )
    (cross_attn_protein): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=960, out_features=960, bias=True)
    )
  )
  (pool): PoolingLayer()
  (self_attn1): SelfAttentionBlock(
    (self_attn): MultiheadAttention(
      (out_