In [13]:
import pickle
with open('pilot_studies.pickle', 'rb') as handle:
    studies = pickle.load(handle)      # Saca la cosa
studies[0].trials.__len__()

study = studies[0]

In [48]:
import pandas as pd
from optuna.visualization._pareto_front import _get_pareto_front_info
from typing import Tuple
import optuna
import plotly.graph_objects as go

class TwoObjectiveSolutions:
    """
    A class that provides functionality for handling two-objective solutions in an Optuna study.

    Attributes:
    study (optuna.study.study.Study): The Optuna study object.
    df (pd.DataFrame): DataFrame containing all solutions.
    df_filtered (pd.DataFrame): DataFrame containing solutions that pass the provided cutoffs.
    filtered_trials (List[optuna.trial.FrozenTrial]): Trials that pass the provided cutoffs.
    """

    def __init__(self, study: optuna.study.study.Study, auc_cutoff: float = None, s_cutoff: float = None):
        """
        Initialize the TwoObjectiveSolutions object with an Optuna study.

        Args:
        study (optuna.study.study.Study): The Optuna study object.
        """
        self.study = study
        self.df = None
        self.df_filtered = None
        self.filtered_trials = None
        self.auc_cutoff = auc_cutoff
        self.s_cutoff   = s_cutoff

    def generate_dataframes(self):
        """
        Generate the dataframes for all solutions and the solutions that pass the provided cutoffs.

        Args:
        auc_cutoff (float, optional): The AUC cutoff. Defaults to None.
        s_cutoff (float, optional): The S cutoff. Defaults to None.

        Returns:
        self
        """
        info = _get_pareto_front_info(self.study)
        self.df = self._create_trials_dataframe(info.best_trials_with_values + info.non_best_trials_with_values)

        if self.auc_cutoff is not None and self.s_cutoff is not None:
            self.df_filtered = self.df[(self.df["AUC"] > self.auc_cutoff) & (self.df["S"] > self.s_cutoff)]
        else:
            self.df_filtered = None

        return self

    def get_filtered_trials(self):
        """
        Generate and save trials from the study that surpass both AUC and S cutoffs.

        Args:
        auc_cutoff (float): The AUC cutoff. Only trials with AUC greater than this will be returned.
        s_cutoff (float): The S cutoff. Only trials with S greater than this will be returned.

        Returns:
        self
        """
        # if self.df is None or self.df_filtered is None:
        #     self.generate_dataframes(auc_cutoff, s_cutoff)

        all_trials = self.study.get_trials(deepcopy=False)
        self.filtered_trials = []
        self.filtered_trials.extend(
            trial
            for trial in all_trials
            if (trial.values[0] > self.auc_cutoff) & (trial.values[1] > self.s_cutoff)
        )
        # for trial in all_trials:
        #     if any(self.df_filtered["AUC"] == trial.values[0] and self.df_filtered["S"] == trial.values[1]):
        #         self.filtered_trials.append(trial)

        for trial in self.filtered_trials:
            assert trial.values[0] > self.auc_cutoff, f"Trial {trial.number} has AUC ({trial.values[0]}) below the cutoff ({self.auc_cutoff})"
            assert trial.values[1] > self.s_cutoff, f"Trial {trial.number} has S ({trial.values[1]}) below the cutoff ({self.s_cutoff})"

        return self

    def plot_solutions(self):
        """
        Generate a scatter plot of the solutions.

        Args:
        
        auc_cutoff (float, optional): The AUC cutoff. Defaults to None.
        s_cutoff (float, optional): The S cutoff. Defaults to None.

        Returns:
        self
        """
        if self.df is None:
            self.generate_dataframes(self.auc_cutoff, self.s_cutoff)

        fig = go.Figure()
        fig.add_trace(go.Scatter(x=self.df["AUC"], y=self.df["S"], mode='markers'))

        if self.auc_cutoff is not None and self.s_cutoff is not None:
            fig.add_shape(type="rect",
                          x0=self.auc_cutoff, y0=self.s_cutoff,
                          x1=max(self.df["AUC"]), y1=max(self.df["S"]),
                          line=dict(color="LightSeaGreen", width=2),
                          fillcolor="LightSeaGreen", opacity=0.3)

        fig.show()

        return self

    def _create_trials_dataframe(self, trials_with_values: Tuple):
        info = _get_pareto_front_info(self.study)
        auc = [values[info.axis_order[0]] for _, values in trials_with_values]
        s = [values[info.axis_order[1]] for _, values in trials_with_values]
        return pd.DataFrame({"AUC": auc, "S": s})



filtered_trials = TwoObjectiveSolutions(study, auc_cutoff=0.8, s_cutoff=0.2).get_filtered_trials().filtered_trials
