In [None]:
import json
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
results_folder = '../../MamaMia/hyperparameter_tuning/tune_results/3d_tuning_run_4'
results = []
for subfolder in os.listdir(results_folder):
    if os.path.isdir(os.path.join(results_folder, subfolder)):
        result_file = os.path.join(results_folder, subfolder, 'result.json')
        if not os.path.exists(result_file):
            continue
        with open(result_file, 'r') as file:
            try:
                result = json.load(file)
            except json.JSONDecodeError:
                continue
            config = result['config']
            result.update(config)
            result.pop('config')
        results.append(result)
results_df = pd.DataFrame(results)
results_df.head()

In [None]:
results_df['mean_5_fold_ranking_score'] = results_df['mean_5_fold_ranking_score'].map(lambda x: 0.5 if x==0 else x)
results_df['balanced_accuracy'] = results_df['balanced_accuracy'].map(lambda x: 0.5 if x==0 else x)

In [None]:
hyperparameters = ['optimizer', 'learning_rate', 'final_learning_rate', 'momentum', 'weight_decay', 'batch_size', 'label_smoothing', 'x_y_resolution', 'z_resolution', 'model_key']

fig, ax = plt.subplots(2, len(hyperparameters), figsize=(40, 8))
for i, hyperparameter in enumerate(hyperparameters):
    sns.scatterplot(data=results_df, x=hyperparameter, y='mean_5_fold_ranking_score', hue='model_key', ax=ax[0][i])
    sns.scatterplot(data=results_df, x=hyperparameter, y='balanced_accuracy', hue='model_key', ax=ax[1][i])
    if hyperparameter in ['learning_rate', 'weight_decay', 'final_learning_rate', 'label_smoothing']:
        ax[0][i].set_xscale('log')
        ax[1][i].set_xscale('log')
    if i==0:
        ax[0][i].legend().remove()
        ax[1][i].legend(loc='lower center')
    else:
        ax[0][i].legend().remove()
        ax[1][i].legend().remove()
plt.tight_layout()
plt.show()

In [None]:
results_df_filtered = results_df[results_df['balanced_accuracy'] != 0].copy()
sns.scatterplot(data=results_df_filtered, x='mean_5_fold_ranking_score', y='balanced_accuracy')
plt.show()

In [None]:
# TODO: Plot validation and training loss curves to see if number of epochs is fine