In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.interpolate import interp1d
from sklearn.metrics import auc, precision_recall_curve

In [None]:
slurmId = ""  # id of the slurm job assigned to the run you want to analyze
colour_train = "#F79647"
colour_pretrain = "#5D97BF"

In [None]:
data = pd.read_csv("../../data/output/predictions.csv", index_col=0)
data = data[data["run_id"].str.contains(f"slurmId={slurmId}")]

data["pretrain"] = data["run_id"].str.contains("pretrain")
data["seed"] = data["run_id"].str.extract("seed=(\d+)")

pretrain_data = data[data["pretrain"]]
non_pretrain_data = data[~data["pretrain"]]

data

In [None]:
model_name = "Ditto-R" if data["run_id"].str.contains("roberta").any() else "Ditto-D"

# Precision Recall plots

Calculates the mean of the linearly interpolated precision-recall curve for a given set of predictions and labels.


In [None]:
def average_precision_recall_curves(pr_curves, recall_points=None):
    if recall_points is None:
        recall_points = np.linspace(0, 1, 100)

    interpolated_precisions = []

    for curve in pr_curves:
        precisions, recalls, _ = curve
        # Interpolate precision at fixed recall points
        interp_precision = interp1d(
            recalls, precisions, kind="linear", bounds_error=False, fill_value=0
        )
        interpolated_precisions.append(interp_precision(recall_points))

    # Average the interpolated precisions
    avg_precision = np.mean(interpolated_precisions, axis=0)

    return recall_points, avg_precision


def calculate_pr_curve(df):
    # group by the seed
    grouped = df.groupby("seed")
    pr_curves = [
        precision_recall_curve(group["label"], group["probability_match"])
        for _, group in grouped
    ]
    recall_points, avg_precision = average_precision_recall_curves(pr_curves)

    return pd.DataFrame({"precision": avg_precision, "recall": recall_points})


pretrain_pr = calculate_pr_curve(pretrain_data)
non_pretrain_pr = calculate_pr_curve(non_pretrain_data)

# Plot PR curves
plt.figure(figsize=(6.5, 5))
plt.rcParams.update(
    {
        "font.size": 20,
        "axes.titlesize": 20,
    }
)
plt.plot(
    pretrain_pr["recall"],
    pretrain_pr["precision"],
    color=colour_pretrain,
    label=f"P-{model_name}",
)
plt.plot(
    non_pretrain_pr["recall"],
    non_pretrain_pr["precision"],
    color=colour_train,
    label=f"{model_name}",
)

plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend(title="Model Type", loc="upper right")

plt.tight_layout()

plt.savefig(f"../../data/output/plots/pr_curve_calculated_{model_name}.pdf")
plt.show()