In [None]:
import math
import pandas as pd
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import StratifiedKFold
from xgboost import XGBRegressor
import numpy as np
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

## Metric Calculation Utilities


In [None]:
def get_aupr(pre, rec):
    pr_value = 0.0
    for ii in range(len(rec[:-1])):
        x_r, x_l = rec[ii], rec[ii+1]
        y_t, y_b = pre[ii], pre[ii+1]
        tempo = abs(x_r - x_l) * (y_t + y_b) * 0.5
        pr_value += tempo
    return pr_value

from sklearn.metrics import confusion_matrix

from sklearn.metrics import accuracy_score

def scores_regression(y_true, y_pred):
    # Round predictions to nearest class (0–4)
    y_pred_class = np.clip(np.round(y_pred).astype(int), 0, 4)

    # Core regression metrics
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    acc = accuracy_score(y_true, y_pred_class)
    cm = confusion_matrix(y_true, y_pred_class, labels=[0, 1, 2, 3, 4])

    return [mse, rmse, mae, r2, acc], cm

# Data Loader

In [None]:
def construct_feature_matrices2(lociembeddings_path, rbpembeddings_path, interactions):
    """
     Constructs feature matrices for host-phage interaction prediction using ESM-2 embeddings.
     It reads RBP and loci embeddings, filters them based on valid entries in the interaction matrix,
     and combines them to form feature vectors representing phage-host pairs.

    Parameters:
    - lociembeddings_path: Path to loci embeddings CSV
    - rbpembeddings_path: Path to RBP embeddings CSV

    Returns:
    - features_esm2, labels, groups_loci, groups_phage
    """

    RBP_embeddings = pd.read_csv(rbpembeddings_path)
    loci_embeddings = pd.read_csv(lociembeddings_path)

    # Get valid phages and loci from interactions
    valid_phages = interactions.columns.tolist()
    valid_loci = interactions.index.tolist()

    # Filter embeddings to only use valid entries
    RBP_embeddings = RBP_embeddings[RBP_embeddings['phage_ID'].str.replace('.fna', '').isin(valid_phages)]
    loci_embeddings = loci_embeddings[loci_embeddings['accession'].isin(valid_loci)]
    print(len(RBP_embeddings), len(loci_embeddings))

    # Construct multi-RBP representations
    multi_embeddings = []
    names = []
    for phage_id in sorted(set(RBP_embeddings['phage_ID'])):
        phage_key = phage_id.split('.')[0]
        if phage_key in valid_phages:
            rbp_embeddings = RBP_embeddings.iloc[:, 2:][RBP_embeddings['phage_ID'] == phage_id]
            multi_embedding = np.mean(np.asarray(rbp_embeddings), axis=0)
            names.append(phage_id)
            multi_embeddings.append(multi_embedding)

    multiRBP_embeddings = pd.concat([pd.DataFrame({'phage_ID': names}), pd.DataFrame(multi_embeddings)], axis=1)

    # Build features
    features_lan = []
    labels = []
    groups_loci = []
    groups_phage = []

    for i, accession in enumerate(loci_embeddings['accession']):
        for j, phage_id in enumerate(multiRBP_embeddings['phage_ID']):
            phage_key = phage_id.split('.')[0]
            if accession in interactions.index and phage_key in interactions.columns:
                interaction = interactions.loc[accession][phage_key]
                if not math.isnan(interaction):
                    combined = pd.concat([loci_embeddings.iloc[i, 1:], multiRBP_embeddings.iloc[j, 1:]])
                    features_lan.append(combined.to_numpy(dtype=np.float32))  # Ensures numeric array
                    labels.append(interaction)
                    groups_loci.append(i)
                    groups_phage.append(j)

    features_lan = np.asarray(features_lan)
    print("Dimensions match?", features_lan.shape[1] == (loci_embeddings.shape[1] + multiRBP_embeddings.shape[1] - 2))
    
    features_lan = np.array(features_lan, dtype=np.float32)
    labels = np.array(labels, dtype=np.float32)

    return features_lan, labels, groups_loci, groups_phage

In [None]:
# Escherichia Dataset
general_output_path = "esm-features_escherichia"

loci_embeddings_path = "esm-features_escherichia/esm2_embeddings_lociJonas.csv"
rbp_embeddings_path = "esm-features_escherichia/esm2_embeddings_rbpJonas2.csv"
interactions = pd.read_csv("esm-features_escherichia/ordinal_interaction_matrix.csv", index_col=0, sep=";")

features_esm2, labels, groups_loci, groups_phage = construct_feature_matrices2(loci_embeddings_path, rbp_embeddings_path, interactions)

# Model Training and Evaluation

## Random Forest Regressor

In [None]:
# Tune hyperparameters on the entire dataset
param_grid = {
    'n_estimators': [100, 150],
    'max_depth': [12, 10],
    'min_samples_leaf': [8, 10],
}

base_model = RandomForestRegressor(random_state=42)
grid_search = GridSearchCV(base_model, param_grid, cv=3, n_jobs=-1, scoring='neg_mean_squared_error')
grid_search.fit(features_esm2, labels)

best_params = grid_search.best_params_
print("Best Hyperparameters:", best_params)

In [None]:
labels = np.asarray(labels)
features_esm2 = np.asarray(features_esm2)

kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

results_all = []
conf_matrices = []
models = []

pbar = tqdm(total=kf.get_n_splits(features_esm2, labels))

for train_idx, test_idx in kf.split(features_esm2, labels):
    X_train, X_test = features_esm2[train_idx], features_esm2[test_idx]
    y_train, y_test = labels[train_idx], labels[test_idx]

    best_model = RandomForestRegressor(**best_params, n_jobs=-1, random_state=42)
    best_model.fit(X_train, y_train)
    models.append(best_model)

    y_pred = best_model.predict(X_test)

    try:
        fold_scores, cm = scores_regression(y_test, y_pred)
    except Exception as e:
        print(f"Error on fold: {e}")
        fold_scores = [np.nan] * 3
        cm = np.zeros((5, 5), dtype=int)

    results_all.append(fold_scores)
    conf_matrices.append(cm)
    pbar.update(1)

pbar.close()

### Metrics Across Folds


In [None]:
metric_names = ["MSE", "RMSE", "MAE", "R2", "ACC"]

results_array = np.array(results_all)
results_df = pd.DataFrame(results_array, columns=metric_names)
results_df.index = [f"Fold {i+1}" for i in range(len(results_all))]

# Print fold-wise results
print("=== Fold-wise Performance ===")
display(results_df)

# Summary stats
print("\n=== Mean ===")
display(results_df.mean())

print("\n=== Std Dev ===")
display(results_df.std())

In [None]:
mse_values = results_array[:, 0]

plt.figure(figsize=(8, 5))
plt.plot(range(1, len(mse_values) + 1), mse_values, marker='o', linestyle='-', color='blue', label='MSE per fold')
plt.axhline(np.mean(mse_values), color='red', linestyle='--', label='Mean MSE')
plt.title('Mean Squared Error (MSE) Across Folds')
plt.xlabel('Fold')
plt.ylabel('MSE')
plt.xticks(range(1, len(mse_values) + 1))
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(12, 6))
plt.boxplot(results_array, tick_labels=metric_names, patch_artist=True,
            boxprops=dict(facecolor='lightgreen', color='black'),
            medianprops=dict(color='darkred'),
            whiskerprops=dict(color='black'))
plt.title("Random Forest Regressor Cross-Fold Metric Distribution")
plt.grid()
plt.show()

In [None]:
#Sum all confusion matrices
avg_cm = np.sum(conf_matrices, axis=0)
print(avg_cm)

#Plot the averaged confusion matrix
plt.figure(figsize=(6, 5))
sns.heatmap(avg_cm, annot=True, fmt='g', cmap='Blues', cbar=False,
            xticklabels=[0, 1, 2, 3, 4], yticklabels=[0, 1, 2, 3, 4])
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix (All Folds)")
plt.show()

In [None]:
for i, cm in enumerate(conf_matrices):
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, 
                xticklabels=[0, 1, 2, 3, 4], yticklabels=[0, 1, 2, 3, 4])
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title(f"Confusion Matrix - Fold {i+1}")
    plt.show()

In [None]:
all_y_true = []
all_y_pred = []

for train_idx, test_idx in kf.split(features_esm2, labels):
    X_test = features_esm2[test_idx]
    y_test = labels[test_idx]
    y_pred = models[i].predict(X_test)

    all_y_true.extend(y_test)
    all_y_pred.extend(y_pred)

all_y_true = np.array(all_y_true)
all_y_pred = np.array(all_y_pred)

# Plotting combined true vs. predicted labels
plt.figure(figsize=(6, 4))
plt.scatter(all_y_true, all_y_pred, alpha=0.6, color='darkorange')
plt.xlabel("True Labels")
plt.ylabel("Predicted Labels")
plt.title("True vs. Predicted Labels (All Folds Combined)")
plt.plot([0, 4], [0, 4], 'r--')  # identity line
plt.xlim(-0.5, 4.5)
plt.ylim(-0.5, 4.5)
plt.grid(True)
plt.show()

In [None]:
for i, (train_idx, test_idx) in enumerate(kf.split(features_esm2, labels)):
    X_test = features_esm2[test_idx]
    y_test = labels[test_idx]
    y_pred = models[i].predict(X_test) 

    plt.figure(figsize=(6, 4))
    plt.scatter(y_test, y_pred, alpha=0.5)
    plt.xlabel("True Labels")
    plt.ylabel("Predicted Labels")
    plt.title(f"Fold {i+1} - True vs. Predicted")
    plt.grid()
    plt.plot([0, 4], [0, 4], 'r--')
    plt.xlim(-0.5, 4.5)
    plt.ylim(-0.5, 4.5)
    plt.show()

## XGBoost Regressor

In [None]:
# Define hyperparameter grid
param_grid = {
    'n_estimators': [100, 120],
    'max_depth': [8, 10],
    'learning_rate': [0.3, 0.1],
}

xgb_model = XGBRegressor(objective='reg:squarederror', random_state=42)

grid_search = GridSearchCV(
    estimator=xgb_model,
    param_grid=param_grid,
    cv=3,
    n_jobs=7,
    scoring='neg_mean_squared_error'
)
grid_search.fit(features_esm2, labels)

best_params = grid_search.best_params_
print("Best XGBoost Hyperparameters:", best_params)

In [None]:
labels = np.asarray(labels)
features_esm2 = np.asarray(features_esm2)

kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

results_all = []
models = []
conf_matrices = []

pbar = tqdm(total=kf.get_n_splits(features_esm2, labels))

for train_index, test_index in kf.split(features_esm2, labels):
    X_train, X_test = features_esm2[train_index], features_esm2[test_index]
    y_train, y_test = labels[train_index], labels[test_index]

    model = XGBRegressor(
        objective='reg:squarederror',
        random_state=42,
        **best_params
    )
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    models.append(model)

    try:
        fold_scores, cm = scores_regression(y_test, y_pred)
    except Exception as e:
        print(f"Error on fold: {e}")
        fold_scores = [np.nan] * 5
        cm = np.zeros((5, 5), dtype=int)

    results_all.append(fold_scores)
    conf_matrices.append(cm)
    pbar.update(1)

pbar.close()

### Metrics Across Folds


In [None]:
metric_names = ["MSE", "RMSE", "MAE", "R2", "ACC"]

results_array = np.array(results_all)
results_df = pd.DataFrame(results_array, columns=metric_names)
results_df.index = [f"Fold {i+1}" for i in range(len(results_all))]

print("=== Fold-wise Performance ===")
print(results_df.to_string())

print("\n=== Mean ===")
print(results_df.mean().to_string())

print("\n=== Std Dev ===")
print(results_df.std().to_string())

In [None]:
mse_values = results_array[:, 0]

plt.figure(figsize=(8, 5))
plt.plot(range(1, len(mse_values) + 1), mse_values, marker='o', linestyle='-', color='blue', label='MSE per fold')
plt.axhline(np.mean(mse_values), color='red', linestyle='--', label='Mean MSE')
plt.title('Mean Squared Error (MSE) Across Folds')
plt.xlabel('Fold')
plt.ylabel('MSE')
plt.xticks(range(1, len(mse_values) + 1))
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# Boxplot
plt.figure(figsize=(12, 6))
plt.boxplot(results_df.values, tick_labels=metric_names, patch_artist=True,
            boxprops=dict(facecolor='lightblue', color='black'),
            medianprops=dict(color='red'),
            whiskerprops=dict(color='black'))
plt.title("XGBoost Regressor Cross-Fold Metric Distribution")
plt.grid()
plt.show()

In [None]:
#Sum all confusion matrices
avg_cm = np.sum(conf_matrices, axis=0)

print(avg_cm)

#Plot the averaged confusion matrix
plt.figure(figsize=(6, 5))
sns.heatmap(avg_cm, annot=True, fmt='g', cmap='Blues', cbar=False,
            xticklabels=[0, 1, 2, 3, 4], yticklabels=[0, 1, 2, 3, 4])
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Average Confusion Matrix (All Folds)")
plt.show()

In [None]:
for i, cm in enumerate(conf_matrices):
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, 
                xticklabels=[0, 1, 2, 3, 4], yticklabels=[0, 1, 2, 3, 4])
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title(f"Confusion Matrix - Fold {i+1}")
    plt.show()

In [None]:
all_y_true = []
all_y_pred = []

for train_idx, test_idx in kf.split(features_esm2, labels):
    X_test = features_esm2[test_idx]
    y_test = labels[test_idx]
    y_pred = models[i].predict(X_test)

    all_y_true.extend(y_test)
    all_y_pred.extend(y_pred)

all_y_true = np.array(all_y_true)
all_y_pred = np.array(all_y_pred)

# Plotting combined true vs. predicted labels
plt.figure(figsize=(6, 4))
plt.scatter(all_y_true, all_y_pred, alpha=0.6, color='darkorange')
plt.xlabel("True Labels")
plt.ylabel("Predicted Labels")
plt.title("True vs. Predicted Labels (All Folds Combined)")
plt.plot([0, 4], [0, 4], 'r--')  # identity line
plt.xlim(-0.5, 4.5)
plt.ylim(-0.5, 4.5)
plt.grid(True)
plt.show()