# Multi-Scale Gauss Linking Integral for Protein-Protein Binding Affinity Prediction

In [None]:
import os
import torch
import numpy as np
import pandas as pd
from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV

In [None]:
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
print(project_root)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

### 1. Processing dataset 



#### 1.1 Processing binding affinity labels
!!!This section needs to be rewritten since we should convert the various types of binding affinity measure all to  Gibbs free energy. See https://www.nature.com/articles/s42003-023-04866-3 Supplementary Information Supplementary Note 4. 

In [2]:
def extract_binding_affinity(tsv_file):
    """
    Reads the TSV file and returns two lists:
    - pdb_ids: list of PDB IDs
    - affinities: list of binding affinity values (ΔG_kJ/mol)
    """
    df = pd.read_csv(tsv_file, sep='\t')
    pdb_ids = df['PDB_ID'].tolist()
    affinities = df['ΔG_kJ/mol'].tolist()
    return pdb_ids, affinities

In [3]:
# Load features and binding affinity data from TSV
dir = "/home/as4272/protein_design/"
tsv_file = dir + "topology/mGLI-PP/binding_affinity.tsv"

# Extract PDB IDs and binding affinities from TSV
pdb_ids, affinities = extract_binding_affinity(tsv_file)

X = []
y = []

# Load mGLI features for each PDB ID that has binding affinity data
for i, pdb_id in enumerate(pdb_ids):
    try:
        # Load mGLI features
        features_path = f"{dir}topology/mGLI-PP/outputs/PDBBind_2020_PP/{pdb_id}_mGLI.pt"
        features = torch.load(features_path).numpy()
        
        X.append(features)
        y.append(affinities[i])  # Use ΔG directly (already in kJ/mol)
        
    except FileNotFoundError:
        print(f"Warning: mGLI features not found for {pdb_id}")
        continue
    except Exception as e:
        print(f"Error loading features for {pdb_id}: {e}")
        continue

X = np.array(X)
y = np.array(y)

# Remove samples with NaN in X or y
mask = ~(
    np.isnan(X).any(axis=1) | np.isnan(y)
)
X = X[mask]
y = y[mask]

print(f"Number of samples: {len(X)}")
print(f"Number of features: {X.shape[1] if len(X) > 0 else 0}")
print(f"Number of targets: {len(y)}")
print(f"ΔG range: {np.min(y):.2f} to {np.max(y):.2f} kJ/mol")

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

  features = torch.load(features_path).numpy()


Number of samples: 2798
Number of features: 1440
Number of targets: 2798
ΔG range: -89.60 to -3.85 kJ/mol


#### 1.2 Checking mGLI features
Question: why feature.shape is 1440?

In [None]:
# Check the shape of each _mGLI.pt file
def diagnose_embeddings(emb_dir, labels_tsv):
    """Diagnostic function to check embedding quality"""
    df = pd.read_csv(labels_tsv, sep="\t")
    pdbs = df["PDB_ID"].tolist()
    
    for pdb in pdbs[:5]:  # Check first 5
        path = os.path.join(emb_dir, f"{pdb}_mGLI.pt")
        if os.path.exists(path):
            emb = torch.load(path)
            emb_flat = emb.flatten().numpy()
            print(f"{pdb}: shape={emb.shape}, "
                  f"flat_shape={emb_flat.shape}, "
                  f"has_nan={np.isnan(emb_flat).any()}, "
                  f"has_inf={np.isinf(emb_flat).any()}, "
                  f"min={emb_flat.min():.3f}, "
                  f"max={emb_flat.max():.3f}")
        else:
            print(f"{pdb}: File not found")

diagnose_embeddings(project_root+"/src/features/mGLI", 
                   project_root+"/src/data/data_files/binding_affinity.tsv")

### 2. Training ML models for prediction

#### 2.1 MLP

In [None]:
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
import numpy as np

# Train MLP regressor
print("Training MLP regressor...")
mlp = MLPRegressor(hidden_layer_sizes=(100, 50), activation='relu', max_iter=1000, random_state=42)
mlp.fit(X_train, y_train)

# Cross-validation on training set
print("\nPerforming cross-validation...")
cv_scores_mse = cross_val_score(mlp, X_train, y_train, cv=5, scoring='neg_mean_squared_error')
cv_scores_r2 = cross_val_score(mlp, X_train, y_train, cv=5, scoring='r2')
cv_scores_mae = cross_val_score(mlp, X_train, y_train, cv=5, scoring='neg_mean_absolute_error')

print(f"Cross-validation Results (5-fold):")
print(f"MSE: {-cv_scores_mse.mean():.3f} ± {cv_scores_mse.std():.3f}")
print(f"MAE: {-cv_scores_mae.mean():.3f} ± {cv_scores_mae.std():.3f}")
print(f"R²: {cv_scores_r2.mean():.3f} ± {cv_scores_r2.std():.3f}")

# Predict and evaluate
y_pred_train = mlp.predict(X_train)
y_pred_test = mlp.predict(X_test)

# Calculate metrics
train_mse = mean_squared_error(y_train, y_pred_train)
test_mse = mean_squared_error(y_test, y_pred_test)
train_r2 = r2_score(y_train, y_pred_train)
test_r2 = r2_score(y_test, y_pred_test)
train_mae = mean_absolute_error(y_train, y_pred_train)
test_mae = mean_absolute_error(y_test, y_pred_test)

print(f"\nMLP Performance:")
print(f"Training - MSE: {train_mse:.3f}, MAE: {train_mae:.3f}, R²: {train_r2:.3f}")
print(f"Test - MSE: {test_mse:.3f}, MAE: {test_mae:.3f}, R²: {test_r2:.3f}")

# Visualization of model performance
plt.figure(figsize=(20, 5))

# Plot 1: Predictions vs Actual for training set
plt.subplot(1, 4, 1)
plt.scatter(y_train, y_pred_train, alpha=0.6, color='blue')
plt.plot([y_train.min(), y_train.max()], [y_train.min(), y_train.max()], 'r--', lw=2)
plt.xlabel('Actual ΔG (kJ/mol)')
plt.ylabel('Predicted ΔG (kJ/mol)')
plt.title(f'Training Set (R² = {train_r2:.3f})')
plt.grid(True, alpha=0.3)

# Plot 2: Predictions vs Actual for test set
plt.subplot(1, 4, 2)
plt.scatter(y_test, y_pred_test, alpha=0.6, color='green')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
plt.xlabel('Actual ΔG (kJ/mol)')
plt.ylabel('Predicted ΔG (kJ/mol)')
plt.title(f'Test Set (R² = {test_r2:.3f})')
plt.grid(True, alpha=0.3)

# Plot 3: Residuals plot
plt.subplot(1, 4, 3)
residuals_test = y_test - y_pred_test
plt.scatter(y_pred_test, residuals_test, alpha=0.6, color='red')
plt.axhline(y=0, color='black', linestyle='--')
plt.xlabel('Predicted ΔG (kJ/mol)')
plt.ylabel('Residuals (kJ/mol)')
plt.title('Residuals Plot (Test Set)')
plt.grid(True, alpha=0.3)

# Plot 4: Cross-validation scores
plt.subplot(1, 4, 4)
x_pos = [1, 2, 3]
scores = [-cv_scores_mse.mean(), -cv_scores_mae.mean(), cv_scores_r2.mean()]
errors = [cv_scores_mse.std(), cv_scores_mae.std(), cv_scores_r2.std()]
labels = ['MSE', 'MAE', 'R²']
plt.bar(x_pos, scores, yerr=errors, capsize=5, alpha=0.7)
plt.xticks(x_pos, labels)
plt.ylabel('Score')
plt.title('Cross-validation Scores (±1 std)')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Hyperparameter tuning with cross-validation
print("\nPerforming hyperparameter tuning...")
param_grid = {
    'hidden_layer_sizes': [(50,), (100,), (100, 50), (200, 100), (100, 50, 25)],
    'activation': ['relu', 'tanh'],
    'alpha': [0.0001, 0.001, 0.01],
    'learning_rate_init': [0.001, 0.01]
}

mlp_grid = MLPRegressor(max_iter=1000, random_state=42)
grid = GridSearchCV(mlp_grid, param_grid, cv=5, scoring='neg_mean_squared_error', n_jobs=-1)
grid.fit(X_train, y_train)

print(f"Best parameters: {grid.best_params_}")
print(f"Best CV score (MSE): {-grid.best_score_:.3f}")

# Evaluate best model
best_mlp = grid.best_estimator_
y_pred_best = best_mlp.predict(X_test)
best_r2 = r2_score(y_test, y_pred_best)
best_mse = mean_squared_error(y_test, y_pred_best)
best_mae = mean_absolute_error(y_test, y_pred_best)

print(f"\nBest MLP Performance on Test Set:")
print(f"MSE: {best_mse:.3f}, MAE: {best_mae:.3f}, R²: {best_r2:.3f}")

# Cross-validation on best model
best_cv_scores_mse = cross_val_score(best_mlp, X_train, y_train, cv=5, scoring='neg_mean_squared_error')
best_cv_scores_r2 = cross_val_score(best_mlp, X_train, y_train, cv=5, scoring='r2')
best_cv_scores_mae = cross_val_score(best_mlp, X_train, y_train, cv=5, scoring='neg_mean_absolute_error')

print(f"\nBest Model Cross-validation Results (5-fold):")
print(f"MSE: {-best_cv_scores_mse.mean():.3f} ± {best_cv_scores_mse.std():.3f}")
print(f"MAE: {-best_cv_scores_mae.mean():.3f} ± {best_cv_scores_mae.std():.3f}")
print(f"R²: {best_cv_scores_r2.mean():.3f} ± {best_cv_scores_r2.std():.3f}")

Training MLP regressor...


#### 2.2 Random forest 

In [None]:
# To be added

#### 2.3 Gradient boost decision tree

In [None]:
# To be added