In [None]:
import pandas as pd

In [None]:
output_path = "/Users/benseimon/GitHub/teaching/DSDM_forecasting_Apr25/output"

In [None]:
inc_preds = pd.read_csv(f"{output_path}/incidence_predictions.csv")
ons_preds = pd.read_csv(f"{output_path}/onset_predictions.csv")

In [None]:
from sklearn.metrics import precision_recall_curve, auc
import pandas as pd
import matplotlib.pyplot as plt
from typing import Tuple

class Eval():

    """
    Quick evaluation class to calculate PR AUC score for a given threshold.
    """

    def __init__(
        self, 
        df:pd.DataFrame, 
        target_col:str, 
        preds_col:str, 
        since_col:str,
        figsize:str = (8, 8),
        dpi: int = 400,
        label_fontsize = 24,
        legend_fontsize = 16,
        tick_fontsize = 16,
        linewidth = 3
        ):

        self.df = df
        self.target_col = target_col
        self.preds_col = preds_col
        self.since_col = since_col

        #plotting
        self.figsize = figsize
        self.dpi = dpi
        self.label_fontsize = label_fontsize
        self.legend_fontsize = legend_fontsize
        self.tick_fontsize = tick_fontsize
        self.linewidth = linewidth

    def _default_fig(self) -> Tuple[plt.Figure, plt.Axes]:

        """
        Function to set up the default figure and axes for plotting

        Args
        ----------
        figsize : Tuple[int, int]
            Figure size
        dpi : int
            Dots per inch

        Returns
        -------
        Tuple[plt.Figure, plt.Axes]
            Figure and axes objects
        """

        return plt.subplots(figsize = self.figsize, dpi = self.dpi)

    def pr_auc(self, onset_threshold:int):

        """
        Function to calculate the PR AUC score conditional on onset_threshold.

        Args
        ----------
        onset_threshold : int
            Threshold to define an onset

        Returns
        -------
        float
            PR AUC score
        recall : np.array (optional)
            Recall
        precision : np.array (optional)
            Precision
        """

        cond1 = (self.df[self.target_col].notna())
        cond2 = (self.df[self.since_col] >= onset_threshold)
        df = self.df[cond1 & cond2]
        precision, recall, _ = precision_recall_curve(df[self.target_col], df[self.preds_col])
        
        return auc(recall, precision), recall, precision

    def pr_auc_plot(self, onset_thresholds:list):

        """
        Function to plot the PR AUC score for a list of onset thresholds.

        Args
        ----------
        onset_thresholds : list
            List of onset thresholds

        Returns
        -------
        plt.Figure
            Plot of the PR curves
        """

        fig, ax = self._default_fig()

        for threshold in onset_thresholds:

            score, x, y = self.pr_auc(threshold)

            ax.plot(x, y, label = f'{threshold}: {round(score, 2)}', linewidth = self.linewidth)

        ax.set_xlim([-0.05, 1.05])
        ax.set_ylim([-0.05, 1.05])
        ax.legend(fontsize = self.legend_fontsize)
        ax.tick_params(axis = 'both', which = 'major', labelsize = self.tick_fontsize)

        return fig, ax


In [None]:
ons_preds.columns

In [None]:
inc_eval = Eval(
    df = inc_preds,
    target_col = "inc_anyviolence_th0_h3",
    preds_col = "inc_preds",
    since_col = "violence_since_0",
)

ons_eval = Eval(
    df = ons_preds,
    target_col = "ons_anyviolence_th0_h3",
    preds_col = "ons_preds",
    since_col = "violence_since_0",
)

In [None]:
inc_eval.pr_auc_plot(onset_thresholds = [0,1,60])

In [None]:
ons_eval.pr_auc_plot(onset_thresholds = [1,60])