In [1]:
%cd /home/dev/24/es-bench

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from pathlib import Path
import optuna
from ebes.pipeline.utils import optuna_df
from optuna.trial import TrialState

  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/home/dev/24/es-bench


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import warnings
from omegaconf import OmegaConf

warnings.filterwarnings("ignore")

METRIC_FOR_DS = {
    "mimic3": "MulticlassAUROC",
    "physionet2012": "MulticlassAUROC",
    "age": "MulticlassAccuracy",
    "x5": "MulticlassAccuracy",
    # "pendulum": "R2Score",
    "pendulum_cls": "MulticlassAccuracy",
    "taobao": "MulticlassAUROC",
    "mbd": "MultiLabelMeanAUROC",
    "arabic": "MulticlassAccuracy",
    "electric_devices": "MulticlassAccuracy",
    "bpi_17": "MulticlassAUROC",
}
METRIC_PRETTY = {
    "MIMIC-III": "ROC AUC",
    "PhysioNet2012": "ROC AUC",
    "Age": "Accuracy",
    "Retail": "Accuracy",
    # "Pendulum": "$R^2$",
    "Pendulum": "Accuracy",
    "Taobao": "ROC AUC",
    "MBD": "Mean ROC AUC",
    "ArabicDigits": "Accuracy",
    "ElectricDevices": "Accuracy",
    "BPI17": "ROC AUC",
}
METHODS_PRETTY = {
    "coles": "CoLES",
    "gru": "GRU",
    "mlem": "MLEM",
    "transformer": "Transformer",
    "mamba": "Mamba",
    "convtran": "ConvTran",
    "mtand": "mTAND",
    "primenet": "PrimeNet",
    "mlp": "MLP",
}
DATASETS = [
    "mbd",
    "x5",
    "age",
    "taobao",
    "bpi_17",
    "physionet2012",
    "mimic3",
    "pendulum_cls",
    "arabic",
    "electric_devices",
]
DATASETS_PRETTY = {
    "x5": "Retail",
    "age": "Age",
    "physionet2012": "PhysioNet2012",
    # "pendulum": "Pendulum",
    "pendulum_cls": "Pendulum",
    "mimic3": "MIMIC-III",
    "mbd": "MBD",
    "taobao": "Taobao",
    "arabic": "ArabicDigits",
    "electric_devices": "ElectricDevices",
    "bpi_17": "BPI17",
}
category_mapping = {
    "MBD": "Discrete \\ES",
    "Retail": "Discrete \\ES",
    "Age": "Discrete \\ES",
    "Taobao": "Discrete \\ES",
    "BPI17": "Discrete \\ES",
    "PhysioNet2012": "Continuous \\ES",
    "MIMIC-III": "Continuous \\ES",
    "Pendulum": "Continuous \\ES",
    "ArabicDigits": "Time Series",
    "ElectricDevices": "Time Series",
}


def print_latex(df, no_metrics=False):
    df = df.copy()
    columns_with_categories = [
        (
            category_mapping[col],
            f"\\textbf{{{col}}}",
            f"\\footnotesize{{{METRIC_PRETTY[col]}}}",
        )
        for col in df.columns
    ]
    if no_metrics:
        columns_with_categories = [i[:2] for i in columns_with_categories]

    df.columns = pd.MultiIndex.from_tuples(
        columns_with_categories,
        names=(
            ["Category", "Dataset", "\\footnotesize{Metric}"]
            if not no_metrics
            else ["Category", "Dataset"]
        ),
    )

    df.index.name = None
    res = df.to_latex(
        bold_rows=True,
        column_format="r|" + 5 * "c" + "|" + 3 * "c" + "|" + 2 * "c",
        multicolumn_format="c|",
    )
    res = res.splitlines()
    res.insert(3, "\midrule")
    res = "\n".join(res)
    print(res)

In [70]:
def get_meta(which):
    if which == "time":
        param_name = "params_model.preprocess.params.time_process"
        options = ["none", ["cat", "diff"]]
        options_in_table = ["w/o time", "with time"]
        option_name = "Time process"
    elif which == "agg":
        param_name = "params_model.aggregation.name"
        options = ["TakeLastHidden", "ValidHiddenMean"]
        options_in_table = ["Last hidden", "Mean hidden"]
        option_name = "Aggregation"
    elif which == "norm":
        param_name = "params_model.preprocess.params.num_norm"
        options = [False, True]
        options_in_table = ["w/o norm", "with norm"]
        option_name = "Normalization"
    return param_name, options, option_name, options_in_table  # type: ignore


def param_importance(datasets, methods, which="time", percent=False):
    param_name, options, option_name, options_in_table = get_meta(which)

    if not percent:
        index = pd.MultiIndex.from_product(
            [[DATASETS_PRETTY[dataset] for dataset in datasets], options_in_table],
            names=["Dataset", option_name],
        )
    else:
        index = [DATASETS_PRETTY[dataset] for dataset in datasets]
    res = pd.DataFrame(
        index=index,
        columns=[METHODS_PRETTY[method] for method in methods],
        dtype=float,
    )

    for method in methods:
        for dataset in datasets:
            param_name, options, option_name, options_in_table = get_meta(which)
            if method == "mlem" and which == "agg":
                param_name = "params_model.preprocess.params.enc_aggregation"
            try:
                path = Path(f"log/{dataset}/{method}/optuna")
                df, _ = optuna_df(path)  # type: ignore
                df = df.copy()
                df = df[df["state"] == "COMPLETE"]
                test = f"user_attrs_test_{METRIC_FOR_DS[dataset]}_mean"
                sorted_df = df.sort_values(test)  # type: ignore
                option_dict = pd.DataFrame(columns=["mean", "std", "str"])
                for option, option_in_t in zip(options, options_in_table):
                    option = [option] if not isinstance(option, list) else option  # type: ignore
                    option_metrics = sorted_df[sorted_df[param_name].isin(option)].iloc[
                        -5:
                    ][test]

                    option_dict.loc[option_in_t] = [
                        option_metrics.mean(),
                        option_metrics.std(),
                        f"{option_metrics.mean():.3f} \\pm {option_metrics.std():.3f}",
                    ]
                highlight = (
                    option_dict["mean"].max() - option_dict["mean"].min()
                ) > 2 * option_dict["std"].mean()
                max_option = option_dict.sort_values("mean").index[-1]
                if not percent:
                    for table_option in option_dict.index:
                        value = option_dict.loc[table_option, "str"]
                        if highlight and max_option == table_option:
                            value = f"\cellcolor{{lightgray}}\\bm{{{value}}}"
                        elif not highlight:
                            value = value  # f"\\underline{{{value}}}"
                        res.loc[
                            (DATASETS_PRETTY[dataset], table_option), METHODS_PRETTY[method]
                        ] = f"${value}$"
                else:
                    res.loc[DATASETS_PRETTY[dataset], METHODS_PRETTY[method]] = (
                        option_dict.loc[options_in_table[0], "mean"]
                        / option_dict.loc[options_in_table[1], "mean"]
                        * 100
                        - 100
                    )
            except Exception as e:
                print(method, dataset, e)
    na_cols = res.isna().all()
    if sum(na_cols) > 0:
        print("DROP NA COLS\n", res.columns[na_cols])
    res = res.loc[:, ~na_cols]
    return res


def optuna_importance(datasets, methods, which="lr"):
    if which == "time":
        param_name = "params_model.preprocess.params.time_process"
    elif which == "agg":
        param_name = "params_model.aggregation.name"
    elif which == "norm":
        param_name = "params_model.preprocess.params.num_norm"
    elif which == "lr":
        param_name = "params_optimizer.params.lr"
    res = pd.DataFrame(
        index=[DATASETS_PRETTY[dataset] for dataset in datasets],
        columns=[METHODS_PRETTY[method] for method in methods],
        dtype=float,
    )
    for method in methods:
        for dataset in datasets:
            print(dataset)
            try:
                path = Path(f"log/{dataset}/{method}/optuna")
                df, study = optuna_df(path)
                importance = optuna.importance.get_param_importances(study)
                rank = int(
                    pd.Series(importance).rank(ascending=False)[
                        param_name.replace("params_", "")
                    ]
                )
                res.loc[DATASETS_PRETTY[dataset], METHODS_PRETTY[method]] = f"{rank}"
            except:
                pass
    return res.T


def best_param(datasets, methods, param_name="time"):
    if param_name == "time":
        param_name = "model.preprocess.params.time_process"
    elif param_name == "agg":
        param_name = "model.aggregation.name"
    res = pd.DataFrame(
        index=[dataset_names[dataset] for dataset in datasets],
        columns=[method_names[method] for method in methods],
        dtype=float,
    )
    for method in methods:
        for dataset in datasets:
            try:
                path = Path(f"configs/specify/{dataset}/{method}/best.yaml")
                config = OmegaConf.load(path)
                value = OmegaConf.select(config, param_name)
                res.loc[dataset_names[dataset], method_names[method]] = value
            except:
                pass
    return res.T


res = param_importance(DATASETS, METHODS_PRETTY, "agg", percent=True)
# res = best_param(DATASETS, METHODS_PRETTY, "time")
# res = optuna_importance(DATASETS, METHODS_PRETTY, "lr")
res

convtran mbd 'params_model.aggregation.name'
convtran x5 'params_model.aggregation.name'
convtran age 'params_model.aggregation.name'
convtran taobao 'params_model.aggregation.name'
convtran bpi_17 'params_model.aggregation.name'
convtran physionet2012 'params_model.aggregation.name'
convtran mimic3 'params_model.aggregation.name'
convtran pendulum_cls 'params_model.aggregation.name'
convtran arabic 'params_model.aggregation.name'
convtran electric_devices 'params_model.aggregation.name'
primenet mbd 'params_model.aggregation.name'
primenet x5 'params_model.aggregation.name'
primenet age 'params_model.aggregation.name'
primenet taobao 'params_model.aggregation.name'
primenet bpi_17 'params_model.aggregation.name'
primenet physionet2012 'params_model.aggregation.name'
primenet mimic3 'params_model.aggregation.name'
primenet pendulum_cls 'params_model.aggregation.name'
primenet arabic 'params_model.aggregation.name'
primenet electric_devices 'params_model.aggregation.name'
DROP NA COLS
 

Unnamed: 0,CoLES,GRU,MLEM,Transformer,Mamba,mTAND,MLP
MBD,0.577791,0.554612,0.935085,-0.500976,-0.011003,1.261042,-6.697061
Retail,0.79415,0.473483,1.137783,-0.586098,-2.263153,-0.188058,-34.857057
Age,0.491085,-1.648698,1.953751,-3.462494,-5.162872,1.145189,-43.349727
Taobao,1.083814,0.69313,1.309293,-0.031093,-4.204109,0.666898,-11.952346
BPI17,2.951984,0.667424,1.942124,0.139527,-0.968516,-0.951959,-3.910872
PhysioNet2012,2.218798,3.650888,2.170041,0.812639,3.912241,-0.069278,3.657035
MIMIC-III,1.124623,0.778432,0.388929,0.451687,-1.188456,0.527866,0.613642
Pendulum,0.005515,-4.53194,1.698072,9.102975,-5.936522,9.750331,22.846846
ArabicDigits,0.257257,-0.183284,-0.045715,-0.541581,-0.448634,0.6266,-38.438966
ElectricDevices,2.915493,0.38106,2.313389,-0.418515,-2.728877,-0.076977,83.374234


In [36]:
print_latex(res.T)

\begin{tabular}{r|ccccc|ccc|cc}
\toprule
Category & \multicolumn{5}{c|}{Discrete \ES} & \multicolumn{3}{c|}{Continuous \ES} & \multicolumn{2}{c|}{Time Series} \\
\midrule
Dataset & \textbf{MBD} & \textbf{Retail} & \textbf{Age} & \textbf{Taobao} & \textbf{BPI17} & \textbf{PhysioNet2012} & \textbf{MIMIC-III} & \textbf{Pendulum} & \textbf{ArabicDigits} & \textbf{ElectricDevices} \\
\footnotesize{Metric} & \footnotesize{Mean ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{Accuracy} \\
\midrule
\textbf{CoLES} & -0.910752 & 0.112570 & -0.248629 & -1.201488 & -0.266050 & -0.401366 & -0.463234 & -61.177328 & -0.320464 & -1.749505 \\
\textbf{GRU} & -1.071905 & 0.158791 & -0.697168 & -1.777865 & -0.360726 & -0.270535 & -0.473769 & -60.484611 & 0.015271 & -0.402115 \\
\textbf{Transformer} & -1.244187 & -2.067255 & -0.5

In [11]:
print(
    res.to_latex(
        bold_rows=True,
        column_format="r" * len(res.columns) + "rr",
    )
)

\begin{tabular}{rrrrrrrrrr}
\toprule
 &  & CoLES & GRU & Transformer & Mamba & ConvTran & mTAND & PrimeNet & MLP \\
Dataset & Time process &  &  &  &  &  &  &  &  \\
\midrule
\multirow[t]{2}{*}{\textbf{MBD}} & \textbf{w/o time} & $0.817 \pm 0.002$ & $0.817 \pm 0.002$ & $0.811 \pm 0.002$ & $0.814 \pm 0.001$ & $0.811 \pm 0.001$ & $0.719 \pm 0.123$ & $0.743 \pm 0.021$ & $0.801 \pm 0.001$ \\
\textbf{} & \textbf{with time} & $\cellcolor{lightgray}\bm{0.825 \pm 0.000}$ & $\cellcolor{lightgray}\bm{0.826 \pm 0.000}$ & $\cellcolor{lightgray}\bm{0.822 \pm 0.001}$ & $\cellcolor{lightgray}\bm{0.822 \pm 0.000}$ & $\cellcolor{lightgray}\bm{0.816 \pm 0.001}$ & $0.795 \pm 0.002$ & $\cellcolor{lightgray}\bm{0.779 \pm 0.004}$ & $\cellcolor{lightgray}\bm{0.809 \pm 0.000}$ \\
\cline{1-10}
\multirow[t]{2}{*}{\textbf{Retail}} & \textbf{w/o time} & $0.551 \pm 0.001$ & $0.544 \pm 0.001$ & $0.529 \pm 0.002$ & $0.540 \pm 0.001$ & $0.533 \pm 0.002$ & $0.517 \pm 0.002$ & $0.518 \pm 0.001$ & $0.525 \pm 0.001$ \\
\

In [None]:
def get_grayscale_color_lr(x):
    if x >= 5:
        return f"${x}$"
    elif x >= 4:
        return f"\\cellcolor{{gray!25}}${x}$"
    elif x >= 3:
        return f"\\cellcolor{{gray!25}}${x}$"
    elif x >= 2:
        return f"\\cellcolor{{gray!50}}${x}$"
    elif x >= 1:
        return f"\\cellcolor{{gray!75}}${x}$"
    else:
        return f"\\cellcolor{{gray!100}}${x}$"

def get_grayscale_color_time(x):
    if x >= -0.3:
        return f"${x:.2f} \\%$"
    elif x >= -0.5:
        return f"\\cellcolor{{gray!15}}${x:.2f} \\%$"
    elif x >= -0.9:
        return f"\\cellcolor{{gray!25}}${x:.2f} \\%$"
    elif x >= -2:
        return f"\\cellcolor{{gray!50}}${x:.2f} \\%$"
    elif x >= -10:
        return f"\\cellcolor{{gray!75}}${x:.2f} \\%$"
    else:
        return f"\\cellcolor{{gray!100}}${x:.2f} \\%$"

def get_grayscale_color_norm(x):
    if x >= 0:
        return f"${x:.2f} \\%$"
    elif x >= -0.5:
        return f"\\cellcolor{{gray!15}}${x:.2f} \\%$"
    elif x >= -1:
        return f"\\cellcolor{{gray!25}}${x:.2f} \\%$"
    elif x >= -2:
        return f"\\cellcolor{{gray!50}}${x:.2f} \\%$"
    elif x >= -5:
        return f"\\cellcolor{{gray!75}}${x:.2f} \\%$"
    else:
        return f"\\cellcolor{{gray!100}}${x:.2f} \\%$"

def get_grayscale_color_pool(x):
    if x >= 0:
        return f"${x:.2f} \\%$"
    elif x >= -0.5:
        return f"\\cellcolor{{gray!15}}${x:.2f} \\%$"
    elif x >= -1:
        return f"\\cellcolor{{gray!25}}${x:.2f} \\%$"
    elif x >= -2:
        return f"\\cellcolor{{gray!50}}${x:.2f} \\%$"
    elif x >= -5:
        return f"\\cellcolor{{gray!75}}${x:.2f} \\%$"
    else:
        return f"\\cellcolor{{gray!100}}${x:.2f} \\%$"

print_latex(res.map(get_grayscale_color_pool).T)

\begin{tabular}{r|ccccc|ccc|cc}
\toprule
Category & \multicolumn{5}{c|}{Discrete \ES} & \multicolumn{3}{c|}{Continuous \ES} & \multicolumn{2}{c|}{Time Series} \\
\midrule
Dataset & \textbf{MBD} & \textbf{Retail} & \textbf{Age} & \textbf{Taobao} & \textbf{BPI17} & \textbf{PhysioNet2012} & \textbf{MIMIC-III} & \textbf{Pendulum} & \textbf{ArabicDigits} & \textbf{ElectricDevices} \\
\footnotesize{Metric} & \footnotesize{Mean ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{Accuracy} \\
\midrule
\textbf{CoLES} & \cellcolor{gray!15}$-0.25 \%$ & \cellcolor{gray!75}$-3.30 \%$ & \cellcolor{gray!25}$-0.68 \%$ & \cellcolor{gray!50}$-1.36 \%$ & \cellcolor{gray!75}$-3.53 \%$ & \cellcolor{gray!100}$-7.28 \%$ & \cellcolor{gray!75}$-2.54 \%$ & $11.78 \%$ & \cellcolor{gray!15}$-0.06 \%$ & \cellcolor{gray!15}$-0.46 \%$ \\
\tex

In [71]:
pd.DataFrame(res.values.flatten()).quantile([0., 0.2, 0.4, 0.6, 0.8, 1.0])

Unnamed: 0,0
0.0,-43.349727
0.2,-1.280505
0.4,-0.019039
0.6,0.642719
0.8,1.746883
1.0,83.374234
