In [1]:
import pandas as pd
import numpy as np
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
import torch
from tqdm import tqdm
from transformers import EsmTokenizer, EsmForMaskedLM
from peft import PeftModel
from umap import UMAP
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import fastcluster
from scipy.cluster.hierarchy import dendrogram
import shap
import os
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV, ParameterGrid
from sklearn.metrics import confusion_matrix, classification_report
from joblib.parallel import BatchCompletionCallBack

############################################
# Step 1: Read the input CSV file
############################################
df = pd.read_csv('./ABP_RF_SHAP_data/A_domain_786.csv')

#########################################################################
# 1. Data Augmentation: Add duplicate samples for substrates with < 10 occurrences
#########################################################################

# Count the occurrences of each substrate
substrate_counts = df['substrate'].value_counts()

# Identify substrates with occurrences less than 10
low_occurrence = substrate_counts[substrate_counts < 10]

# Create a copy of the original DataFrame for augmentation
df_enhanced = df.copy()
new_rows = []

# Add duplicate samples for low-frequency substrates
for substrate, count in low_occurrence.items():
    substrate_samples = df[df['substrate'] == substrate]
    copies_needed = 10 - count

    for i in range(copies_needed):
        # Randomly select a sample to duplicate
        sample_to_copy = substrate_samples.sample(1, random_state=np.random.randint(10000)).iloc[0]
        new_row = sample_to_copy.copy()
        new_row['id'] = f"{new_row['id']}_repeat_{i+1}"  # Modify ID to indicate duplication
        new_rows.append(new_row)

# Add the new rows to the augmented dataset
if new_rows:
    df_enhanced = pd.concat([df_enhanced, pd.DataFrame(new_rows)], ignore_index=True)

#########################################################################
# 2. Encode substrates as numeric values
#########################################################################

# Factorize the substrates in the augmented dataset
codes, uniques = pd.factorize(df_enhanced['substrate'])
df_enhanced['substrate_numeric'] = codes

# Create a mapping dictionary and apply it to the original dataset
mapping = dict(zip(uniques, range(len(uniques))))
df['substrate_numeric'] = df['substrate'].map(mapping)

#########################################################################
# 3. Save the augmented CSV and FASTA files
#########################################################################

# Save the augmented dataset as a CSV file
output_csv = './ABP_RF_SHAP_data/A_domain_augmented_1579.csv'
df_enhanced.to_csv(output_csv, index=False)

# Generate the augmented dataset as a FASTA file
records = [
    SeqRecord(Seq(row['ABP']), id=str(row['id']), description="")
    for _, row in df_enhanced.iterrows()
]
output_fasta = './ABP_RF_SHAP_data/A_domain_augmented_1579.fasta'
with open(output_fasta, 'w') as output_handle:
    SeqIO.write(records, output_handle, 'fasta')

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [2]:
#########################################################################
# 4. Load the ABP-ESM2 for ABP embedding
#########################################################################

# Load the model and tokenizer
model_name = '/nfs/home/9401_qinzhiwei/HJQ/1.NRPS/finetune/esm/esm2_t33_650M_UR50D'
tokenizer = EsmTokenizer.from_pretrained(model_name)
base_model = EsmForMaskedLM.from_pretrained(model_name)

# Load the LoRA fine-tuned model
lora_model = PeftModel.from_pretrained(base_model, '/nfs/home/9401_qinzhiwei/HJQ/1.NRPS/3.CPSL/pocket/PLM')
lora_model.eval()  # Set model to evaluation mode

# Configure device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lora_model.to(device)

def compute_lora_embedding(sequence):
    """
    Compute LoRA embedding for a single protein sequence.
    
    Args:
        sequence (str): Protein sequence as a string.
    
    Returns:
        torch.Tensor: Embedding tensor with shape [seq_length, hidden_dim],
                      where seq_length = len(sequence).
    """
    # Tokenize the input sequence
    inputs = tokenizer(sequence, return_tensors="pt")
    inputs = {key: value.to(device) for key, value in inputs.items()}
    
    # Perform forward pass through the model
    with torch.no_grad():
        output = lora_model.esm(**inputs).last_hidden_state[0]
    
    # Exclude CLS token and take only the actual sequence embeddings
    embedding = output[1:len(sequence)+1].cpu()
    return embedding

# Compute LoRA embeddings for all sequences in the augmented dataset
print("\nComputing LoRA embeddings for all sequences:")
lora_embeddings_1579 = []

# Iterate over all sequences in the DataFrame
for seq in tqdm(df_enhanced['ABP'], desc="Computing embeddings"):
    emb = compute_lora_embedding(seq)
    lora_embeddings_1579.append(emb)

# Convert the list of embeddings to a tensor
lora_tensor_1579 = torch.stack(lora_embeddings_1579, dim=0)  # Shape: (N, seq_length, hidden_dim)

# # Save the embeddings as a file for reuse
# torch.save(lora_tensor_1579, "lora_embeddings_1579.pt")


Computing LoRA embeddings for all sequences:


Computing embeddings: 100%|██████████| 1579/1579 [00:43<00:00, 36.05it/s]


In [2]:
# Optional: load the saved pt file
# File path to the saved tensor
embedding_file_1579 = "lora_embeddings_1579.pt"

# Load the tensor from the file
lora_tensor_1579 = torch.load(embedding_file_1579)

# Print the shape of the loaded tensor for verification
print("Loaded LoRA embeddings tensor shape:", lora_tensor_1579.shape)

Loaded LoRA embeddings tensor shape: torch.Size([1579, 27, 1280])


In [3]:
######################################################
# 5. UMAP Dimensionality Reduction: Reduce embeddings at each sequence position to 1D
######################################################

# Convert the LoRA tensor to a NumPy array
X_1579 = lora_tensor_1579.numpy()  # Shape: (1579, seq_length, hidden_dim)

# Extract dimensions
N_1579, seq_length, hidden_dim = X_1579.shape
print(f"\nFor df_enhanced: Number of sequences: {N_1579}, Sequence length: {seq_length}, Hidden dimension: {hidden_dim}")

# Initialize an array to store the UMAP-reduced embeddings
X_umap_1579 = np.zeros((N_1579, seq_length))  # Shape: (1579, seq_length)

# Perform UMAP dimensionality reduction for each sequence position
for pos in range(seq_length):
    # Initialize UMAP for 1D reduction
    umap = UMAP(n_components=1, random_state=42)
    
    # Extract embeddings at the current position for all sequences
    pos_data = X_1579[:, pos, :]  # Shape: (1579, hidden_dim)
    
    # Apply UMAP and store the reduced 1D embedding
    X_umap_1579[:, pos] = umap.fit_transform(pos_data).squeeze()  # Shape: (1579,)

print("After UMAP dimensionality reduction for df_enhanced:")
print(f"X_umap_1579 shape: {X_umap_1579.shape}")  # Shape: (1579, seq_length)


For df_enhanced: Number of sequences: 1579, Sequence length: 27, Hidden dimension: 1280
After UMAP dimensionality reduction for df_enhanced:
X_umap_1579 shape: (1579, 27)


In [5]:
###################################
# 6. Train Random Forest with GridSearchCV (5-fold cross-validation)
###################################

# Extract numeric labels for the substrate column
y = df_enhanced['substrate_numeric'].values
print("y shape:", y.shape)

# Define parameter grid for Random Forest
param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [None, 10],
    'min_samples_split': [2, 5],
    'min_samples_leaf': [1, 2],
    'bootstrap': [True, False]
}

# Initialize Random Forest classifier and GridSearchCV
rf = RandomForestClassifier(random_state=42, n_jobs=-1)
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5, scoring='accuracy', n_jobs=-1, verbose=0)

# Custom progress bar for GridSearchCV
n_candidates = len(list(ParameterGrid(param_grid)))
cv_folds = grid_search.cv  # cv=5
total_tasks = n_candidates * cv_folds

class TqdmBatchCompletionCallBack(BatchCompletionCallBack):
    def __init__(self, pbar, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pbar = pbar

    def __call__(self, *args, **kwargs):
        self.pbar.update(n=self.batch_size)
        return super().__call__(*args, **kwargs)

old_batch_callback = BatchCompletionCallBack
def batch_callback_with_tqdm(*args, **kwargs):
    return TqdmBatchCompletionCallBack(main_pbar, *args, **kwargs)
import joblib
joblib.parallel.BatchCompletionCallBack = batch_callback_with_tqdm

# Perform GridSearchCV with progress bar
print("\nStarting GridSearchCV (5-fold cross-validation)...")
with tqdm(total=total_tasks, desc="GridSearchCV Progress") as main_pbar:
    grid_search.fit(X_umap_1579, y)

# Restore original BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = old_batch_callback

# Output best parameters and accuracy
print("Best parameters:", grid_search.best_params_)
print("Best cross-validation accuracy:", grid_search.best_score_)

# Save the best model using pickle
best_rf = grid_search.best_estimator_
model_filename = "best_random_forest_model_umap.pkl"
with open(model_filename, "wb") as f:
    pickle.dump(best_rf, f)
print(f"Best model saved as: {model_filename}")

y shape: (1579,)

Starting GridSearchCV (5-fold cross-validation)...


GridSearchCV Progress: 360it [00:36,  9.78it/s]                         


Best parameters: {'bootstrap': True, 'max_depth': None, 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 200}
Best cross-validation accuracy: 0.93221418525216
Best model saved as: best_random_forest_model_umap.pkl


In [5]:
#Optional: load the ABP-RF trained in this study

# Define the filename of the saved model
model_filename = "./ABP_RF_SHAP_data/best_random_forest_model_umap_05012025.pkl"

# Load the model using pickle
with open(model_filename, "rb") as f:
    best_rf = pickle.load(f)

# Verify that the model has been loaded successfully by printing its parameters
print("Loaded best model parameters:", best_rf.get_params())

Loaded best model parameters: {'bootstrap': False, 'ccp_alpha': 0.0, 'class_weight': None, 'criterion': 'gini', 'max_depth': None, 'max_features': 'auto', 'max_leaf_nodes': None, 'max_samples': None, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'n_estimators': 200, 'n_jobs': -1, 'oob_score': False, 'random_state': 42, 'verbose': 0, 'warm_start': False}


In [7]:
###################################
# 8. SHAP Value Analysis
###################################

# Initialize SHAP TreeExplainer
explainer = shap.TreeExplainer(best_rf)
shap_values = explainer.shap_values(X_umap_1579)

# Ensure output directory exists
os.makedirs("SHAP_analysis", exist_ok=True)

# Load the augmented dataset and extract class labels
original_data = pd.read_csv('./ABP_RF_SHAP_data/A_domain_augmented_1579.csv')
actual_classes = original_data['substrate_numeric'].values

# Predict using the trained random forest
predictions = best_rf.predict(X_umap_1579)
pred_probs = best_rf.predict_proba(X_umap_1579)
num_classes = len(shap_values)

# SHAP value collection: site-level for all samples and all classes
all_shap_rows = []
for sample_idx in tqdm(range(len(X_umap_1579)), desc="Processing SHAP values"):
    pred_class = predictions[sample_idx]
    actual_class = actual_classes[sample_idx]
    pocket_sequence = original_data.loc[sample_idx, 'ABP']
    for site_idx in range(27):  # Assuming 27 sites (positions)
        site_key = f"Site_{site_idx+1}"
        aa = pocket_sequence[site_idx] if site_idx < len(pocket_sequence) else None
        for class_idx in range(num_classes):
            shap_val = shap_values[class_idx][sample_idx][site_idx]
            all_shap_rows.append({
                'Sample_ID': sample_idx,
                'Site': site_key,
                'Site_Index': site_idx + 1,
                'Amino_Acid': aa,
                'Predicted_Class': pred_class,
                'Actual_Class': actual_class,
                'Class': class_idx,
                'SHAP_Value': shap_val,
                'Abs_SHAP_Value': abs(shap_val)
            })

# Create DataFrame of all per-site SHAP values, with amino acid info
all_shap_df = pd.DataFrame(all_shap_rows)
all_shap_df.to_csv('all_samples_site_shap_analysis_with_aa.csv', index=False)
print("Saved all site-level SHAP values (with amino acid) to 'all_samples_site_shap_analysis_with_aa.csv'.")

# Calculate per-site (across all samples and all classes) importance statistics
site_importance = (
    all_shap_df
    .groupby('Site')['Abs_SHAP_Value']
    .agg(['mean', 'std'])
    .reset_index()
    .rename(columns={'mean': 'Mean_Abs_SHAP', 'std': 'Std_Abs_SHAP'})
)
site_importance['Importance_Rank'] = site_importance['Mean_Abs_SHAP'].rank(ascending=False)
site_importance = site_importance.sort_values('Importance_Rank')

# Save summary statistics
site_importance.to_csv('site_importance_all_classes.csv', index=False)
print("Saved site importance statistics to 'site_importance_all_classes.csv'.")

Processing SHAP values: 100%|██████████| 1579/1579 [00:03<00:00, 424.19it/s]


Saved all site-level SHAP values (with amino acid) to 'all_samples_site_shap_analysis_with_aa.csv'.
Saved site importance statistics to 'site_importance_all_classes.csv'.
