In [1]:
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, roc_curve
import matplotlib.pyplot as plt
import pandas as pd
import pickle

## Generating Multithreshold ROC graphs

In [None]:
with open("../Results/0_AUCs_original_replica.pkl", 'rb') as f:
    loaded_data = pickle.load(f)

# .pkl paths:
# "0_AUCs_original_replica.pkl"                   - Replicating the original results after splitting the individual proteins, then re-averaging them
# "1_AUCs_max_max.pkl"                            - Maximum scores among protein-protein pairs
# "2_AUCs_max_max_original_sero.pkl"              - Replicating the original results, but using One-Hot Encoded serotypes instead of bacterial proteins
# "3_AUCs_max_max_sero.pkl"                       - Combining the previous two approaches: individual viral RBPs and One-Hot Encoded serotypes for bacterial proteins
# "4_AUCs_motif_focus_increased.pkl"              - Motif-containing RBPs and viruses that only express one RBP as positives
# "5_AUCs_motif_focus.pkl"                        - Motif-containing RBPs only as positives

threshold = [1, 0.995, 0.99, 0.95, 0.90, 0.85, 0.80, 0.75]

# Create a single figure for all ROC curves
fig, ax = plt.subplots(figsize=(10, 8))

# Loop through all data points and plot them on the same figure
for num, el in enumerate(loaded_data):
    thresh = threshold[num]
    labels = el[0]
    scoreslr = el[1]

    # predsss.append((label_max, scores_max, rauclr))

    fpr, tpr, thrs = roc_curve(labels, scoreslr)
    rauclr = round(auc(fpr, tpr), 3)
    print("THRESHOLD: ", thresh)
    print("AUC: ", rauclr)
    print("##############################################")

    # Plot each ROC curve on the same axes
    ax.plot(fpr, tpr, linewidth=2.5, label=f'Threshold: {thresh} (AUC = {rauclr})')

# Final plot adjustments
ax.set_xlabel('False positive rate', size=24)
ax.set_ylabel('True positive rate', size=24)
ax.legend(loc=4, prop={'size': 16})
ax.grid(True, linestyle=':')
ax.yaxis.set_tick_params(labelsize=14)
ax.xaxis.set_tick_params(labelsize=14)
ax.set_title('ROC Curves for Different Thresholds\n(Full Embeddings)', fontsize=26)

plt.show()

## Adding Precision-Recall Curves for the Highest Threshold

In [None]:
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, roc_curve
import matplotlib.pyplot as plt
import pandas as pd
import pickle

with open("../Results/0_original_replica.pkl", 'rb') as f:
    loaded_data = pickle.load(f)

# .pkl paths:
# "0_AUCs_original_replica.pkl"                   - Replicating the original results after splitting the individual proteins, then re-averaging them
# "1_AUCs_max_max.pkl"                            - Maximum scores among protein-protein pairs
# "2_AUCs_max_max_original_sero.pkl"              - Replicating the original results, but using One-Hot Encoded serotypes instead of bacterial proteins
# "3_AUCs_max_max_sero.pkl"                       - Combining the previous two approaches: individual viral RBPs and One-Hot Encoded serotypes for bacterial proteins
# "4_AUCs_motif_focus_increased.pkl"              - Motif-containing RBPs and viruses that only express one RBP as positives
# "5_AUCs_motif_focus.pkl"                        - Motif-containing RBPs only as positives

threshold = [1, 0.995, 0.99, 0.95, 0.90, 0.85, 0.80, 0.75]


# Loop through all data points and plot them on the same figure
for num, el in enumerate(loaded_data):
    thresh = threshold[num]
    labels = el[0]
    scores = el[1]

    # Compute ROC curve and ROC AUC
    fpr, tpr, _ = roc_curve(labels, scores)
    roc_auc = auc(fpr, tpr)

    # Compute Precision-Recall curve and AUC-PR
    precision, recall, _ = precision_recall_curve(labels, scores)
    pr_auc = auc(recall, precision)

    break

# Plot both curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# ROC curve plot
ax1.plot(fpr, tpr, label=f"AUC-ROC = {roc_auc:.3f}")
ax1.set_title("ROC Curve")
ax1.set_xlabel("False Positive Rate")
ax1.set_ylabel("True Positive Rate")
ax1.grid(True)
ax1.legend()

# PR curve plot
ax2.plot(recall, precision, label=f"AUC-PR = {pr_auc:.3f}", color="orange")
ax2.set_title("Precision-Recall Curve")
ax2.set_xlabel("Recall")
ax2.set_ylabel("Precision")
ax2.grid(True)
ax2.legend()

plt.tight_layout()
plt.show()