In [None]:
import random

import pandas as pd
import math
from tqdm import tqdm
from sklearn.metrics import (
    f1_score, accuracy_score, recall_score, precision_score, 
    precision_recall_curve, confusion_matrix, roc_auc_score, 
    matthews_corrcoef, roc_curve
)
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
import numpy as np

## Metric Calculation Utilities

In [None]:
def get_aupr(pre, rec):
    pr_value = 0.0
    for ii in range(len(rec[:-1])):
        x_r, x_l = rec[ii], rec[ii+1]
        y_t, y_b = pre[ii], pre[ii+1]
        tempo = abs(x_r - x_l) * (y_t + y_b) * 0.5
        pr_value += tempo
    return pr_value

def scores(y_test, y_pred, th=0.5):           
    y_predlabel = [(0. if item < th else 1.) for item in y_pred]
    tn, fp, fn, tp = confusion_matrix(y_test, y_predlabel).flatten()
    SPE = tn / (tn + fp)
    MCC = matthews_corrcoef(y_test, y_predlabel)
    fpr, tpr, _ = roc_curve(y_test, y_pred)
    sen, spe, pre, f1, mcc, acc, auc, tn, fp, fn, tp = np.array([
        recall_score(y_test, y_predlabel), SPE, precision_score(y_test, y_predlabel), 
        f1_score(y_test, y_predlabel), MCC, accuracy_score(y_test, y_predlabel), 
        roc_auc_score(y_test, y_pred), tn, fp, fn, tp
    ])
    precision, recall, _ = precision_recall_curve(y_test, y_pred)
    aupr = get_aupr(precision, recall)
    return [aupr, auc, f1, acc, sen, spe, pre, fpr, tpr, precision, recall]


# Data Loader

In [None]:
def construct_feature_matrices2(lociembeddings_path, rbpembeddings_path, interactions):
    """
     Constructs feature matrices for host-phage interaction prediction using ESM-2 embeddings.
     It reads RBP and loci embeddings, filters them based on valid entries in the interaction matrix,
     and combines them to form feature vectors representing phage-host pairs.

    Parameters:
    - lociembeddings_path: Path to loci embeddings CSV
    - rbpembeddings_path: Path to RBP embeddings CSV
    - interactions

    Returns:
    - features_esm2, labels, groups_loci, groups_phage
    """
    RBP_embeddings = pd.read_csv(rbpembeddings_path)
    loci_embeddings = pd.read_csv(lociembeddings_path)

    # Get valid phages and loci from interactions
    valid_phages = interactions.columns.tolist()
    valid_loci = interactions.index.tolist()

    # Filter embeddings to only use valid entries
    RBP_embeddings = RBP_embeddings[RBP_embeddings['phage_ID'].str.replace('.fna', '').isin(valid_phages)]
    loci_embeddings = loci_embeddings[loci_embeddings['accession'].isin(valid_loci)]
    print(len(RBP_embeddings), len(loci_embeddings))

    # Construct multi-RBP representations
    multi_embeddings = []
    names = []
    for phage_id in sorted(set(RBP_embeddings['phage_ID'])):
        phage_key = phage_id.split('.')[0]
        if phage_key in valid_phages:
            rbp_embeddings = RBP_embeddings.iloc[:, 2:][RBP_embeddings['phage_ID'] == phage_id]
            multi_embedding = np.mean(np.asarray(rbp_embeddings), axis=0)
            names.append(phage_id)
            multi_embeddings.append(multi_embedding)

    multiRBP_embeddings = pd.concat([pd.DataFrame({'phage_ID': names}), pd.DataFrame(multi_embeddings)], axis=1)

    # Build features
    features_lan = []
    labels = []
    groups_loci = []
    groups_phage = []

    for i, accession in enumerate(loci_embeddings['accession']):
        for j, phage_id in enumerate(multiRBP_embeddings['phage_ID']):
            phage_key = phage_id.split('.')[0]
            if accession in interactions.index and phage_key in interactions.columns:
                interaction = interactions.loc[accession][phage_key]
                if not math.isnan(interaction):
                    features_lan.append(pd.concat([loci_embeddings.iloc[i, 1:], multiRBP_embeddings.iloc[j, 1:]]))
                    label = 1 if interaction >= 1 else 0
                    labels.append(label)
                    groups_loci.append(i)
                    groups_phage.append(j)


    features_lan = np.asarray(features_lan)
    print("Dimensions match?", features_lan.shape[1] == (loci_embeddings.shape[1] + multiRBP_embeddings.shape[1] - 2))

    return features_lan, labels, groups_loci, groups_phage

In [None]:
# Escherichia Dataset
general_output_path = "esm-features_escherichia"

loci_embeddings_path = "esm-features_escherichia/esm2_embeddings_lociJonas.csv"
rbp_embeddings_path = "esm-features_escherichia/esm2_embeddings_rbpJonas2.csv"
interactions = pd.read_csv("esm-features_escherichia/ordinal_interaction_matrix.csv", index_col=0, sep=";")

features_esm2, labels, groups_loci, groups_phage = construct_feature_matrices2(loci_embeddings_path, rbp_embeddings_path, interactions)

In [None]:
# Klebsiella Dataset
general_output_path = "esm-features-klebsiella"

loci_embeddings_path = "esm-features-klebsiella/esm2_embeddings_loci.csv"
rbp_embeddings_path = "esm-features-klebsiella/esm2_embeddings_rbp.csv"
interactions = pd.read_csv("./phage_host_interactions.csv", index_col=0)

features_esm2, labels, groups_loci, groups_phage = construct_feature_matrices2(loci_embeddings_path, rbp_embeddings_path,interactions)

In [None]:
print("Shape of features_esm2:", features_esm2.shape)
print("Length of labels:", len(labels))
print("Length of groups_loci:", len(groups_loci))
print("Length of groups_phage:", len(groups_phage))

# Model Training and Evaluation

## XGBoost

random state on dataset A was 1 and 42 on dataset B ( happened not intentionally, but came from 2 different notebooks for experimentation -- only on XGBoost here)

In [None]:
labels = np.asarray(labels)
random.seed(42)
np.random.seed (42)
# Calculate imbalance
imbalance = sum(labels >= 1) / sum(labels == 0)

#imbalance left out for dataset A
xgb = XGBClassifier(eval_metric='logloss', random_state=1)
param_grid = {
    'n_estimators': [150, 100],
    'max_depth': [10, 8],
    'learning_rate': [0.1, 0.3],
    'scale_pos_weight': [1, 1/imbalance]
}

# Perform GridSearchCV once on the whole dataset
grid_search = GridSearchCV(estimator=xgb, param_grid=param_grid, scoring='f1', cv=3, n_jobs=-1, verbose=1)
grid_search.fit(features_esm2, labels)
best_params = grid_search.best_params_
print("Best XGB hyperparameters:", best_params)

In [None]:
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)
results_all = []
fprs, tprs, precisions, recalls = [], [], [], []

pbar = tqdm(total=kf.get_n_splits(features_esm2, labels))

for train_index, test_index in kf.split(features_esm2, labels):
    Xlan_train, Xlan_test = features_esm2[train_index], features_esm2[test_index]
    y_train, y_test = labels[train_index], labels[test_index]

    best_model = XGBClassifier(eval_metric='logloss', **best_params, random_state=1)
    best_model.fit(Xlan_train, y_train)

    score_xgb = best_model.predict_proba(Xlan_test)[:, 1]
    fold_scores = scores(y_test, score_xgb)

    results_all.append(fold_scores[:7])
    fprs.append(fold_scores[7])
    tprs.append(fold_scores[8])
    precisions.append(fold_scores[9])
    recalls.append(fold_scores[10])

    pbar.update(1)

pbar.close()

### Metrics Across Folds

In [None]:
metric_names = ["AUPR", "ROC-AUC", "F1", "ACC", "SEN", "SPE", "PRE"]
results_df = pd.DataFrame(results_all, columns=metric_names)
results_df.index = [f"Fold {i+1}" for i in range(len(results_all))]

print("=== Fold-wise Performance ===")
display(results_df)
print("\n=== Mean ===")
print(results_df.mean())
print("\n=== Std Dev ===")
display(results_df.std())

In [None]:
# Plot ROC Curves
plt.figure(figsize=(10, 5))
for i in range(len(fprs)):
    plt.plot(fprs[i], tprs[i], label=f'Fold {i+1}')
plt.plot([0, 1], [0, 1], 'k--')
plt.title("XGBoost ROC Curve Across Folds")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.grid()
plt.show()

# Plot PR Curves
plt.figure(figsize=(10, 5))
for i in range(len(precisions)):
    plt.plot(recalls[i], precisions[i], label=f'Fold {i+1}')
plt.title("XGBoost Precision-Recall Curve Across Folds")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend()
plt.grid()
plt.show()


In [None]:
metric_names = ["AUPR", "ROC-AUC", "F1", "Accuracy", "Sensitivity", "Specificity", "Precision"]
results_array = np.array(results_all)

plt.figure(figsize=(12, 6))
plt.boxplot(results_array, labels=metric_names, patch_artist=True,
            boxprops=dict(facecolor='lightblue', color='black'),
            medianprops=dict(color='red'),
            whiskerprops=dict(color='black'))
plt.title("XGBooost Cross-Fold Metric Distribution")
plt.grid()
plt.show()

In [None]:
results_df = pd.DataFrame(results_all, columns=metric_names)
results_df.index = [f"Fold {i+1}" for i in range(len(results_all))]
display(results_df)

print("Mean Metrics:")
display(results_df.mean())

##  Random Forest Model


In [None]:
features_esm2 = np.asarray(features_esm2)
labels = np.asarray(labels)

# Define hyperparameter grid
param_grid = {
    'n_estimators': [100, 150],
    'min_samples_split': [2, 5],
    'min_samples_leaf': [1, 2],
}

rf = RandomForestClassifier(n_jobs=-1, random_state=42)

grid_search = GridSearchCV(
    estimator=rf,
    param_grid=param_grid,
    scoring='f1',
    cv=3,
    n_jobs=-1,
    verbose=1
)
grid_search.fit(features_esm2, labels)
best_rf_params = grid_search.best_params_
print("Best RF hyperparameters:", best_rf_params)

In [None]:
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

rf_results_all = []
rf_fprs, rf_tprs, rf_precisions, rf_recalls = [], [], [], []

pbar = tqdm(total=kf.get_n_splits(features_esm2, labels))

for train_index, test_index in kf.split(features_esm2, labels):
    X_train, X_test = features_esm2[train_index], features_esm2[test_index]
    y_train, y_test = labels[train_index], labels[test_index]

    # Train using best parameters found earlier
    best_rf = RandomForestClassifier(
        n_jobs=-1,
        random_state=42,
        **best_rf_params
    )
    best_rf.fit(X_train, y_train)
    y_pred_prob = best_rf.predict_proba(X_test)[:, 1]

    # Compute metrics
    metrics = scores(y_test, y_pred_prob)

    rf_results_all.append(metrics[:7])
    rf_fprs.append(metrics[7])
    rf_tprs.append(metrics[8])
    rf_precisions.append(metrics[9])
    rf_recalls.append(metrics[10])

    print(f"AUPR: {metrics[0]:.4f}, AUC: {metrics[1]:.4f}, "
          f"F1: {metrics[2]:.4f}, Acc: {metrics[3]:.4f}")

    pbar.update(1)

pbar.close()

### Metrics Across Folds

In [None]:
plt.style.use('default')

# Plot ROC Curves
plt.figure(figsize=(10, 5))
for i in range(len(rf_fprs)):
    plt.plot(rf_fprs[i], rf_tprs[i], label=f'Fold {i+1}')
plt.plot([0, 1], [0, 1], 'k--')
plt.title("Random Forest ROC Curve Across Folds")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.grid(True, color='black', linestyle='--', linewidth=0.5)
plt.gca().set_facecolor('white')
plt.show()

# Plot PR Curves
plt.figure(figsize=(10, 5))
for i in range(len(rf_precisions)):
    plt.plot(rf_recalls[i], rf_precisions[i], label=f'Fold {i+1}')
plt.title("Random Forest Precision-Recall Curve Across Folds")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend()
plt.grid(True, color='black', linestyle='--', linewidth=0.5)
plt.gca().set_facecolor('white')
plt.show()


In [None]:
rf_metric_names = ["AUPR", "ROC-AUC", "F1", "Accuracy", "Sensitivity", "Specificity", "Precision"]
rf_results_array = np.array(rf_results_all)

# Boxplot
plt.figure(figsize=(12, 6))
plt.boxplot(rf_results_array, tick_labels=rf_metric_names, patch_artist=True,
            boxprops=dict(facecolor='lightgreen', color='black'),
            medianprops=dict(color='red'),
            whiskerprops=dict(color='black'))
plt.title("Random Forest Cross-Fold Metric Distribution")
plt.grid()
plt.show()


In [None]:
rf_results_df = pd.DataFrame(rf_results_all, columns=rf_metric_names)
rf_results_df.index = [f"Fold {i + 1}" for i in range(len(rf_results_all))]
display(rf_results_df)

print("Random Forest Mean Metrics:")
display(rf_results_df.mean())

In [None]:
all_metric_names = ["AUPR", "ROC-AUC", "F1", "Accuracy", "Sensitivity", "Specificity", "Precision"]

plt.figure(figsize=(14, 6))
plt.boxplot([np.array(results_all)[:,i] for i in range(7)],
            positions=np.arange(1, 8) - 0.2, widths=0.3, patch_artist=True,
            boxprops=dict(facecolor='lightblue'), medianprops=dict(color='blue'), labels=all_metric_names)

plt.boxplot([np.array(rf_results_all)[:,i] for i in range(7)], 
            positions=np.arange(1, 8) + 0.2, widths=0.3, patch_artist=True,
            boxprops=dict(facecolor='lightgreen'), medianprops=dict(color='green'))

plt.legend(['XGBoost - blue', 'Random Forest - green'])
plt.title("Metric Distribution Comparison: XGBoost vs Random Forest")
plt.grid()
plt.xticks(np.arange(1, 8), all_metric_names)
plt.show()


In [None]:
xgb_mean = pd.DataFrame([np.mean(results_all, axis=0)], columns=all_metric_names, index=["XGBoost"])
rf_mean = pd.DataFrame([np.mean(rf_results_all, axis=0)], columns=all_metric_names, index=["Random Forest"])

display(pd.concat([rf_mean]))