In [1]:
%cd /home/dev/24/es-bench

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from pathlib import Path
import optuna
from ebes.pipeline.utils import optuna_df
from optuna.trial import TrialState

  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/home/dev/24/es-bench


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import warnings
from omegaconf import OmegaConf

warnings.filterwarnings("ignore")

METRIC_FOR_DS = {
    "mimic3": "MulticlassAUROC",
    "physionet2012": "MulticlassAUROC",
    "age": "MulticlassAccuracy",
    "x5": "MulticlassAccuracy",
    "pendulum": "R2Score",
    "taobao": "MulticlassAUROC",
    "mbd": "MultiLabelMeanAUROC",
}
METRIC_PRETTY = {
    "MIMIC-III": "ROC AUC",
    "PhysioNet2012": "ROC AUC",
    "Age": "Accuracy",
    "Retail": "Accuracy",
    "Pendulum": "R2",
    "Taobao": "ROC AUC",
    "MBD": "Mean ROC AUC",
}
METHODS = sorted([
    "gru",
    "coles",
    "mlem",
    "mlp",
    "mamba",
    "transformer",
    "mtand",
    "primenet",
])
METHODS_PRETTY = {
    "mtand": "mTAND",
    "gru": "GRU",
    "mlp": "MLP",
    "mamba": "Mamba",
    "coles": "CoLES",
    "primenet": "PrimeNet",
    "mlem": "MLEM",
    "transformer": "Transformer",
}
DATASETS = sorted([
    "x5",
    "mbd",
    "age",
    "physionet2012",
    "mimic3",
    "pendulum",
    "taobao",
])
DATASETS_PRETTY = {
    "x5": "Retail",
    "age": "Age",
    "physionet2012": "PhysioNet2012",
    "pendulum": "Pendulum",
    "mimic3": "MIMIC-III",
    "mbd": "MBD",
    "taobao": "Taobao",
}


def print_latex(df):
    df = df.copy()
    print(df.to_latex(bold_rows=True, column_format="rr" + "c" * len(df.columns)))

In [14]:
def get_meta(which):
    if which == "time":
        param_name = "params_model.preprocess.params.time_process"
        options = ["none", ["cat", "diff"]]
        options_in_table = ["w/o time", "with time"]
        option_name = "Time process"
    elif which == "agg":
        param_name = "params_model.aggregation.name"
        options = ["TakeLastHidden", "ValidHiddenMean"]
        options_in_table = ["Last hidden", "Mean hidden"]
        option_name = "Aggregation"
    elif which == "norm":
        param_name = "params_model.preprocess.params.num_norm"
        options = [True, False]
        options_in_table = ["with norm", "w/o norm"]
        option_name = "Normalization"
    return param_name, options, option_name, options_in_table  # type: ignore


def param_importance(datasets, methods, which="time"):
    param_name, options, option_name, options_in_table = get_meta(which)

    index = pd.MultiIndex.from_product(
        [[DATASETS_PRETTY[dataset] for dataset in datasets], options_in_table],
        names=["Dataset", option_name],
    )
    res = pd.DataFrame(
        index=index,
        columns=[METHODS_PRETTY[method] for method in methods],
        dtype=float,
    )

    for method in methods:
        for dataset in datasets:
            param_name, options, option_name, options_in_table = get_meta(which)
            # if method == "mlem" and "aggregation" in param_name:
            #     param_name = "params_model.preprocess.params.enc_aggregation"
            try:
                path = Path(f"log/{dataset}/{method}/optuna")
                df, _ = optuna_df(path)  # type: ignore
                df = df.copy()
                df = df[df["state"] == "COMPLETE"]
                test = f"user_attrs_test_{METRIC_FOR_DS[dataset]}_mean"
                sorted_df = df.sort_values("value")  # type: ignore
                option_dict = pd.DataFrame(columns=["mean", "std", "str"])
                for option, option_in_t in zip(options, options_in_table):
                    option = [option] if not isinstance(option, list) else option  # type: ignore
                    option_metrics = sorted_df[sorted_df[param_name].isin(option)].iloc[
                        -3:
                    ][test]

                    option_dict.loc[option_in_t] = [
                        option_metrics.mean(),
                        option_metrics.std(),
                        f"{option_metrics.mean():.3f} \\pm {option_metrics.std():.3f}",
                    ]
                highlight = (
                    option_dict["mean"].max() - option_dict["mean"].min()
                ) > 2 * option_dict["std"].mean()
                max_option = option_dict.sort_values("mean").index[-1]
                for table_option in option_dict.index:
                    value = option_dict.loc[table_option, "str"]
                    if highlight and max_option == table_option:
                        value = f"\cellcolor{{lightgray}}\\bm{{{value}}}"
                    elif not highlight:
                        value = value #f"\\underline{{{value}}}"
                    res.loc[
                        (DATASETS_PRETTY[dataset], table_option), METHODS_PRETTY[method]
                    ] = f"${value}$"
            except Exception as e:
                print(method, dataset)
    na_cols = res.isna().all()
    if sum(na_cols) > 0:
        print("DROP NA COLS\n", res.columns[na_cols])
    res = res.loc[:, ~na_cols]
    return res


def optuna_importance(datasets, methods, which="lr"):
    if which == "time":
        param_name = "params_model.preprocess.params.time_process"
    elif which == "agg":
        param_name = "params_model.aggregation.name"
    elif which == "norm":
        param_name = "params_model.preprocess.params.num_norm"
    elif which == "lr":
        param_name = "params_optimizer.params.lr"
    res = pd.DataFrame(
        index=[DATASETS_PRETTY[dataset] for dataset in datasets],
        columns=[METHODS_PRETTY[method] for method in methods],
        dtype=float,
    )
    for method in methods:
        for dataset in datasets:
            print(dataset)
            try:
                path = Path(f"log/{dataset}/{method}/optuna")
                df, study = optuna_df(path)
                importance = optuna.importance.get_param_importances(study)
                rank = int(
                    pd.Series(importance).rank(ascending=False)[
                        param_name.replace("params_", "")
                    ]
                )
                res.loc[DATASETS_PRETTY[dataset], METHODS_PRETTY[method]] = f"${rank}$"
            except:
                pass
    return res.T


def best_param(datasets, methods, param_name="time"):
    if param_name == "time":
        param_name = "model.preprocess.params.time_process"
    elif param_name == "agg":
        param_name = "model.aggregation.name"
    res = pd.DataFrame(
        index=[dataset_names[dataset] for dataset in datasets],
        columns=[method_names[method] for method in methods],
        dtype=float,
    )
    for method in methods:
        for dataset in datasets:
            try:
                path = Path(f"configs/specify/{dataset}/{method}/best.yaml")
                config = OmegaConf.load(path)
                value = OmegaConf.select(config, param_name)
                res.loc[dataset_names[dataset], method_names[method]] = value
            except:
                pass
    return res.T


# res = param_importance(DATASETS, METHODS, "norm")
# res = best_param(DATASETS, METHODS, "time")
res = optuna_importance(DATASETS, METHODS, "lr")
res

Unnamed: 0,Age,MBD,MIMIC-III,Pendulum,PhysioNet2012,Taobao,Retail
CoLES,$1$,$10$,$2$,$7$,$3$,$1$,$4$
GRU,$2$,$1$,$2$,$3$,$4$,$1$,$1$
Mamba,$1$,$1$,$1$,$1$,$1$,$1$,$1$
MLEM,$2$,$3$,$6$,$11$,$4$,$1$,$1$
MLP,$2$,$1$,$3$,$5$,$1$,$2$,$1$
mTAND,$1$,$1$,$1$,$1$,$3$,$1$,$1$
PrimeNet,$11$,$1$,$1$,$2$,$1$,$1$,$1$
Transformer,$1$,$4$,$7$,$1$,$10$,$3$,$8$


In [15]:
print(res.to_latex(bold_rows=True, column_format="rr" + "c" * len(res.columns)))

\begin{tabular}{rrccccccc}
\toprule
 & Age & MBD & MIMIC-III & Pendulum & PhysioNet2012 & Taobao & Retail \\
\midrule
\textbf{CoLES} & $1$ & $10$ & $2$ & $7$ & $3$ & $1$ & $4$ \\
\textbf{GRU} & $2$ & $1$ & $2$ & $3$ & $4$ & $1$ & $1$ \\
\textbf{Mamba} & $1$ & $1$ & $1$ & $1$ & $1$ & $1$ & $1$ \\
\textbf{MLEM} & $2$ & $3$ & $6$ & $11$ & $4$ & $1$ & $1$ \\
\textbf{MLP} & $2$ & $1$ & $3$ & $5$ & $1$ & $2$ & $1$ \\
\textbf{mTAND} & $1$ & $1$ & $1$ & $1$ & $3$ & $1$ & $1$ \\
\textbf{PrimeNet} & $11$ & $1$ & $1$ & $2$ & $1$ & $1$ & $1$ \\
\textbf{Transformer} & $1$ & $4$ & $7$ & $1$ & $10$ & $3$ & $8$ \\
\bottomrule
\end{tabular}



### Params influence

In [6]:
trials = study.trials
trials = [trial for trial in trials if trial.state == TrialState.COMPLETE]
plotted_trials = sorted(trials, key=lambda t: t.value)[:]
plotted_study = optuna.create_study()
for trial in plotted_trials:
    plotted_study.add_trial(trial)

[I 2024-06-26 08:02:21,266] A new study created in memory with name: no-name-73480d79-6fa5-4e2a-9df9-09b16cb763d7


In [98]:
df.rank().astype(int)

model.encoder.params.num_layers         10
model.encoder.params.hidden_size         9
optimizer.params.lr                      8
model.preprocess.params.num_norm         7
model.preprocess.params.cat_emb_dim      6
model.preprocess.params.num_emb_dim      5
model.encoder.params.dropout             4
model.aggregation.name                   3
model.preprocess.params.time_process     2
optimizer.params.weight_decay            1
dtype: int64

3.0

In [11]:
# fig = optuna.visualization.plot_parallel_coordinate(plotted_study, target=target, target_name=target_name, params=['model.encoder.params.pooling', 'pretrain_model.encoder.params.pooling',])
# fig = optuna.visualization.plot_contour(study, target=target, target_name=target_name, params=params+not_imp)
fig = optuna.visualization.plot_slice(
    plotted_study,
    target=target,
    target_name=target_name,
    params=params + not_imp,
)
# fig = optuna.visualization.plot_optimization_history(study, target=target, target_name=target_name, error_bar=False)
# targets = lambda t: (t.user_attrs["memory_after_mean"], t.user_attrs["val_metric_mean"])
# target_names = ["memory_after_mean", "val_metric_mean"]
# fig = optuna.visualization.plot_pareto_front(study, targets=targets, target_names=target_names)
fig