In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [2]:
trials_data= pd.read_csv("trials_training_data.csv")

In [None]:
trials_data

In [4]:
trials_data.columns

Index(['Name', 'State', 'Notes', 'User', 'Tags', 'Created', 'Runtime', 'Sweep',
       'base_lr', 'epochs', 'loss', 'loss_weights', 'max_lr', 'mode',
       'optimizer', 'optuna_trial', 'scheduler', 'RhiresD/train',
       'RhiresD/val', 'TabsD/train', 'TabsD/val', 'TmaxD/train', 'TmaxD/val',
       'TminD/train', 'TminD/val', 'best_val_loss',
       'best_val_loss_per_channel', 'epoch', 'epoch_time', 'loss/train',
       'loss/val', 'lr', 'precip_val_loss', 'temp_val_loss', 'tmax_val_loss',
       'tmin_val_loss', 'total_val_loss', 'trial', 'weights',
       'initial_weights'],
      dtype='object')

In [None]:
y = trials_data['RhiresD/val'].values
x = trials_data['loss/val'].values

def pareto_front_2d(x, y):
    # Sort by x (first objective)
    idx_sorted = np.argsort(x)
    pareto_idx = []
    min_y = np.inf
    for idx in idx_sorted:
        if y[idx] < min_y:
            pareto_idx.append(idx)
            min_y = y[idx]
    pareto_points = np.array(list(zip(x[pareto_idx], y[pareto_idx])))
    return pareto_points, np.array(pareto_idx)

pareto_points, pareto_idx = pareto_front_2d(x, y)

plt.figure(figsize=(12, 8))
plt.scatter(x, y, label='Trials (47)', color='blue', alpha=0.5, s=30)
# Annotate Pareto "elbow" (closest to origin)
pareto_distances = np.sqrt(pareto_points[:,0]**2 + pareto_points[:,1]**2)
elbow_idx = np.argmin(pareto_distances)
elbow_x, elbow_y = pareto_points[elbow_idx]
elbow_trial_idx = pareto_idx[elbow_idx]
elbow_trial_name = trials_data.iloc[elbow_trial_idx]['Name']
plt.scatter(elbow_x, elbow_y, marker='*', s=250, color='red', label='Pareto elbow')
plt.annotate(f'{elbow_trial_name}', (elbow_x, elbow_y),
             textcoords="offset points", xytext=(10,-20), ha='left', color='red', fontsize=11, fontweight='bold')

plt.xlabel('Total validation loss')
plt.ylabel('Precipitation Val Loss')
plt.title('Total validation loss vs Precipitation Val Loss for "constrained" trials')
plt.legend()
plt.grid(False)
plt.tight_layout()
plt.savefig('pareto_front_plot.png', dpi=500, bbox_inches='tight')