# Imports

In [None]:
# standard library imports
import os

# related third party imports
import torch
import optuna
import structlog

# local application/library specific imports
from tools.plotter import plot_history
from tools.optuna_analyzer import get_study_object, get_history, inspect_trial

logger = structlog.get_logger(__name__)

# Inputs

In [None]:
###### INPUTS ######
EXP_NAME = "tune_race_pp_bert_logit_20250521"
CONFIG_ID = "bert_ordinal_logit_SL512_BALFalse_FRFalse_ESTrue"
TARGET_NAME = "val_bal_rps"  # NOTE: metrics
PARAMS_PLOT = ["lr", "weight_decay"]

In [None]:
study = get_study_object(exp_name=EXP_NAME, config_id=CONFIG_ID)

# Summary

In [None]:
print("Number of finished trials: ", len(study.trials))
best_trial = study.best_trial
print(f"Hyperparameters tuned: {list(best_trial.params.keys())}")
inspect_trial(study=study, trial_number=None)

In [None]:
# inspect_trial(study=study, trial_number=9)

# Tuning progress

In [None]:
fig = optuna.visualization.plot_optimization_history(study)
fig.show()

# Contour plot

In [None]:
fig = optuna.visualization.plot_contour(
    study, params=PARAMS_PLOT, target_name=TARGET_NAME
)
fig.show()

# Parallel coordinates

In [None]:
fig = optuna.visualization.plot_parallel_coordinate(
    study, params=PARAMS_PLOT, target_name=TARGET_NAME
)
fig.show()

# Hyperparam importance

In [None]:
fig = optuna.visualization.plot_param_importances(study, target_name=TARGET_NAME)
fig.show()

# Rank

In [None]:
fig = optuna.visualization.plot_rank(study, params=PARAMS_PLOT, target_name=TARGET_NAME)
fig.show()

# Slice

In [None]:
fig = optuna.visualization.plot_slice(
    study, params=PARAMS_PLOT, target_name=TARGET_NAME
)
fig.show()

# Timeline

In [None]:
fig = optuna.visualization.plot_timeline(study)
fig.show()

# Terminator improvement

In [None]:
fig = optuna.visualization.plot_terminator_improvement(study, plot_error=False)
fig.show()

In [None]:
# fig = optuna.visualization.plot_terminator_improvement(study, plot_error=True)
# fig.show()

# Learning convergence

In [None]:
lines = get_history(exp_name=EXP_NAME, config_id=CONFIG_ID, study=study, trial_n=None)
plot_history(lines, metric=TARGET_NAME, skip_warmup=0)