In [None]:
import json
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
results_folder = '../../MamaMia/hyperparameter_tuning/tune_results/2d_tuning_run_2'
results = []
for subfolder in os.listdir(results_folder):
    if os.path.isdir(os.path.join(results_folder, subfolder)):
        result_file = os.path.join(results_folder, subfolder, 'result.json')
        with open(result_file, 'r') as file:
            try:
                result = json.load(file)
            except json.JSONDecodeError:
                continue
            config = result['config']
            result.update(config)
            result.pop('config')
        results.append(result)
results_df = pd.DataFrame(results)
results_df.head()

In [None]:
results_df['mean_5_fold_ranking_score'] = results_df['mean_5_fold_ranking_score'].map(lambda x: 0.5 if x==0 else x)

In [None]:
hyperparameters = ['learning_rate', 'weight_decay', 'batch_size', 'label_smoothing', 'x_y_resolution', 'normalization', 'model_key']

fig, ax = plt.subplots(2, (len(hyperparameters) + 1 ) // 2, figsize=(22, 8))
for i, hyperparameter in enumerate(hyperparameters):
    sns.scatterplot(data=results_df, x=hyperparameter, y='mean_5_fold_ranking_score', hue='model_key', ax=ax[i%2][i//2])
    if hyperparameter in ['learning_rate', 'weight_decay', 'final_learning_rate', 'label_smoothing']:
        ax[i%2][i//2].set_xscale('log')
    if i==0:
        ax[i%2][i//2].legend(loc='lower left')
    else:
        ax[i%2][i//2].legend().remove()
plt.tight_layout()
plt.show()

# Plot Metric Curves

In [None]:
trials_folder = '../../MamaMia/hyperparameter_tuning/tune_trials/2d_tuning_run_2'
curve_dfs = []
for subfolder in os.listdir(trials_folder):
    if os.path.isdir(os.path.join(trials_folder, subfolder)):
        loss_file = os.path.join(trials_folder, subfolder, 'loss_log_detailed.csv')
        if os.path.exists(loss_file):
            with open(loss_file, 'r') as file:
                curve_df = pd.read_csv(file)
            curve_df['trial'] = subfolder
            curve_dfs.append(curve_df)
curves_df = pd.concat(curve_dfs)
curves_df

In [None]:
curves_df_filtered = curves_df.copy()
curves_df_filtered = curves_df_filtered[(curves_df_filtered['ranking_score'] > 0.5) & (curves_df_filtered['ranking_score'] != 0.75)]

fig = plt.figure(figsize=(20, 10))
sns.lineplot(data=curves_df_filtered, x='epoch', y='ranking_score', hue='trial')
plt.legend().remove()
plt.show()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
sns.scatterplot(data=curves_df_filtered, x='fairness_score', y='balanced_accuracy', hue='trial', ax=ax[0])
sns.scatterplot(data=curves_df_filtered, x='ranking_score', y='balanced_accuracy', hue='trial', ax=ax[1])
ax[0].hlines(0.5, 0, 1)
ax[1].hlines(0.5, 0, 1)
ax[0].legend().remove()
ax[1].legend().remove()
plt.show()