In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import seaborn as sns

In [None]:
analysis = '3m_mrs_02_with_imaging'
save_data = True

In [None]:
output_dir = '/Users/jk1/Downloads'

In [None]:
# paths for 3M mrs 02
gridsearch_path_v1 = '/Users/jk1/temp/opsum_prediction_output/transformer/with_imaging/with_imaging/training/gridsearch_v1.jsonl'
gridsearch_path_v2 = '/Users/jk1/temp/opsum_prediction_output/transformer/with_imaging/with_imaging/training/gridsearch_v2.jsonl'

In [None]:
df_v1 = pd.read_json(gridsearch_path_v1, lines=True, dtype={'timestamp': 'object'}, convert_dates=False).drop(0)
df_v2 = pd.read_json(gridsearch_path_v2, lines=True, dtype={'timestamp': 'object'}, convert_dates=False).drop(0)

In [None]:
df = pd.concat([df_v1, df_v2], axis=0)
df.reset_index(drop=True, inplace=True)
df.shape

In [None]:
# find best by median_val_scores
best_df = df.sort_values('median_val_scores', ascending=False).head(1)
best_df

In [None]:
# convert best_df to json
if save_data:
    best_df.to_json(os.path.join(output_dir, f'hyperopt_selected_transformer_{analysis}_{best_df["timestamp"].values[0]}.json'), orient='records', lines=True)

In [None]:
df

In [None]:
# plot histogram of median_val_scores
ax = sns.histplot(x='median_val_scores', data=df)
ax.figure.set_size_inches(10,10)
ax.set_title('Median validation scores')
plt.show()

In [None]:
cat_gs_variables = ['num_layers', 'model_dim',
        'batch_size', 'balanced',
       'num_head']
cont_gs_variables = ['dropout', 'train_noise',
       'lr', 'weight_decay','grad_clip_value']

In [None]:
# plot a grid with all previous plots
fig, axes = plt.subplots(4, 3, figsize=(25, 25))
sns.boxplot(x='num_layers', y='median_val_scores', data=df, hue='feature_aggregation', ax=axes[0,0])
sns.boxplot(x='model_dim', y='median_val_scores', data=df, hue='feature_aggregation', ax=axes[0,1])
sns.boxplot(x='batch_size', y='median_val_scores', data=df, ax=axes[1,0])
sns.boxplot(x='balanced', y='median_val_scores', data=df, ax=axes[1,1])
sns.boxplot(x='num_head', y='median_val_scores', data=df, ax=axes[1,2])
sns.regplot(x='dropout', y='median_val_scores', data=df, ax=axes[2,0])
sns.regplot(x='train_noise', y='median_val_scores', data=df, logx=True, ax=axes[2,1])
# set x scale to log for train noise plot
axes[2,1].set_xscale('log')
sns.scatterplot(x='lr', y='median_val_scores', data=df, ax=axes[2,2])
axes[2,2].set_xlim(0.0001, 0.0003)

sns.scatterplot(x='weight_decay', y='median_val_scores', data=df, ax=axes[0,2])
# set x limits to 0, 0.1 for weight decay plot
axes[0,2].set_xlim(0, 0.0002)
sns.scatterplot(x='grad_clip_value', y='median_val_scores', data=df, ax=axes[3,0])

# set y limits to 0.88, 0.92 for all plots
# for ax in axes.flat:
    # ax.set_ylim(0.88, 0.915)

plt.show()


In [None]:
# fig.savefig('/Users/jk1/Downloads/gridsearch_results.png', dpi=300)

In [None]:
# plot interaction between number of layers and model dimension
ax = sns.catplot(x='num_layers', y='median_val_scores', data=df, kind='box', col='model_dim', col_wrap=3)
ax.set_titles('Model dimension: {col_name}')
ax.set_axis_labels('Number of layers', 'Median validation score')
# ax.set(ylim=(0.88, 0.92))
plt.show()

## Focus on best model dimension

In [None]:
best_model_dimension = 1024
df_best_model_dim = df[df['model_dim'] == best_model_dimension]

In [None]:
# plot a grid with all previous plots
fig, axes = plt.subplots(4, 3, figsize=(25, 25))
sns.boxplot(x='num_layers', y='median_val_scores', data=df_best_model_dim, ax=axes[0,0])
sns.boxplot(x='batch_size', y='median_val_scores', data=df_best_model_dim, ax=axes[1,0])
sns.boxplot(x='balanced', y='median_val_scores', data=df_best_model_dim, ax=axes[1,1])
sns.boxplot(x='num_head', y='median_val_scores', data=df_best_model_dim, ax=axes[1,2])
sns.regplot(x='dropout', y='median_val_scores', data=df_best_model_dim, ax=axes[2,0])
sns.regplot(x='train_noise', y='median_val_scores', data=df_best_model_dim, logx=True, ax=axes[2,1])
# set x scale to log for train noise plot
axes[2,1].set_xscale('log')
sns.scatterplot(x='lr', y='median_val_scores', data=df_best_model_dim, ax=axes[2,2])
sns.scatterplot(x='weight_decay', y='median_val_scores', data=df_best_model_dim, ax=axes[0,2])
# set x limits to 0, 0.1 for weight decay plot
axes[0,2].set_xlim(0, 0.0002)
sns.scatterplot(x='grad_clip_value', y='median_val_scores', data=df_best_model_dim, ax=axes[3,0])
# set x limits to 0, 0.5 for grad_clip_value plot
axes[3,0].set_xlim(0, 0.5)

# set y limits to 0.88, 0.92 for all plots
for ax in axes.flat:
    ax.set_ylim(0.88, 0.915)

plt.show()

## Model weight cleaning

Select only best models for Model weight cleaning

In [None]:
# find top X timestamps for median validation score, median rolling val score and worst cv fold validation score
top_n = 3
model_timestamps_to_retain = []
for metric in ['median_val_scores', 'median_rolling_val_scores', 'worst_cv_fold_val_score']:
    model_timestamps_to_retain.extend(df.sort_values(by=metric, ascending=False).head(top_n)['timestamp'].values)
model_timestamps_to_retain

In [None]:
pd.DataFrame(model_timestamps_to_retain, columns=['timestamp']).to_csv(f'/Users/jk1/Downloads/{analysis}_model_timestamps_to_retain.csv', index=False)

In [None]:
analysis