In [None]:
import mlflow
from matplotlib import pyplot as plt
import ast
import numpy as np
from sklearn.metrics import auc
import pandas as pd

In [None]:
retrieve_from_mlflow = False
save_results_file = "reproduction_midvholo.csv"

In [None]:
if retrieve_from_mlflow:
    print("retreiving results from mlflow")
    mlflow.set_tracking_uri("../mlruns") # local

    client = mlflow.MlflowClient()

    experiment_name = f"cumulative_midv_baseline_fulldoc_nosplit"
    current_experiment = dict(client.get_experiment_by_name(experiment_name))
    print("found", current_experiment["name"])
    experiment = current_experiment['experiment_id']
    runs = mlflow.search_runs([experiment], filter_string="attributes.status = 'FINISHED'", order_by=["start_time DESC"])
    runs_tmp = runs[["params.model", "params.decision"]]
    runs = runs.filter(regex="^metrics")
    runs = pd.concat((runs, runs_tmp), axis=1)

    for i, row in runs.iterrows(): # extract the params :s_t, T, h_t (named th)
        for k in ["params.model", "params.decision"]:
            params_model_dict = ast.literal_eval(row[k])
            for k, v in params_model_dict.items():
                if k != "_target_":
                    if k not in runs.columns:
                        runs[k] = None  # Create a new column if it doesn't exist
                    runs.at[i, k] = v
    print(f"writing {save_results_file}")
    runs.to_csv(save_results_file)
else:
    print("retrieving results from", save_results_file)
    runs = pd.read_csv(save_results_file, index_col=0)

In [None]:
runs["metrics.fpr"] = 1-runs["metrics.specificity"]

In [None]:
runs.head()

In [None]:
runs.loc[
    (runs["s_t"] == 50) & (runs["hight_threshold"] == 240) & (runs["th"] == 0.01)
].sort_values("T")[["metrics.recall", "metrics.fpr", "s_t", "T", "hight_threshold", "th"]]

In [None]:
s_t = 50
th = 0.01 # h_t
runs_filtered = runs.loc[
    (runs["s_t"] == s_t) & (runs["hight_threshold"] == 240) & (runs["th"] == th)
]
runs_filtered = runs_filtered.sort_values("T", ascending=True)
x = list(runs_filtered["metrics.fpr"])
y = list(runs_filtered["metrics.recall"]) 

# This is the ROC curve
plt.figure(figsize=(6,6))
plt.step(x,y, where="pre")
plt.plot([0,1], [0,1])
plt.xlabel("False positive rate (FPR)")
plt.xlim([0,1])
plt.ylim([0,1])
ticks = np.arange(0, 100, 10)/100
plt.xticks(ticks, ticks)
plt.yticks(ticks, ticks)
plt.grid(which="both")
plt.ylabel("Recall")
plt.title(f"ROC curve for s_t={s_t} and h_t={th}")
plt.legend([f"AUC {auc(x,y).round(3)}", f"Random {auc([0,1],[0,1]).round(3)}"], loc="lower right")
plt.show()

In [None]:
def auc_group(df):
    df = df.sort_values("T", ascending=True)
    x = list(df["metrics.fpr"])
    y = list(df["metrics.recall"]) 
    return auc(x, y)

In [None]:
auc_serie = runs.groupby(["s_t","th"]).apply(auc_group)
auc_df = pd.DataFrame(auc_serie)
auc_df = auc_df.T.style.format(precision=3)
auc_df

In [None]:
print(auc_df.to_latex())

Result of the previous cell:
```latex
\begin{tabular}{lrrrrrrrrr}
s_t & \multicolumn{3}{r}{30} & \multicolumn{3}{r}{40} & \multicolumn{3}{r}{50} \\
th & 0.010000 & 0.020000 & 0.030000 & 0.010000 & 0.020000 & 0.030000 & 0.010000 & 0.020000 & 0.030000 \\
hight_threshold & 240 & 240 & 240 & 240 & 240 & 240 & 240 & 240 & 240 \\
0 & 0.838 & 0.846 & 0.844 & 0.855 & 0.844 & 0.831 & 0.857 & 0.826 & 0.790 \\
\end{tabular}
```