In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import seaborn as sns

In [None]:
log_folder_path = '/Users/jk1/temp/opsum_end/training/hyperopt/tte_gridsearch'
output_dir = '/Users/jk1/Downloads'

In [None]:
# find all jsonl files in log_folder_path
gs_df = pd.DataFrame()
for root, dirs, files in os.walk(log_folder_path):
    for file in files:
        if file.endswith('.jsonl'):
            temp_df = pd.read_json(os.path.join(root, file),  
                              lines=True, dtype={'timestamp': 'object'}, convert_dates=False).drop(0)
            # add file name as column
            temp_df['file_name'] = file
            gs_df = pd.concat([gs_df, temp_df], ignore_index=True)


In [None]:
gs_df

In [None]:
# find best by median_val_scores
best_df = gs_df.sort_values('median_val_mae', ascending=True).head(1)
best_df

In [None]:
best_df.file_name.values[0]

In [None]:
# best_df.to_csv(os.path.join(output_dir, 'tte_end_transformer_best_hyperparameters.csv'), index=False)

In [None]:
# only retain top 5 models
top_5_df = gs_df.sort_values('median_val_mae', ascending=True).head(5)
# make a table with model timestamp and file_name
top_5_df = top_5_df[['timestamp', 'file_name']]
# save table with timestamp of current date
timestamp = pd.Timestamp.now().strftime('%Y%m%d')
filename = f'tte_end_transformer_top_5_models_{timestamp}.csv'


In [None]:
# plot histogram of median_val_mae and median_val_mape
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
sns.histplot(gs_df['median_val_mae'], ax=axes[0])
sns.histplot(gs_df['median_val_mape'], ax=axes[1])
plt.show()

In [None]:
# plot a grid with all previous plots
fig, axes = plt.subplots(4, 3, figsize=(25, 25))
sns.boxplot(x='num_layers', y='median_val_mae', data=gs_df, ax=axes[0,0])
sns.boxplot(x='batch_size', y='median_val_mae', data=gs_df, ax=axes[1,0])
sns.boxplot(x='num_head', y='median_val_mae', data=gs_df, ax=axes[1,2])
sns.regplot(x='dropout', y='median_val_mae', data=gs_df, ax=axes[2,0])
sns.regplot(x='train_noise', y='median_val_mae', data=gs_df, logx=True, ax=axes[2,1])
# set x scale to log for train noise plot
axes[2,1].set_xscale('log')
sns.scatterplot(x='lr', y='median_val_mae', data=gs_df, ax=axes[2,2])
sns.scatterplot(x='weight_decay', y='median_val_mae', data=gs_df, ax=axes[0,2])
# set x limits to 0, 0.1 for weight decay plot
# axes[0,2].set_xlim(0, 0.0002)
sns.scatterplot(x='grad_clip_value', y='median_val_mae', data=gs_df, ax=axes[3,0])
# set x limits to 0, 0.5 for grad_clip_value plot
# axes[3,0].set_xlim(0, 0.5)

# # set y limits to 0.88, 0.92 for all plots
# for ax in axes.flat:
#     ax.set_ylim(0.88, 0.915)

plt.show()