In [None]:
#For this script install pymoo : pip install pymoo
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
from pymoo.indicators.hv import HV
from pandas.plotting import parallel_coordinates

def select_best_epochs_by_hv(df):
    # Step 0: Remove rows with negative r2
    df = df[df['r2'] >= 0].copy()

    best_checkpoints = {}

    for seed in df['seed'].unique():
        df_seed = df[df['seed'] == seed].copy()

        # Pivot epochs x traits for r2
        pivot = df_seed.pivot(index='epoch', columns='trait', values='r2')

        # Fill NaNs with min r2 value to avoid scaler issues
        pivot_filled = pivot.fillna(pivot.min().min())

        # Normalize trait values to [0,1]
        scaler = MinMaxScaler()
        norm_values = scaler.fit_transform(pivot_filled.values)
        epochs = pivot.index.to_numpy()

        # Identify Pareto front (maximize r2)
        nds_indices = NonDominatedSorting().do(1-norm_values, only_non_dominated_front=True)
        
        print(f'nds indices for seed {seed}: {nds_indices}')
        pareto_points = norm_values[nds_indices]
        pareto_epochs = epochs[nds_indices]
        
        print(f"\nSeed {seed} Pareto front epochs: {list(pareto_epochs)}")
        print(f"Pareto front points (normalized R²):\n{pareto_points}")

        # Calculate hypervolume for each pareto point (minimizing transformed points)
        minim_points = 1 - pareto_points
        ref_point = np.ones(minim_points.shape[1]) * 1.1
        hv_indicator = HV(ref_point=ref_point)
        pareto_hvs = [hv_indicator.do(np.array([pt])) for pt in minim_points]

        best_idx = np.argmax(pareto_hvs)
        best_epoch = pareto_epochs[best_idx]
        best_checkpoints[seed] = best_epoch

        print(f"\nSeed {seed} Pareto front epochs: {list(pareto_epochs)}")
        print(f"Best epoch by HV: {best_epoch} with HV = {pareto_hvs[best_idx]:.6f}")
        # Label types safely
        df_parallel = pd.DataFrame(norm_values, columns=pivot.columns)
        df_parallel['epoch'] = epochs
        df_parallel['Type'] = 'Checkpoint'

        # Mark Pareto front via position
        df_parallel.iloc[nds_indices, df_parallel.columns.get_loc('Type')] = 'Pareto Front'

        # Mark best HV point
        df_parallel.loc[df_parallel['epoch'] == best_epoch, 'Type'] = 'Best HV Point'



        # Prepare for plot
        df_parallel_plot = df_parallel.copy()
        df_parallel_plot['epoch'] = df_parallel_plot['epoch'].astype(str)

        # Consistent colors
        color_map = {
            'Checkpoint': 'gray',
            'Pareto Front': 'red',
            'Best HV Point': 'green'
        }
        

        plt.figure(figsize=(12, 6))
        parallel_coordinates(
            df_parallel_plot,
            class_column='Type',
            cols=pivot.columns,
            color=[color_map[t] for t in df_parallel_plot['Type'].unique()],
            alpha=0.7
        )
        plt.title(f'Seed {seed} - Parallel Coordinates Plot of Normalized R²')
        plt.xlabel('Traits')
        plt.ylabel('Normalized R²')
        plt.legend(title='Point Type')
        plt.show()

    # Aggregate metrics for best epochs across seeds
    rows = []
    for seed, epoch in best_checkpoints.items():
        rows.append(df[(df['seed'] == seed) & (df['epoch'] == epoch)])
    df_best = pd.concat(rows)

    metrics = ['r2', 'nmae', 'r']

    agg_mean = df_best.groupby('trait')[metrics].mean()
    agg_std = df_best.groupby('trait')[metrics].std()

    agg = pd.concat([agg_mean.add_suffix('_mean'), agg_std.add_suffix('_std')], axis=1)
    agg.columns = ['_'.join(col) if isinstance(col, tuple) else col for col in agg.columns]

    # Format mean ± std nicely as strings
    for metric in metrics:
        agg[f'{metric}_mean±std'] = agg[f'{metric}_mean'].round(2).astype(str) + ' ± ' + agg[f'{metric}_std'].round(2).astype(str)

    final = agg[[f'{metric}_mean±std' for metric in metrics]].reset_index()

    print("\nAggregated metrics per trait (mean ± std):")
    print(final)

    return best_checkpoints, final


# === USAGE EXAMPLE ===
# df = pd.read_csv('your_data.csv')  # Your dataframe with columns: seed, epoch, trait, r2, nmae, r
# best_epochs, aggregated_metrics = select_best_epochs_by_hv(df)




In [None]:
#epoch best summary should have structure like seed, epoch, trait, r2, nmae, r
df = pd.read_csv('path/to/output_folder/epoch_wise_summary.csv')
best_epochs, aggregated_metrics = select_best_epochs_by_hv(df)