# Create PR Curves

## Load the data
Load both, the datasets themselves as well as the results of the AFD measures.

In [1]:
import os
import sys

import pandas as pd

# for Jupyter notebooks: add the path of 'code' to allow importing module
sys.path.append(os.path.join(os.getcwd(), ".."))
from afd_measures import utils as afd_utils

data_path = "../../data"
gt_path = "../../data/ground_truth.csv"
results_path = "../../results"

rwd_data = {}
for file in filter(
    lambda f: f.endswith(".csv"), os.listdir(os.path.join(data_path, "rwd"))
):
    rwd_data[file] = pd.read_csv(os.path.join(data_path, "rwd", file))
    rwd_data[file].columns = [
        afd_utils.clean_colname(c) for c in rwd_data[file].columns
    ]

rwd_results = pd.DataFrame()
for file in filter(
    lambda f: f.startswith("rwd_results_") and f.endswith(".csv"),
    os.listdir(results_path),
):
    rwd_results = pd.concat(
        [rwd_results, pd.read_csv(os.path.join(results_path, file))]
    )

  rwd_data[file] = pd.read_csv(os.path.join(data_path, "rwd", file))
  rwd_data[file] = pd.read_csv(os.path.join(data_path, "rwd", file))
  rwd_data[file] = pd.read_csv(os.path.join(data_path, "rwd", file))


## Create plotting data

This will create a file to plot.

In [2]:
from typing import Dict

import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_curve

from afd_measures import utils as afd_utils


def make_pr_data(dataset: pd.DataFrame, y_true_key: str) -> Dict[str, pd.DataFrame]:
    result_dfs = {}
    for measure in afd_utils.measure_order:
        # calculate recall and precision curve values
        df = dataset.query(f"{measure}.notna()").copy()
        precision, recall, threshold = precision_recall_curve(
            df.loc[:, y_true_key], df.loc[:, measure]
        )
        result_dfs[measure] = pd.DataFrame(
            {
                "precision": precision,
                "recall": recall,
            },
            index=np.append(threshold, 2.0),
        )

    return result_dfs

In [3]:
_rwd = rwd_results.query("trivial_fd == False").copy()
plot_data = make_pr_data(_rwd.query("exact_fd == False"), "afd")
for measure, df in plot_data.items():
    df.to_csv(
        f"../../paper/figure1_rwd_{measure}.dat", sep="\t", index_label="threshold"
    )