In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from torchvision import models
import numpy as np
import pandas as pd
import os
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, precision_recall_curve

In [5]:
# Load test data
test_dna = torch.load('../embedding/test_set_DNA_embedding_v3.pt')
test_tf = torch.load('../embedding/test_set_tf_embedding_v3.pt')
test_labels = pd.read_csv('../dataset/test_set_v3.csv')['label']


In [6]:
regprecise_dna_data = torch.load('regprecise/regprecise_DNA_embedding.pt')
regprecise_tf_data = torch.load('regprecise/regprecise_tf_embedding.pt')

In [7]:
# Convert test data to tensors
test_labels_tensor = torch.tensor(test_labels.values, dtype=torch.float32)

# If test_dna_data is a list, convert it to tensor
if isinstance(test_dna, list):
    test_feature_dim = test_dna[0].size(0)
    test_num_samples = len(test_dna)
    test_dna_tensor = torch.zeros((test_num_samples, test_feature_dim))
    for i, tensor in enumerate(test_dna):
        test_dna_tensor[i] = tensor
else:
    test_dna_tensor = test_dna

print(f"Test DNA tensor shape: {test_dna_tensor.shape}")
print(f"Test TF data shape: {test_tf.shape}")
print(f"Test labels shape: {test_labels_tensor.shape}")


Test DNA tensor shape: torch.Size([3890, 768])
Test TF data shape: torch.Size([3890, 960])
Test labels shape: torch.Size([3890])


In [10]:
import model
# Create test dataset and dataloader
test_dataset = torch.utils.data.TensorDataset(test_dna_tensor, test_tf, test_labels_tensor)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the saved model
loaded_model = model.DNAProteinClassifier()
loaded_model = torch.load('/opt/WS/WS2/human_genome/code/binding/models_v3/dna_protein_classifier_full_v3r2.pt', weights_only=False)
loaded_model.to(device)
loaded_model.eval()


DNAProteinClassifier(
  (dna_feature_extractor): Sequential(
    (0): Linear(in_features=768, out_features=768, bias=True)
    (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
  )
  (protein_feature_extractor): Sequential(
    (0): Linear(in_features=960, out_features=960, bias=True)
    (1): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
  )
  (bi_cross_attn): BiCrossAttention(
    (dna_proj): Linear(in_features=768, out_features=960, bias=True)
    (cross_attn_dna): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=960, out_features=960, bias=True)
    )
    (cross_attn_protein): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=960, out_features=960, bias=True)
    )
  )
  (pool): PoolingLayer()
  (self_attn1): SelfAttentionBlock(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamic

In [11]:
# Evaluate the model on test data
correct = 0
total = 0
predictions_all = []
true_labels_all = []

with torch.no_grad():
    for dna_batch, protein_batch, label_batch in test_loader:
        dna_batch, protein_batch, label_batch = dna_batch.to(device), protein_batch.to(device), label_batch.to(device)
        outputs = loaded_model(dna_batch, protein_batch)
        predictions = (outputs > 0.5).float()
        correct += (predictions == label_batch).sum().item()
        total += label_batch.size(0)
        
        # Store predictions and true labels for potential further analysis
        predictions_all.extend(predictions.cpu().numpy())
        true_labels_all.extend(label_batch.cpu().numpy())

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# Calculate additional metrics

# Convert lists to numpy arrays
predictions_all = np.array(predictions_all)
true_labels_all = np.array(true_labels_all)

# Print classification report
print("\nClassification Report:")
print(classification_report(true_labels_all, predictions_all))

# Print confusion matrix
cm = confusion_matrix(true_labels_all, predictions_all)
print("\nConfusion Matrix:")
print(cm)

# Calculate ROC AUC if there are both positive and negative samples
if len(np.unique(true_labels_all)) > 1:
    # Get raw probabilities for ROC AUC calculation
    raw_probs = []
    with torch.no_grad():
        for dna_batch, protein_batch, _ in test_loader:
            dna_batch, protein_batch = dna_batch.to(device), protein_batch.to(device)
            outputs = loaded_model(dna_batch, protein_batch)
            raw_probs.extend(outputs.cpu().numpy())
    
    roc_auc = roc_auc_score(true_labels_all, raw_probs)
    print(f"\nROC AUC Score: {roc_auc:.4f}")

Test Accuracy: 86.50%

Classification Report:
              precision    recall  f1-score   support

         0.0       0.89      0.90      0.90      2593
         1.0       0.80      0.79      0.80      1297

    accuracy                           0.87      3890
   macro avg       0.85      0.85      0.85      3890
weighted avg       0.86      0.87      0.86      3890


Confusion Matrix:
[[2343  250]
 [ 275 1022]]

ROC AUC Score: 0.9221


In [None]:
# 0.5 threshold
# Test Accuracy: 92.78%

# Classification Report:
#               precision    recall  f1-score   support

#          0.0       0.96      0.92      0.94      1331
#          1.0       0.88      0.94      0.91       844

#     accuracy                           0.93      2175
#    macro avg       0.92      0.93      0.92      2175
# weighted avg       0.93      0.93      0.93      2175


# Confusion Matrix:
# [[1222  109]
#  [  48  796]]

# ROC AUC Score: 0.9614

# 0.7 threshold
# Test Accuracy: 93.20%

# Classification Report:
#               precision    recall  f1-score   support

#          0.0       0.95      0.93      0.94      1331
#          1.0       0.90      0.93      0.91       844

#     accuracy                           0.93      2175
#    macro avg       0.93      0.93      0.93      2175
# weighted avg       0.93      0.93      0.93      2175


# Confusion Matrix:
# [[1242   89]
#  [  59  785]]

# ROC AUC Score: 0.9614

#------------------------------------------------------------------------------

# 0.8 threshold
# Test Accuracy: 93.01%

# Classification Report:
#               precision    recall  f1-score   support

#          0.0       0.95      0.94      0.94      1331
#          1.0       0.90      0.92      0.91       844

#     accuracy                           0.93      2175
#    macro avg       0.93      0.93      0.93      2175
# weighted avg       0.93      0.93      0.93      2175


# Confusion Matrix:
# [[1247   84]
#  [  68  776]]

# ROC AUC Score: 0.9614

In [None]:
# V3r1
# Test Accuracy: 91.93%

# Classification Report:
#               precision    recall  f1-score   support

#          0.0       0.95      0.93      0.94      2593
#          1.0       0.86      0.91      0.88      1297

#     accuracy                           0.92      3890
#    macro avg       0.91      0.92      0.91      3890
# weighted avg       0.92      0.92      0.92      3890


# Confusion Matrix:
# [[2399  194]
#  [ 120 1177]]

# ROC AUC Score: 0.9653



# V4r1
# Test Accuracy: 92.78%

# Classification Report:
#               precision    recall  f1-score   support

#          0.0       0.96      0.92      0.94      1331
#          1.0       0.88      0.94      0.91       844

#     accuracy                           0.93      2175
#    macro avg       0.92      0.93      0.92      2175
# weighted avg       0.93      0.93      0.93      2175


# Confusion Matrix:
# [[1222  109]
#  [  48  796]]

# ROC AUC Score: 0.9614

In [14]:
# Create test dataset and dataloader
regprecise_test_dataset = torch.utils.data.TensorDataset(regprecise_dna_data, regprecise_tf_data)
regprecise_test_loader = torch.utils.data.DataLoader(regprecise_test_dataset, batch_size=64, shuffle=False)

In [15]:
# Evaluate the model on regprecise_test_dataset
regprecise_predictions = []

with torch.no_grad():
    for dna_batch, protein_batch in regprecise_test_loader:
        dna_batch, protein_batch = dna_batch.to(device), protein_batch.to(device)
        outputs = loaded_model(dna_batch, protein_batch)
        regprecise_predictions.extend((outputs > 0.5).float().cpu().numpy())

# Convert predictions to numpy array
regprecise_predictions = np.array(regprecise_predictions)

#Calculate the number of positive predictions
num_positives = np.sum(regprecise_predictions)
print(f"Number of positive predictions: {num_positives}")
print(f"Percentage of positive predictions: {num_positives / len(regprecise_predictions) * 100:.2f}%")

Number of positive predictions: 16886.0
Percentage of positive predictions: 51.35%


### Finetune the model

In [19]:
tf_id = pd.read_csv('/opt/WS/WS2/human_genome/code/binding/src/regprecise/tf_id.txt', sep='\t', header=None)

In [20]:
tf_id

Unnamed: 0,0
0,YP_001800414.1
1,YP_002906236.1
2,YP_250752.1
3,YP_225825.1
4,NP_738273.1
...,...
32880,YP_003346261.1
32881,YP_002334534.1
32882,YP_001305684.1
32883,YP_001568035.1


In [24]:
assert tf_id.shape[0] == regprecise_tf_data.shape[0] == regprecise_dna_data.shape[0], "Mismatch in number of rows between tf_id and regprecise_tf_data"

In [26]:
import random
from sklearn.model_selection import train_test_split

# First, let's create indices for the split

# Set random seed for reproducibility
np.random.seed(42)
random.seed(42)

# Calculate the size of the subset for finetuning (20% of regprecise data)
regprecise_total_size = regprecise_dna_data.shape[0]
finetune_subset_size = int(regprecise_total_size * 0.2)

# Generate random indices for the finetuning subset from regprecise data
subset_indices = np.random.choice(regprecise_total_size, finetune_subset_size, replace=False)

# Create positive pairs (original pairs from the subset)
positive_pairs_indices = [] # Store original indices from regprecise_dna_data/regprecise_tf_data
for i in subset_indices:
    positive_pairs_indices.append((i, i))
positive_labels = np.ones(len(positive_pairs_indices))

# Generate negative pairs by shuffling within the subset_indices, ensuring different TFs
negative_pairs_indices = []
max_attempts = len(subset_indices) * 10  # Increased attempts
attempts = 0

tf_ids_in_subset = tf_id.iloc[subset_indices, 0].values
subset_indices_list = list(range(len(subset_indices))) # Work with local indices for easier shuffling

while len(negative_pairs_indices) < len(positive_pairs_indices) and attempts < max_attempts:
    # Pick local indices within the subset
    local_dna_idx = random.choice(subset_indices_list)
    local_tf_idx = random.choice(subset_indices_list)
    
    # Convert local indices back to original regprecise indices
    original_dna_idx = subset_indices[local_dna_idx]
    original_tf_idx = subset_indices[local_tf_idx]
    
    # Ensure the TF is different by checking tf_id and not the same pair as positive
    if original_dna_idx != original_tf_idx and tf_id.iloc[original_dna_idx, 0] != tf_id.iloc[original_tf_idx, 0]:
        negative_pairs_indices.append((original_dna_idx, original_tf_idx))
    attempts += 1

if len(negative_pairs_indices) < len(positive_pairs_indices):
    print(f"Warning: Could only generate {len(negative_pairs_indices)} negative pairs, less than {len(positive_pairs_indices)} positive pairs.")
    # Optionally, trim positive pairs to match negative pairs count for balance
    # positive_pairs_indices = positive_pairs_indices[:len(negative_pairs_indices)]
    # positive_labels = positive_labels[:len(negative_pairs_indices)]

negative_labels = np.zeros(len(negative_pairs_indices))

# Combine positive and negative pairs' original indices
all_pairs_original_indices = np.array(positive_pairs_indices + negative_pairs_indices)
all_labels = np.concatenate((positive_labels, negative_labels))

# Shuffle the data
shuffle_idx = np.random.permutation(len(all_pairs_original_indices))
all_pairs_original_indices = all_pairs_original_indices[shuffle_idx]
all_labels = all_labels[shuffle_idx]

# Create tensors for the combined dataset using original indices
finetune_dna_all_tensor = torch.zeros((len(all_pairs_original_indices), regprecise_dna_data.shape[1]))
finetune_tf_all_tensor = torch.zeros((len(all_pairs_original_indices), regprecise_tf_data.shape[1]))

for i, (dna_idx, tf_idx) in enumerate(all_pairs_original_indices):
    finetune_dna_all_tensor[i] = regprecise_dna_data[dna_idx]
    finetune_tf_all_tensor[i] = regprecise_tf_data[tf_idx]

finetune_labels_all_tensor = torch.tensor(all_labels, dtype=torch.float32)

# Split into training and validation sets (e.g., 80% train, 20% val)
val_split_ratio = 0.2
num_total_samples = len(finetune_labels_all_tensor)
num_val_samples = int(val_split_ratio * num_total_samples)
num_train_samples = num_total_samples - num_val_samples

train_indices, val_indices = train_test_split(np.arange(num_total_samples), test_size=val_split_ratio, random_state=42, stratify=all_labels)

finetune_train_dna_tensor = finetune_dna_all_tensor[train_indices]
finetune_train_tf_tensor = finetune_tf_all_tensor[train_indices]
finetune_train_labels_tensor = finetune_labels_all_tensor[train_indices]

finetune_val_dna_tensor = finetune_dna_all_tensor[val_indices]
finetune_val_tf_tensor = finetune_tf_all_tensor[val_indices]
finetune_val_labels_tensor = finetune_labels_all_tensor[val_indices]

# Create datasets and dataloaders for finetuning
finetune_train_dataset = torch.utils.data.TensorDataset(finetune_train_dna_tensor, finetune_train_tf_tensor, finetune_train_labels_tensor)
finetune_train_loader = torch.utils.data.DataLoader(finetune_train_dataset, batch_size=32, shuffle=True)

finetune_val_dataset = torch.utils.data.TensorDataset(finetune_val_dna_tensor, finetune_val_tf_tensor, finetune_val_labels_tensor)
finetune_val_loader = torch.utils.data.DataLoader(finetune_val_dataset, batch_size=32, shuffle=False)

# Print dataset statistics
print(f"Regprecise total size: {regprecise_total_size}")
print(f"Subset size for finetuning generation: {finetune_subset_size}")
print(f"Total generated pairs for finetuning: {num_total_samples}")
print(f"  Positive pairs: {len(positive_pairs_indices)}")
print(f"  Negative pairs: {len(negative_pairs_indices)}")
print(f"Finetuning training set size: {len(finetune_train_dataset)}")
print(f"Finetuning validation set size: {len(finetune_val_dataset)}")
print(f"Balance of positive/negative examples in combined set: {np.sum(all_labels)/len(all_labels)*100:.2f}% / {(1-np.sum(all_labels)/len(all_labels))*100:.2f}%")

Regprecise total size: 32885
Subset size for finetuning generation: 6577
Total generated pairs for finetuning: 13154
  Positive pairs: 6577
  Negative pairs: 6577
Finetuning training set size: 10523
Finetuning validation set size: 2631
Balance of positive/negative examples in combined set: 50.00% / 50.00%


In [33]:
# Set the model to training mode
loaded_model.train()

# Calculate how many parameters to freeze (80% of total)
all_params = list(loaded_model.parameters())
total_params = len(all_params)
params_to_freeze = int(total_params * 0.6)

# Freeze the first 80% of parameters
for i, param in enumerate(loaded_model.parameters()):
    if i < params_to_freeze:
        param.requires_grad = False
    else:
        param.requires_grad = True

# Check which parameters are trainable
trainable_params = sum(p.numel() for p in loaded_model.parameters() if p.requires_grad)
total_params_count = sum(p.numel() for p in loaded_model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params_count:.2%})")
print(f"Frozen parameters: {total_params_count - trainable_params:,} ({(total_params_count - trainable_params)/total_params_count:.2%})")


Trainable parameters: 21,192,577 (30.46%)
Frozen parameters: 48,382,848 (69.54%)


In [34]:

# Define optimizer and loss function
# Using a smaller learning rate for fine-tuning
optimizer = optim.Adam(filter(lambda p: p.requires_grad, loaded_model.parameters()), lr=0.0001)
criterion = nn.BCELoss()

# Fine-tuning loop
num_epochs = 15
best_val_accuracy = 0.0

for epoch in range(num_epochs):
    # Training
    loaded_model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for dna_batch, tf_batch, labels_batch in finetune_train_loader:
        dna_batch, tf_batch, labels_batch = dna_batch.to(device), tf_batch.to(device), labels_batch.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = loaded_model(dna_batch, tf_batch)
        loss = criterion(outputs, labels_batch)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Calculate statistics
        running_loss += loss.item() * dna_batch.size(0)
        predictions = (outputs > 0.5).float()
        correct += (predictions == labels_batch).sum().item()
        total += labels_batch.size(0)
    
    epoch_loss = running_loss / len(finetune_train_dataset)
    epoch_acc = correct / total
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")
    
    # Validation on finetuning validation set
    loaded_model.eval()
    val_correct = 0
    val_total = 0
    val_raw_probs_all = []
    val_true_labels_all = []
    
    with torch.no_grad():
        for dna_batch, tf_batch, labels_batch in finetune_val_loader:
            dna_batch, tf_batch, labels_batch = dna_batch.to(device), tf_batch.to(device), labels_batch.to(device)
            outputs = loaded_model(dna_batch, tf_batch)
            predictions = (outputs > 0.5).float()
            val_correct += (predictions == labels_batch).sum().item()
            val_total += labels_batch.size(0)
            val_raw_probs_all.extend(outputs.cpu().numpy())
            val_true_labels_all.extend(labels_batch.cpu().numpy())
    
    val_acc = val_correct / val_total
    val_roc_auc = -1.0 # Default if not calculable
    if len(np.unique(val_true_labels_all)) > 1: # Check for both classes
        val_roc_auc = roc_auc_score(val_true_labels_all, val_raw_probs_all)
    
    print(f"Validation Accuracy: {val_acc:.4f}, Validation ROC AUC: {val_roc_auc:.4f}")
    
    # Save best model based on validation accuracy
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        torch.save(loaded_model, '/opt/WS/WS2/human_genome/code/binding/models_v3/dna_protein_classifier_v3_finetuned.pt')
        print(f"Saved new best model with validation accuracy: {best_val_accuracy:.4f}")
    
    print("-" * 50)

print(f"Fine-tuning completed! Best validation accuracy: {best_val_accuracy:.4f}")

Epoch 1/15
Training Loss: 0.6340, Accuracy: 0.6335
Validation Accuracy: 0.6043, Validation ROC AUC: 0.6766
Saved new best model with validation accuracy: 0.6043
--------------------------------------------------
Epoch 2/15
Training Loss: 0.6093, Accuracy: 0.6590
Validation Accuracy: 0.6127, Validation ROC AUC: 0.6805
Saved new best model with validation accuracy: 0.6127
--------------------------------------------------
Epoch 3/15
Training Loss: 0.5862, Accuracy: 0.6807
Validation Accuracy: 0.6268, Validation ROC AUC: 0.6874
Saved new best model with validation accuracy: 0.6268
--------------------------------------------------
Epoch 4/15
Training Loss: 0.5682, Accuracy: 0.7005
Validation Accuracy: 0.6271, Validation ROC AUC: 0.6871
Saved new best model with validation accuracy: 0.6271
--------------------------------------------------
Epoch 5/15
Training Loss: 0.5492, Accuracy: 0.7172
Validation Accuracy: 0.6287, Validation ROC AUC: 0.6918
Saved new best model with validation accuracy

In [38]:
# Load the fine-tuned model
finetuned_model = torch.load('/opt/WS/WS2/human_genome/code/binding/models_v3/dna_protein_classifier_v3_finetuned.pt', weights_only=False)
finetuned_model.to(device)
finetuned_model.eval()

# Get the indices that were not used in the finetuning subset
regprecise_total_size = regprecise_dna_data.shape[0]

# Get indices not used in finetuning
test_indices = np.setdiff1d(np.arange(regprecise_total_size), subset_indices)
print(f"Number of test samples: {len(test_indices)}")

# Create test tensors for evaluation
test_dna_tensor = regprecise_dna_data[test_indices]
test_tf_tensor = regprecise_tf_data[test_indices]

# Create test dataset and dataloader
test_dataset = torch.utils.data.TensorDataset(test_dna_tensor, test_tf_tensor)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Get predictions
predictions = []
raw_scores = []

with torch.no_grad():
    for dna_batch, tf_batch in test_loader:
        dna_batch, tf_batch = dna_batch.to(device), tf_batch.to(device)
        outputs = finetuned_model(dna_batch, tf_batch)
        predictions.extend((outputs > 0.5).float().cpu().numpy())
        raw_scores.extend(outputs.cpu().numpy())

predictions = np.array(predictions)
raw_scores = np.array(raw_scores)

# Calculate the number of positive predictions
num_positives = np.sum(predictions)
print(f"Number of positive predictions: {num_positives}")
print(f"Percentage of positive predictions: {num_positives / len(predictions) * 100:.2f}%")

# Compare with original model predictions
original_model = model.DNAProteinClassifier()
original_model = torch.load('/opt/WS/WS2/human_genome/code/binding/models_v3/dna_protein_classifier_full_v3r2.pt', weights_only=False)
original_model.to(device)
original_model.eval()
original_model.eval()

original_predictions = []
original_raw_scores = []

with torch.no_grad():
    for dna_batch, tf_batch in test_loader:
        dna_batch, tf_batch = dna_batch.to(device), tf_batch.to(device)
        outputs = original_model(dna_batch, tf_batch)
        original_predictions.extend((outputs > 0.5).float().cpu().numpy())
        original_raw_scores.extend(outputs.cpu().numpy())

original_predictions = np.array(original_predictions)
original_raw_scores = np.array(original_raw_scores)

# Compare overall statistics
orig_positives = np.sum(original_predictions)
print("\nComparison between original and finetuned models:")
print(f"Original model positive predictions: {orig_positives} ({orig_positives / len(original_predictions) * 100:.2f}%)")
print(f"Finetuned model positive predictions: {num_positives} ({num_positives / len(predictions) * 100:.2f}%)")

# Calculate prediction changes
changed_predictions = np.sum(original_predictions != predictions)
print(f"Number of predictions that changed: {changed_predictions} ({changed_predictions / len(predictions) * 100:.2f}%)")

Number of test samples: 26308
Number of positive predictions: 18417.0
Percentage of positive predictions: 70.01%

Comparison between original and finetuned models:
Original model positive predictions: 13495.0 (51.30%)
Finetuned model positive predictions: 18417.0 (70.01%)
Number of predictions that changed: 11502 (43.72%)
