In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import seaborn as sns

In [None]:
log_folder_path = "/Users/jk1/temp/opsum_end/training/hyperopt/gridsearch"
output_dir = "/Users/jk1/Downloads"

In [None]:
# find all jsonl files in log_folder_path
gs_df = pd.DataFrame()
for root, dirs, files in os.walk(log_folder_path):
    for file in files:
        if file.endswith(".jsonl"):
            temp_df = pd.read_json(os.path.join(root, file),  
                              lines=True, dtype={"timestamp": "object"}, convert_dates=False).drop(0)
            # add file name as column
            temp_df["file_name"] = file
            gs_df = pd.concat([gs_df, temp_df], ignore_index=True)


In [None]:
# replace nan in loss_function with "nan_or_bce"
gs_df["loss_function"] = gs_df["loss_function"].fillna("nan_or_bce")
# replance nan in scheduler with "exponential"
gs_df["scheduler"] = gs_df["scheduler"].fillna("exponential")
# replace nan in restrict_to_first_event with 1
gs_df["restrict_to_first_event"] = gs_df["restrict_to_first_event"].fillna(1)
# replace nan in target_interval with 0
gs_df["target_interval"] = gs_df["target_interval"].fillna(0)

# replace nan in oversampling_ratio with 1

In [None]:
gs_df.split_file.unique()

In [None]:
gs_df

In [None]:
# find best by median_val_scoressqueeu
best_df = gs_df.sort_values("median_val_scores", ascending=False).head(1)
best_df

In [None]:
# best_df.to_csv(os.path.join(output_dir, "end_transformer_best_hyperparameters.csv"), index=False)

In [None]:
# plot histogram of median_val_scores for all split_files
ax = sns.histplot(x="median_val_scores", data=gs_df, hue="split_file")
ax.figure.set_size_inches(10,10)
ax.set_title("Median validation scores")
plt.show()


In [None]:
# plot a grid with all previous plots
fig, axes = plt.subplots(6, 3, figsize=(25, 25))
sns.boxplot(x="num_layers", y="median_val_scores", data=gs_df, ax=axes[0,0])
sns.boxplot(x="batch_size", y="median_val_scores", data=gs_df, ax=axes[1,0])
sns.boxplot(x="num_head", y="median_val_scores", data=gs_df, ax=axes[1,2])
sns.regplot(x="dropout", y="median_val_scores", data=gs_df, ax=axes[2,0])
sns.regplot(x="train_noise", y="median_val_scores", data=gs_df, logx=True, ax=axes[2,1])
# set x scale to log for train noise plot
axes[2,1].set_xscale("log")
sns.scatterplot(x="lr", y="median_val_scores", data=gs_df, ax=axes[2,2])
sns.scatterplot(x="weight_decay", y="median_val_scores", data=gs_df, ax=axes[0,2])
# set x limits to 0, 0.1 for weight decay plot
axes[0,2].set_xlim(0, 0.0002)

sns.scatterplot(x="grad_clip_value", y="median_val_scores", data=gs_df, ax=axes[3,0])

# oversampling_ratio
sns.boxplot(x="oversampling_ratio", y="median_val_scores", data=gs_df, ax=axes[3,1])

# loss_function
sns.boxplot(x="loss_function", y="median_val_scores", data=gs_df, ax=axes[3,2])

# tau
sns.boxplot(x="tau", y="median_val_scores", data=gs_df, ax=axes[0,1])

# gamma
sns.boxplot(x="gamma", y="median_val_scores", data=gs_df, ax=axes[1,1])

# alpha
sns.boxplot(x="alpha", y="median_val_scores", data=gs_df, ax=axes[4,0])

# n_lr_warm_up_steps
sns.boxplot(x="n_lr_warm_up_steps", y="median_val_scores", data=gs_df, ax=axes[4,1])

# model_dim
sns.boxplot(x="model_dim", y="median_val_scores", data=gs_df, ax=axes[4,2])

# scheduler
sns.boxplot(x="scheduler", y="median_val_scores", data=gs_df, ax=axes[5,0])

# # set y limits to 0.88, 0.92 for all plots
for ax in axes.flat:
    ax.set_ylim(0.75, 0.915)

plt.show()

In [None]:
def plot_hyperparameter_vs_metric(hyperparameters, metric, df):
    # plot auc_val vs hyperparameters
    n_hyperparameter_values = len(hyperparameters)
    fig, axes = plt.subplots(n_hyperparameter_values // 3 + 1, 3, figsize=(25, 35))
    for i, hyperparameter in enumerate(hyperparameters):
        ax = axes[i // 3, i % 3]
        # if number of unique values is > 10, use scatter plot
        if len(df[hyperparameter].unique()) > 10:
            sns.scatterplot(data=df, x=hyperparameter, y=metric, ax=ax)
        else:
            sns.boxplot(data=df, x=hyperparameter, y=metric, ax=ax)
        ax.set_title(f'{metric} vs {hyperparameter}')
        ax.set_xlabel(hyperparameter)
        ax.set_ylabel(metric)

        ax.set_ylim(0.75, 0.85)

    plt.tight_layout()


In [None]:
hyperparameters = ['batch_size', 'num_layers', 'model_dim', 'train_noise', 'weight_decay',
       'dropout', 'num_head', 'lr', 'n_lr_warm_up_steps', 'grad_clip_value',
       'early_stopping_step_limit', 'scheduler',
       'alpha', 'gamma', 'model_type', 'loss_function', 'imbalance_factor',
       'oversampling_ratio', 'n_trials', 'target_interval',
       'restrict_to_first_event', 'max_epochs']

In [None]:
all_event_gs_df = gs_df[gs_df["split_file"] == '/home/users/k/klug/data/opsum/gsu_Extraction_20220815_prepro_09052025_220520/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth']
first_event_gs_df = gs_df[gs_df["split_file"] == '/home/users/k/klug/data/opsum/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth']

In [None]:
metric = "median_val_scores"
print("all_event_gs_df")
plot_hyperparameter_vs_metric(hyperparameters, metric, all_event_gs_df)

In [None]:
print("first_event_gs_df")
plot_hyperparameter_vs_metric(hyperparameters, metric, first_event_gs_df)

In [None]:
        # Manual hyperparameter tuning
        #  - grad clip to 1
        # - lr to 1e-5
        # - wd to 5e-4
        # - dropout to 0.5
        # - loss function to focal
        # - num layers to 2
        # - model_dim to 256
        # - num_head 32

In [None]:
config = {
    "n_trials": 1000,
    "target_interval": 1,
    "restrict_to_first_event": 0,
    "batch_size": [256],
    "num_layers": [2, 6],
    "model_dim": [256, 1024],
    "train_noise": [1e-5, 1e-6],
    "weight_decay": [1e-5, 1e-4, 5e-4],
    "dropout": [0.3, 0.2, 0.5],
    "num_head": [16, 32],
    "lr": [1e-5, 1e-4],
    "n_lr_warm_up_steps": [0, 100],
    "grad_clip_value": [0.01, 0.75, 1],
    "early_stopping_step_limit": [10],
    "scheduler": ["exponential", "cosine"],
    "imbalance_factor": [62],
    "loss_function": ["focal", "bce"],
    "alpha": [0.25, 0.6],
    "gamma": [2.0, 3.0],
    "tau": [1.0],
    "oversampling_ratio": [1, 10],
    "max_epochs": 100
}