# TODO

In [None]:
import optuna
import os
from torch import optim
import pickle
import datetime

In [None]:
trials_params = {}


def objective(trial: optuna.Trial) -> float:
    # Define the hyperparameters to tune
    sku_emb_dim = trial.suggest_categorical("sku_emb_dim", choices=[32, 64, 128, 256])
    cat_emb_dims = trial.suggest_categorical("cat_emb_dims", choices=[32, 64, 128, 256])
    lstm_hidden_size = trial.suggest_categorical(
        "lstm_hidden_size", choices=[32, 64, 128, 256]
    )
    linear_hidden_size = trial.suggest_categorical(
        "linear_hidden_size", choices=[64, 128, 192, 256, 512]
    )
    lstm_bidirectional = trial.suggest_categorical(
        "lstm_bidirectional", choices=[True, False]
    )
    lstm_layers = trial.suggest_categorical("lstm_layers", choices=[2, 3, 4, 5, 6])
    learning_rate = trial.suggest_categorical(
        "learning_rate", choices=[0.01, 0.001, 0.0001]
    )
    dropout = trial.suggest_categorical("dropout", choices=[0, 0.1, 0.3, 0.5])
    flatten = trial.suggest_categorical("flatten", choices=[True, False])
    batch_size = trial.suggest_categorical("batch_size", choices=[64, 128, 256, 512])

    params = [
        sku_emb_dim,
        cat_emb_dims,
        lstm_hidden_size,
        linear_hidden_size,
        lstm_bidirectional,
        lstm_layers,
        learning_rate,
        dropout,
        batch_size,
    ]
    key = "_".join(map(str, params))
    if key in trials_params:
        print(f"Skipping trial {trial.params} [already run]")
        return trials_params[key]
    num_epochs = 20
    device, dl_train, dl_test, ds_train, ds_test = init_ds(batch_size, 4)
    model = DemandForecastingModel(
        sku_vocab_size,
        sku_emb_dim,
        cat_features_shapes,
        cat_emb_dims,
        time_features_dim,
        lstm_bidirectional,
        lstm_hidden_size,
        lstm_layers,
        linear_hidden_size,
        dropout,
        n_out,
    ).to(device)
    early_stop = {"patience": 10, "min_delta": 0.5}
    # Define the loss functions and optimizer
    regression_criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        patience=5,
        factor=0.5,
        threshold=early_stop["min_delta"] * 2,
        mode="min",
        # verbose=True,
    )

    # Train the model
    train_model(
        model,
        dl_train,
        dl_test,
        regression_criterion,
        optimizer,
        scheduler,
        num_epochs,
        batch_size,
        device,
        early_stop=early_stop,
        flatten=flatten,
    )

    # Validate the model
    val_metrics = validate_model(
        model,
        dl_test,
        regression_criterion,
        batch_size,
        True,
    )
    _collect()
    trials_params[key] = val_metrics["flatten_mse"]
    return val_metrics["flatten_mse"]


_collect()
study_name = "study-v12"  # Unique identifier of the study.
storage_name = "sqlite:///{}.db".format(study_name)

sampler = optuna.samplers.TPESampler(seed=seed)
sampler_fname = f"{study_name}-sampler.pkl"
if os.path.exists(sampler_fname):
    sampler = pickle.load(open(sampler_fname, "rb"))
else:
    with open(f"{study_name}-sampler.pkl", "wb") as fout:
        pickle.dump(sampler, fout)
try:
    study = optuna.create_study(
        direction="minimize",
        study_name=study_name,
        storage=storage_name,
        load_if_exists=True,
        sampler=sampler,
    )
    # Optimize the study
    study.optimize(
        objective,
        timeout=datetime.timedelta(minutes=60).seconds,
        gc_after_trial=True,
        show_progress_bar=True,
        catch=[Exception],
    )
except Exception as e:
    with open(f"{study_name}-sampler.pkl", "wb") as fout:
        pickle.dump(sampler, fout)
    df_optuna = study.trials_dataframe().sort_values(["value"])
    df_optuna.to_csv(f"{study_name}.csv", index=False)

# Print the best parameters and value
df_optuna = study.trials_dataframe().sort_values(["value"])
df_optuna.to_csv(f"{study_name}.csv", index=False)
param_column = [col for col in df_optuna.columns if "param" in col]
df_optuna.drop_duplicates(param_column, inplace=True)
display(df_optuna.head(20))
df_param = df_optuna.head(1)[param_column]
df_param.columns = [c.removeprefix("params_") for c in df_param.columns]
best_params = df_param.iloc[0].to_dict()

In [None]:
# Print the best parameters and value
df_optuna = study.trials_dataframe().sort_values(["value"])
df_optuna.to_csv(f"{study_name}.csv", index=False)
param_column = [col for col in df_optuna.columns if "param" in col]
df_optuna.drop_duplicates(param_column, inplace=True)
display(df_optuna.head(20))
df_param = df_optuna.head(1)[param_column]
df_param.columns = [c.removeprefix("params_") for c in df_param.columns]
best_params = df_param.iloc[0].to_dict()

In [None]:
from optuna.visualization import (
    plot_contour,
    plot_edf,
    plot_optimization_history,
    plot_parallel_coordinate,
    plot_param_importances,
    plot_rank,
    plot_slice,
    plot_timeline,
)

base_folder = f"plotly/{study_name}/"
os.makedirs(base_folder, exist_ok=True)
for _plot in [
    plot_contour,
    plot_edf,
    plot_optimization_history,
    plot_parallel_coordinate,
    plot_param_importances,
    plot_rank,
    plot_slice,
    plot_timeline,
]:

    try:
        _name = _plot.__str__().split(" ")[1]
        fig = _plot(study)
        fig.write_html(f"{base_folder}/{_name}.html")
    except Exception as e:
        print(f"unable to plot {_name} due to {e}")