In [1]:
%cd ..

/home/dev/24/es-bench


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
from pathlib import Path
from datetime import datetime
import itertools
import seaborn as sns

import yaml
import pandas as pd
import numpy as np
from scipy import stats
from tqdm.auto import tqdm
from IPython.display import display
import matplotlib.pyplot as plt
import networkx as nx
from omegaconf import OmegaConf

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
logs_root = Path("log")

In [4]:
METRIC_FOR_DS = {
    "mimic3": "MulticlassAUROC",
    "physionet2012": "MulticlassAUROC",
    "age": "MulticlassAccuracy",
    "x5": "MulticlassAccuracy",
    # "pendulum": "R2Score",
    "pendulum_cls": "MulticlassAccuracy",
    "taobao": "MulticlassAUROC",
    "mbd": "MultiLabelMeanAUROC",
    "arabic": "MulticlassAccuracy",
    "electric_devices": "MulticlassAccuracy",
    "bpi_17": "MulticlassAUROC",
}
METRIC_PRETTY = {
    "MIMIC-III": "ROC AUC",
    "PhysioNet2012": "ROC AUC",
    "Age": "Accuracy",
    "Retail": "Accuracy",
    # "Pendulum": "$R^2$",
    "Pendulum": "Accuracy",
    "Taobao": "ROC AUC",
    "MBD": "Mean ROC AUC",
    "ArabicDigits": "Accuracy",
    "ElectricDevices": "Accuracy",
    "BPI17": "ROC AUC"
}
METHODS_PRETTY = {
    "mtand": "mTAND",
    "gru": "GRU",
    "mlp": "MLP",
    "mamba": "Mamba",
    "coles": "CoLES",
    "primenet": "PrimeNet",
    "mlem": "MLEM",
    "transformer": "Transformer",
    "convtran": "ConvTran",
}
DATASETS = [
    "x5",
    "mbd",
    "bpi_17",
    "age",
    "physionet2012",
    "mimic3",
    "pendulum_cls",
    "taobao",
    "arabic",
    "electric_devices",

]
DATASETS_PRETTY = {
    "x5": "Retail",
    "age": "Age",
    "physionet2012": "PhysioNet2012",
    # "pendulum": "Pendulum",
    "pendulum_cls": "Pendulum",
    "mimic3": "MIMIC-III",
    "mbd": "MBD",
    "taobao": "Taobao",
    "arabic": "ArabicDigits",
    "electric_devices": "ElectricDevices",
    "bpi_17": "BPI17",
}

category_mapping = {
    'MBD': 'Discrete \\ES',
    'Retail': 'Discrete \\ES',
    'Age': 'Discrete \\ES',
    'Taobao': 'Discrete \\ES',
    'BPI17': 'Discrete \\ES',
    'PhysioNet2012': 'Continuous \\ES',
    'MIMIC-III': 'Continuous \\ES',
    'Pendulum': 'Continuous \\ES',
    'ArabicDigits': 'Time Series',
    'ElectricDevices': 'Time Series',
}


def print_latex(df, no_metrics=False):
    df = df.copy()[['MBD', 'Retail', 'Age', 'Taobao', 'BPI17', 'PhysioNet2012', 'MIMIC-III', 'Pendulum', 'ArabicDigits', 'ElectricDevices']]
    columns_with_categories = [
        (category_mapping[col], f"\\textbf{{{col}}}", f"\\footnotesize{{{METRIC_PRETTY[col]}}}")
        for col in df.columns
    ]
    if no_metrics:
        columns_with_categories = [
            i[:2]
            for i in columns_with_categories
        ]

    df.columns = pd.MultiIndex.from_tuples(
        columns_with_categories,
        names=(
            ["Category", "Dataset", "\\footnotesize{Metric}"]
            if not no_metrics
            else ["Category", "Dataset"]
        ),
    )

    df.index.name = None
    res = df.to_latex(
        bold_rows=True,
        column_format="r|" + 5 * "c" + "|" + 3 * "c" + "|" + 2 * "c",
        multicolumn_format="c|",
    )
    res = res.splitlines()
    res.insert(3, '\midrule')
    res = "\n".join(res)
    print(res)


def get_ranks(df, crit_level: float = 0.01):
    rankings = []
    for dataset in df["dataset"].unique():
        ddf = df.query("dataset == @dataset")
        pvals = pd.Series(
            index=pd.Index(list(itertools.combinations(ddf.method.unique(), 2))),
            dtype=float,
        )
        
        for m1, m2 in pvals.index:
            x = ddf.query("method == @m1").metric
            y = ddf.query("method == @m2").metric
            pvals.at[(m1, m2)] = stats.mannwhitneyu(x, y).pvalue
        pvals = pvals.sort_values()

        # Benjamini–Hochberg
        crit_lvl_adj = np.arange(1, len(pvals) + 1) / len(pvals) * crit_level
        largest_reject = np.where(pvals <= crit_lvl_adj)[0].max()
        indistinguishable = pd.Series(
            data=np.arange(len(pvals)) > largest_reject,
            index=pvals.index,
        )
        
        for m in METHODS_PRETTY.values():
            indistinguishable.at[(m, m)] = False

        metrics = ddf.groupby("method").metric.median().sort_values(ascending=False)
        idx = metrics.index

        adj = indistinguishable.unstack(fill_value=False)
        adj = adj.loc[idx][idx]
        adj += adj.T
        
        cliques = list(nx.find_cliques(nx.Graph(adj.values)))
        cliques.sort(key=lambda ixs: metrics.iloc[ixs].mean(), reverse=True)
        ranking = pd.Series(index=idx, data="", name=dataset)
        for order, c in enumerate(cliques, 1):
            for i in c:
                if not ranking.iat[i]:
                    ranking.iat[i] = str(order)
                else:
                    ranking.iat[i] += f",{order}"

        rankings.append(ranking)
    
    return pd.concat(rankings, axis=1)

# Pick config details

In [5]:
opts = []

for d in DATASETS:
    for m in METHODS_PRETTY:
        path = logs_root / d / m / "correlation" / "seed_0" / "config.yaml"
        with open(path) as f:
            config = yaml.load(f, yaml.SafeLoader)
        try:
            opt = config["model"]["preprocess"]["params"]["time_process"]
        except KeyError:
            opt = None
            
        opts.append(dict(
            dataset=d,
            method=m,
            option=opt
        ))

df_options = (
    pd.DataFrame(opts)
    .assign(method=lambda df: df.method.replace(METHODS_PRETTY))
    .assign(dataset=lambda df: df.dataset.replace(DATASETS_PRETTY))
)

In [6]:
df_options.pivot(index="method", columns="dataset", values="option")

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
CoLES,cat,diff,cat,diff,cat,diff,cat,cat,none,diff
ConvTran,cat,diff,cat,cat,cat,diff,cat,cat,cat,none
GRU,cat,cat,none,none,cat,cat,diff,none,diff,diff
MLEM,cat,diff,cat,diff,cat,diff,cat,cat,none,diff
MLP,cat,cat,cat,cat,cat,cat,cat,diff,diff,diff
Mamba,diff,diff,none,cat,cat,cat,cat,cat,diff,diff
PrimeNet,diff,diff,cat,cat,cat,diff,cat,diff,none,diff
Transformer,diff,cat,none,none,cat,cat,cat,cat,cat,cat
mTAND,diff,cat,diff,none,cat,diff,cat,cat,none,diff


# Collect metrics from experiments

In [7]:
ms = []
for d in DATASETS:

    # correlations
    for m in METHODS_PRETTY:
        path = logs_root / d / m / "correlation" / "results.csv"
        try:
            res = pd.read_csv(path, index_col=0)
        except Exception as e:
            print(f"Error: {e}, Skipping {path}")
        else:
            ms.append(pd.DataFrame(dict(
                dataset=d,
                method=m,
                exp="main",
                seed=res.columns[:-2],
                metric=res.loc["test_" + METRIC_FOR_DS[d]].values[:-2],
            )))

    # train with events permutation
    path = logs_root / d / "gru" / "permutation_keep_last" / "results.csv"
    try:
        res = pd.read_csv(path, index_col=0)
    except Exception as e:
        print(f"Error: {e}, Skipping {path}")
    else:
        ms.append(pd.DataFrame(dict(
            dataset=d,
            method="gru",
            exp="train_perm_kl",
            seed=res.columns[:-2],
            metric=res.loc["test_" + METRIC_FOR_DS[d]].values[:-2],
        )))
        # train with events permutation
    path = logs_root / d / "gru" / "correlation(1)" / "results.csv"
    assert (
        OmegaConf.load(
            logs_root / d / "gru" / "correlation(1)/seed_0/config.yaml"
        ).model.preprocess.params.time_process
        == "none"
    )
    try:
        res = pd.read_csv(path, index_col=0)
    except Exception as e:
        print(f"Error: {e}, Skipping {path}")
    else:
        ms.append(pd.DataFrame(dict(
            dataset=d,
            method="gru",
            exp="train_NO_perm_kl",
            seed=res.columns[:-2],
            metric=res.loc["test_" + METRIC_FOR_DS[d]].values[:-2],
        )))


df_eval = (
    pd.concat(ms)
    .assign(method=lambda df: df.method.replace(METHODS_PRETTY))
    .assign(dataset=lambda df: df.dataset.replace(DATASETS_PRETTY))
)

mtand_mask = ((df_eval["dataset"] == "Pendulum") & (df_eval["method"] == "mTAND") & (df_eval["metric"] < 0.19) & (df_eval["exp"] == "main"))
print("THROW AWAY MTAND FOR PENDULUM:", df_eval[mtand_mask].shape[0])
df_eval = df_eval[~mtand_mask]

df_eval.query("exp == 'main'").pivot_table(index="method", columns="dataset", values="seed", aggfunc="count") 

THROW AWAY MTAND FOR PENDULUM: 4


dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
CoLES,20,20,20,20,20,20,20,20,20,20
ConvTran,20,20,20,20,20,20,20,20,20,20
GRU,20,20,20,20,20,20,20,20,20,20
MLEM,20,20,20,20,5,20,20,20,20,20
MLP,20,20,20,20,20,20,20,20,20,20
Mamba,20,20,20,20,20,20,20,20,20,20
PrimeNet,20,20,20,20,20,20,20,20,20,20
Transformer,20,20,20,20,20,20,20,20,20,20
mTAND,20,20,20,20,20,20,16,20,20,20


## Time permutation eval

In [8]:
from glob import glob

df_time_initial = pd.read_csv("log/Ablations/time.csv")
df_time = pd.concat(
    [pd.read_csv(f) for f in glob("log/Ablations/time_*-*1.csv")]
).dropna()

# merged = pd.merge(df_time, df_time_initial, how="inner", on=['dataset', 'method', 'seed'], suffixes=('_new', '_initial'))
# mismatched_metrics = merged[merged['metric_new'] != merged['metric_initial']][['dataset', 'method', 'seed', 'metric_initial', 'metric_new']]
# mismatched_metrics["diff"] = (mismatched_metrics["metric_new"] - mismatched_metrics["metric_initial"]).abs() < 0.002

merged_df = pd.merge(
    df_time_initial[["dataset", "method", "seed", "metric"]],
    df_time,
    on=["dataset", "method", "seed"],
    how="outer",
    suffixes=("_initial", "_new"),
)
merged_df['metric'] = merged_df['metric_initial'].combine_first(merged_df['metric_new'])
merged_df = merged_df.drop(columns=['metric_initial', 'metric_new'])
df_time = (
    merged_df
    .assign(exp="time")
    .assign(method=lambda df: df.method.replace(METHODS_PRETTY))
    .assign(dataset=lambda df: df.dataset.replace(DATASETS_PRETTY))
)

df_time = df_time.query("dataset != 'pendulum'")
(
    df_time
    .groupby(["method", "dataset"])
    .apply(lambda df: df.metric.count(), include_groups=False)
    .unstack()
).loc[["mTAND", "PrimeNet"]]

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
mTAND,20,20,20,20,20,20,20,20,20,20
PrimeNet,20,20,20,20,20,20,20,20,20,20


## Events permutation eval

In [9]:
df_perm_initial = pd.read_csv("log/Ablations/permutation_keep_last.csv")
df_perm = pd.concat(
    [pd.read_csv(f) for f in glob("log/Ablations/permutation_keep_last_*.csv")]
).dropna()

merged = pd.merge(df_perm, df_perm_initial, how="inner", on=['dataset', 'method', 'seed'], suffixes=('_new', '_initial'))
mismatched_metrics = merged[merged['metric_new'] != merged['metric_initial']][['dataset', 'method', 'seed', 'metric_initial', 'metric_new']]
mismatched_metrics["diff"] = (mismatched_metrics["metric_new"] - mismatched_metrics["metric_initial"]).abs() < 0.002


merged_df = pd.merge(
    df_perm_initial[["dataset", "method", "seed", "metric"]],
    df_perm,
    on=["dataset", "method", "seed"],
    how="outer",
    suffixes=("_initial", "_new"),
)
merged_df['metric'] = merged_df['metric_initial'].combine_first(merged_df['metric_new'])
merged_df = merged_df.drop(columns=['metric_initial', 'metric_new'])
merged_df = merged_df.drop_duplicates()
df_perm = (
    merged_df
    .assign(exp="perm")
    .assign(method=lambda df: df.method.replace(METHODS_PRETTY))
    .assign(dataset=lambda df: df.dataset.replace(DATASETS_PRETTY))
)
df_perm = df_perm[df_perm["dataset"] != "pendulum"]

(
    df_perm
    .groupby(["method", "dataset"])
    .apply(lambda df: df.metric.count(), include_groups=False)
    .unstack()
)

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
CoLES,20,20,20,20,20,20,20,20,20,20
ConvTran,20,20,20,20,20,20,20,20,20,20
GRU,20,20,20,20,20,20,20,20,20,20
MLEM,20,20,20,20,8,20,20,20,20,20
MLP,20,20,20,20,20,20,20,20,20,20
Mamba,20,20,20,20,20,20,20,20,20,20
PrimeNet,20,20,20,20,20,20,20,20,20,20
Transformer,20,20,20,20,20,20,20,20,20,20
mTAND,20,20,20,20,20,20,20,20,20,20


## Putting all together

In [10]:
df = (
    pd.concat((df_eval, df_time, df_perm))
    .assign(method=lambda df: df["method"].replace(METHODS_PRETTY))
    .assign(dataset=lambda df: df["dataset"].replace(DATASETS_PRETTY))
)

# Main result

In [11]:
main_res = df.query("exp == 'main'").pivot_table(
    index="method",
    columns="dataset",
    values="metric",
    aggfunc=["mean", "std"],
)

In [12]:
ranks = get_ranks(df.query("exp == 'main'"))
# ranks[ranks.isna()] = "6"
ranks

Unnamed: 0_level_0,Retail,MBD,BPI17,Age,PhysioNet2012,MIMIC-III,Pendulum,Taobao,ArabicDigits,ElectricDevices
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
CoLES,1,2,34,1,23,1,2,1,12,12
MLEM,2,3,12,1,1,2,3,1,3,1
GRU,2,1,1,2,1,1,3,1,4,1
Mamba,3,4,45,3,34,3,3,23,2,2
Transformer,34,4,123,2,234,3,4,34,12,2
ConvTran,4,5,23,4,234,34,34,2,1,2
MLP,5,6,4,5,4,6,6,5,6,4
PrimeNet,6,8,5,5,234,5,5,4,5,3
mTAND,6,7,4,5,2,45,1,5,5,3


In [13]:
best_method = ranks.map(lambda s: "1" in s.split(","))
best_method

Unnamed: 0_level_0,Retail,MBD,BPI17,Age,PhysioNet2012,MIMIC-III,Pendulum,Taobao,ArabicDigits,ElectricDevices
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
CoLES,True,False,False,True,False,True,False,True,True,True
MLEM,False,False,True,True,True,False,False,True,False,True
GRU,False,True,True,False,True,True,False,True,False,True
Mamba,False,False,False,False,False,False,False,False,False,False
Transformer,False,False,True,False,False,False,False,False,True,False
ConvTran,False,False,False,False,False,False,False,False,True,False
MLP,False,False,False,False,False,False,False,False,False,False
PrimeNet,False,False,False,False,False,False,False,False,False,False
mTAND,False,False,False,False,False,False,True,False,False,False


In [14]:
mean_method_rank = (
    ranks
    .map(lambda s: np.mean(list(map(int, s.split(",")))))
    .mean(1)
    .sort_values()
)
mean_method_rank

method
CoLES          1.70
GRU            1.70
MLEM           1.85
Transformer    2.85
Mamba          3.05
ConvTran       3.05
mTAND          4.25
PrimeNet       4.90
MLP            5.10
dtype: float64

In [15]:
mean_method_rank = (
    ranks
    .map(lambda s: np.mean(list(map(int, s.split(",")))))
    .mean(1)
    .sort_values()
)
mean_method_rank

method
CoLES          1.70
GRU            1.70
MLEM           1.85
Transformer    2.85
Mamba          3.05
ConvTran       3.05
mTAND          4.25
PrimeNet       4.90
MLP            5.10
dtype: float64

In [16]:
main_res_clean = df.query("exp == 'main'").pivot_table(
    index="method",
    columns="dataset",
    values="metric",
    aggfunc=["mean", "std"],
)

res = main_res_clean

In [17]:
main_res_latex = (
    "" # to be able to comment some any rows below
    + best_method.map(lambda flag: "\\cellcolor{lightgray} " if flag else "")
    + "$"
    + best_method.map(lambda flag: "\\mathbf{" if flag else "")
    + main_res.loc[:, "mean"].map(lambda x: f"{x:.3f}")
    + " \pm "
    + main_res.loc[:, "std"].map(lambda x: f"{x:.3f}")
    + ranks.map(lambda s: f"^{{{s}}}")
    + best_method.map(lambda flag: "}" if flag else "")
    + "$"
).loc[mean_method_rank.index, ['MBD', 'Retail', 'Age', 'Taobao', 'BPI17', 'PhysioNet2012', 'MIMIC-III', 'Pendulum', 'ArabicDigits', 'ElectricDevices']]

In [18]:
main_res_latex

Unnamed: 0_level_0,MBD,Retail,Age,Taobao,BPI17,PhysioNet2012,MIMIC-III,Pendulum,ArabicDigits,ElectricDevices
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
CoLES,$0.826 \pm 0.001^{2}$,\cellcolor{lightgray} $\mathbf{0.553 \pm 0.002...,\cellcolor{lightgray} $\mathbf{0.634 \pm 0.005...,\cellcolor{lightgray} $\mathbf{0.713 \pm 0.002...,"$0.742 \pm 0.010^{3,4}$","$0.840 \pm 0.004^{2,3}$",\cellcolor{lightgray} $\mathbf{0.902 \pm 0.001...,$0.740 \pm 0.013^{2}$,\cellcolor{lightgray} $\mathbf{0.983 \pm 0.004...,\cellcolor{lightgray} $\mathbf{0.729 \pm 0.019...
GRU,\cellcolor{lightgray} $\mathbf{0.827 \pm 0.001...,$0.543 \pm 0.002^{2}$,$0.626 \pm 0.004^{2}$,\cellcolor{lightgray} $\mathbf{0.713 \pm 0.004...,\cellcolor{lightgray} $\mathbf{0.754 \pm 0.004...,\cellcolor{lightgray} $\mathbf{0.846 \pm 0.004...,\cellcolor{lightgray} $\mathbf{0.901 \pm 0.002...,$0.683 \pm 0.031^{3}$,$0.975 \pm 0.003^{4}$,\cellcolor{lightgray} $\mathbf{0.741 \pm 0.013...
MLEM,$0.824 \pm 0.001^{3}$,$0.544 \pm 0.002^{2}$,\cellcolor{lightgray} $\mathbf{0.634 \pm 0.003...,\cellcolor{lightgray} $\mathbf{0.713 \pm 0.004...,\cellcolor{lightgray} $\mathbf{0.753 \pm 0.005...,\cellcolor{lightgray} $\mathbf{0.846 \pm 0.007...,$0.899 \pm 0.002^{2}$,$0.676 \pm 0.017^{3}$,$0.978 \pm 0.002^{3}$,\cellcolor{lightgray} $\mathbf{0.736 \pm 0.014...
Transformer,$0.821 \pm 0.002^{4}$,"$0.536 \pm 0.006^{3,4}$",$0.621 \pm 0.006^{2}$,"$0.692 \pm 0.013^{3,4}$",\cellcolor{lightgray} $\mathbf{0.749 \pm 0.006...,"$0.838 \pm 0.008^{2,3,4}$",$0.894 \pm 0.002^{3}$,$0.658 \pm 0.019^{4}$,\cellcolor{lightgray} $\mathbf{0.986 \pm 0.004...,$0.710 \pm 0.024^{2}$
Mamba,$0.820 \pm 0.003^{4}$,$0.538 \pm 0.003^{3}$,$0.609 \pm 0.006^{3}$,"$0.693 \pm 0.023^{2,3}$","$0.737 \pm 0.012^{4,5}$","$0.835 \pm 0.006^{3,4}$",$0.895 \pm 0.002^{3}$,$0.687 \pm 0.017^{3}$,$0.983 \pm 0.005^{2}$,$0.716 \pm 0.022^{2}$
ConvTran,$0.816 \pm 0.002^{5}$,$0.534 \pm 0.005^{4}$,$0.603 \pm 0.006^{4}$,$0.703 \pm 0.009^{2}$,"$0.748 \pm 0.006^{2,3}$","$0.837 \pm 0.006^{2,3,4}$","$0.892 \pm 0.005^{3,4}$","$0.674 \pm 0.028^{3,4}$",\cellcolor{lightgray} $\mathbf{0.986 \pm 0.003...,$0.711 \pm 0.019^{2}$
mTAND,$0.798 \pm 0.002^{7}$,$0.519 \pm 0.003^{6}$,$0.582 \pm 0.009^{5}$,$0.672 \pm 0.010^{5}$,$0.738 \pm 0.005^{4}$,$0.841 \pm 0.005^{2}$,"$0.888 \pm 0.003^{4,5}$",\cellcolor{lightgray} $\mathbf{0.777 \pm 0.031...,$0.951 \pm 0.010^{5}$,$0.631 \pm 0.019^{3}$
PrimeNet,$0.780 \pm 0.006^{8}$,$0.521 \pm 0.003^{6}$,$0.583 \pm 0.011^{5}$,$0.681 \pm 0.010^{4}$,$0.730 \pm 0.006^{5}$,"$0.839 \pm 0.004^{2,3,4}$",$0.887 \pm 0.004^{5}$,$0.600 \pm 0.026^{5}$,$0.958 \pm 0.009^{5}$,$0.636 \pm 0.016^{3}$
MLP,$0.809 \pm 0.001^{6}$,$0.526 \pm 0.002^{5}$,$0.581 \pm 0.007^{5}$,$0.659 \pm 0.035^{5}$,$0.737 \pm 0.004^{4}$,$0.835 \pm 0.004^{4}$,$0.881 \pm 0.001^{6}$,$0.186 \pm 0.006^{6}$,$0.760 \pm 0.011^{6}$,$0.437 \pm 0.019^{4}$


In [19]:
print_latex(main_res_latex)

\begin{tabular}{r|ccccc|ccc|cc}
\toprule
Category & \multicolumn{5}{c|}{Discrete \ES} & \multicolumn{3}{c|}{Continuous \ES} & \multicolumn{2}{c|}{Time Series} \\
\midrule
Dataset & \textbf{MBD} & \textbf{Retail} & \textbf{Age} & \textbf{Taobao} & \textbf{BPI17} & \textbf{PhysioNet2012} & \textbf{MIMIC-III} & \textbf{Pendulum} & \textbf{ArabicDigits} & \textbf{ElectricDevices} \\
\footnotesize{Metric} & \footnotesize{Mean ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{Accuracy} \\
\midrule
\textbf{CoLES} & $0.826 \pm 0.001^{2}$ & \cellcolor{lightgray} $\mathbf{0.553 \pm 0.002^{1}}$ & \cellcolor{lightgray} $\mathbf{0.634 \pm 0.005^{1}}$ & \cellcolor{lightgray} $\mathbf{0.713 \pm 0.002^{1}}$ & $0.742 \pm 0.010^{3,4}$ & $0.840 \pm 0.004^{2,3}$ & \cellcolor{lightgray} $\mathbf{0.902 \pm 0.001^{1}}$ & $0.740 \pm 

# Ablations

In [20]:
def mark_sign_diff(
    df: pd.DataFrame,
    exp1: str,
    exp2: str,
    correction: bool = True,
    crit_level: float = 0.01,
):
    rows = []
    name = exp1 + " VS " + exp2
    for d in df.dataset.unique():
        for m in df.method.unique():
            a = df.query("method == @m and dataset == @d and exp == @exp1")["metric"]
            b = df.query("method == @m and dataset == @d and exp == @exp2")["metric"]
            rows.append({
                "dataset": d,
                "method": m,
                name: stats.mannwhitneyu(a, b).pvalue
            })
    df_pvals = pd.DataFrame(rows).pivot(index="method", columns="dataset", values=name).dropna(axis=0, how="all")
    if not correction:
        return df_pvals

    pvals = df_pvals.values.flatten()
    idx_sort = np.argsort(pvals)
    idx_restore = np.argsort(idx_sort)
    pvals = pvals[idx_sort]
    pvals *= np.arange(len(pvals), 0, -1)
    accept = (pvals >= crit_level).cumsum() > 0
    return pd.DataFrame(
        index=df_pvals.index,
        columns=df_pvals.columns,
        data=~accept[idx_restore].reshape(df_pvals.shape),
    )

## Train with permutation

In [21]:
df_train_perm = df.query("method == 'GRU' and exp in ('main', 'train_NO_perm_kl', 'train_perm_kl')")

(
    df_train_perm
    .groupby(["exp", "dataset"])
    .apply(lambda df: df.metric.count(), include_groups=False)
    .unstack()
)

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
exp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
main,20,20,20,20,20,20,20,20,20,20
train_NO_perm_kl,20,20,20,20,20,20,20,20,20,20
train_perm_kl,20,20,20,20,20,20,20,20,20,20


In [22]:
abl_train_perm_kl_res = df_train_perm.pivot_table(
    index="exp",
    columns="dataset",
    values="metric",
    aggfunc=["mean", "std"],
)
abl_train_perm_kl_res.index = ["GRU", "GRU w/o time", "GRU w/o time w/ perm."]
abl_train_perm_kl_res

Unnamed: 0_level_0,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,std,std,std,std,std,std,std,std,std,std
dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
GRU,0.62592,0.975011,0.753785,0.741363,0.827015,0.901112,0.683183,0.845822,0.54313,0.713456,0.004143,0.003401,0.003927,0.013486,0.001093,0.002243,0.030507,0.004297,0.002488,0.003833
GRU w/o time,0.623148,0.975421,0.753785,0.741363,0.819625,0.898658,0.277172,0.845822,0.543104,0.685977,0.004417,0.004577,0.003927,0.013486,0.000786,0.002315,0.005666,0.004297,0.002431,0.016372
GRU w/o time w/ perm.,0.629805,0.962551,0.750378,0.62274,0.819044,0.889876,0.24686,0.843994,0.54582,0.702445,0.004122,0.005979,0.002858,0.013795,0.000882,0.002258,0.009337,0.004677,0.003093,0.006212


In [23]:
sign_diff = mark_sign_diff(df_train_perm, "main", "train_perm_kl")
sign_diff.index = ["GRU w/o time w/ perm."]
sign_diff.loc["GRU w/o time"] = True #mark_sign_diff(df_train_perm, "main", "train_NO_perm_kl")
sign_diff = sign_diff.iloc[::-1]
sign_diff

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
GRU w/o time,True,True,True,True,True,True,True,True,True,True
GRU w/o time w/ perm.,False,True,False,True,True,True,True,False,False,True


In [24]:
abl_perm_rel_res = (
    abl_train_perm_kl_res.loc[
        ["GRU w/o time", "GRU w/o time w/ perm."], "mean"
    ]
    / abl_train_perm_kl_res.loc["GRU", "mean"]
    * 100
    - 100
).map(lambda s: -0.0 if np.abs(s) < np.finfo(s).eps else s)



def get_grayscale_color(x):
    if x >= -0.1: 
        return f"${x:.2f} \\%$"
    elif x >= -0.5:
        return f"\\cellcolor{{gray!15}}${x:.2f} \\%$"
    elif x >= -1:
        return f"\\cellcolor{{gray!25}}${x:.2f} \\%$"
    elif x >= -5:
        return f"\\cellcolor{{gray!50}}${x:.2f} \\%$"
    elif x >= -25: 
        return f"\\cellcolor{{gray!75}}${x:.2f} \\%$"
    else:
        return f"\\cellcolor{{gray!100}}${x:.2f} \\%$"


abl_res = abl_perm_rel_res.map(get_grayscale_color)
abl_res[~sign_diff] = abl_res.map(lambda x : x.replace("\cellcolor{gray!10}", "").replace("\cellcolor{gray!15}", ""))
abl_res

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
GRU w/o time,\cellcolor{gray!15}$-0.44 \%$,$0.04 \%$,$-0.00 \%$,$-0.00 \%$,\cellcolor{gray!25}$-0.89 \%$,\cellcolor{gray!15}$-0.27 \%$,\cellcolor{gray!100}$-59.43 \%$,$-0.00 \%$,$-0.00 \%$,\cellcolor{gray!50}$-3.85 \%$
GRU w/o time w/ perm.,$0.62 \%$,\cellcolor{gray!50}$-1.28 \%$,$-0.45 \%$,\cellcolor{gray!75}$-16.00 \%$,\cellcolor{gray!25}$-0.96 \%$,\cellcolor{gray!50}$-1.25 \%$,\cellcolor{gray!100}$-63.87 \%$,$-0.22 \%$,$0.50 \%$,\cellcolor{gray!50}$-1.54 \%$


In [25]:
print_latex(abl_res)

\begin{tabular}{r|ccccc|ccc|cc}
\toprule
Category & \multicolumn{5}{c|}{Discrete \ES} & \multicolumn{3}{c|}{Continuous \ES} & \multicolumn{2}{c|}{Time Series} \\
\midrule
Dataset & \textbf{MBD} & \textbf{Retail} & \textbf{Age} & \textbf{Taobao} & \textbf{BPI17} & \textbf{PhysioNet2012} & \textbf{MIMIC-III} & \textbf{Pendulum} & \textbf{ArabicDigits} & \textbf{ElectricDevices} \\
\footnotesize{Metric} & \footnotesize{Mean ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{Accuracy} \\
\midrule
\textbf{GRU w/o time} & \cellcolor{gray!25}$-0.89 \%$ & $-0.00 \%$ & \cellcolor{gray!15}$-0.44 \%$ & \cellcolor{gray!50}$-3.85 \%$ & $-0.00 \%$ & $-0.00 \%$ & \cellcolor{gray!15}$-0.27 \%$ & \cellcolor{gray!100}$-59.43 \%$ & $0.04 \%$ & $-0.00 \%$ \\
\textbf{GRU w/o time w/ perm.} & \cellcolor{gray!25}$-0.96 \%$ & $0.50 \%

In [26]:
# abl_train_perm_kl_latex = (
#     ""
#     + sign_diff.map(lambda flag: "\\cellcolor{lightgray} " if flag else "")
#     + abl_train_perm_kl_res.loc[:, "mean"].map(lambda x: f"${x:.3f} \pm ")
#     + abl_train_perm_kl_res.loc[:, "std"].map(lambda x: f"{x:.3f}")
#     + sign_diff.map(lambda flag: "^*" if flag else "")
#     + "$"
# )
# print_latex(abl_train_perm_kl_latex)

## Random order

In [29]:
abl_perm_res = df_perm.pivot_table(
    index="method",
    columns="dataset",
    values="metric",
    aggfunc=["mean", "std"],
)
# sign_diff = mark_sign_diff(
#     df,
#     "main",
#     "perm",
# )
# main_res_latex_raw = (
#     "$"
#     + main_res.loc[:, "mean"].map(lambda x: f"{x:.3f}")
#     + " \pm "
#     + main_res.loc[:, "std"].map(lambda x: f"{x:.3f}")
#     + "$"
# )
# abl_perm_latex = (
#     ""
#     + sign_diff.map(lambda flag: "\\cellcolor{lightgray} " if flag else "")
#     + abl_perm_res.loc[:, "mean"].map(lambda x: f"${x:.3f} \pm ")
#     + abl_perm_res.loc[:, "std"].map(lambda x: f"{x:.3f}")
#     + sign_diff.map(lambda flag: "^*" if flag else "")
#     + "$"
# )
# abl_perm_comp = (
#     pd.concat((main_res_latex_raw, abl_perm_latex), keys=["Real", "Random"], names=["Order", "Method"])
#     .swaplevel()
#     .sort_index(ascending=[True, False])
# )
# abl_perm_comp.columns = pd.MultiIndex.from_tuples(
#     [(f"\\textbf{{{col}}}", f"\\footnotesize{{{METRIC_PRETTY[col]}}}") for col in abl_perm_comp.columns],
#     names=["Dataset", "\\footnotesize{Metric}"],
# )
# print(
#     abl_perm_comp
#     .to_latex(bold_rows=True, column_format="rr" + "c" * len(abl_perm_comp.columns))
#     .replace("cline{1-9}", "midrule")
# )

In [30]:
abl_perm_rel_res = (
    (abl_perm_res["mean"] / main_res["mean"] * 100 - 100)
    .map(lambda s: -0.0 if np.abs(s) < np.finfo(s).eps else s)
)
abl_perm_rel_res

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
CoLES,-1.634568,-33.856612,-4.655397,-68.78766,-0.09332018,-1.855977,-84.48731,-2.361443,-1.568942,-0.4894
ConvTran,-9.545386,-60.451821,-17.039679,-68.659752,-7.282701,-8.210439,-77.60795,-0.467545,-29.019969,-4.513193
GRU,-1.15191,-46.876239,-4.4556,-69.461135,-0.1017218,-4.241931,-76.08545,-1.494247,-2.251215,-0.671498
MLEM,-1.517596,-37.811258,-3.799606,-65.17256,-0.2986552,-1.434684,-81.84025,-1.707983,-2.57446,-0.891083
MLP,-0.0,-0.0,1e-06,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0
Mamba,-1.202666,-53.367876,-9.555549,-54.175991,-0.06146505,-3.044557,-82.13945,-0.647394,-2.440772,-0.001307
PrimeNet,-7.820438,-53.377464,-4.725307,-54.380865,-4.0752,-3.723431,-75.87658,-3.952115,-26.405886,-2.117629
Transformer,-0.0,-15.124055,2.7e-05,-25.264542,5.446395e-09,-1.4e-05,-4.52854e-07,0.027571,-0.08851,-0.050214
mTAND,-8.945188,-59.120024,-9.067676,-56.042365,-5.047486,-5.047173,-82.56998,-4.132835,-28.092774,-4.13177


In [31]:
# abl_perm_rel_latex = (
#     ""
#     + sign_diff.map(lambda flag: "\\cellcolor{lightgray} " if flag else "")
#     + abl_perm_rel_res.map(lambda x: f"${x:.2f} \\%")
#     + sign_diff.map(lambda flag: "^*" if flag else "")
#     + "$"
# ).loc[mean_method_rank.index]
# abl_perm_rel_latex

def get_grayscale_color(x):
    if x >= -0.48: 
        return f"${x:.2f} \\%$"
    elif x >= -1:
        return f"\\cellcolor{{gray!15}}${x:.2f} \\%$"
    elif x >= -5:
        return f"\\cellcolor{{gray!25}}${x:.2f} \\%$"
    elif x >= -25:
        return f"\\cellcolor{{gray!50}}${x:.2f} \\%$"
    elif x >= -50: 
        return f"\\cellcolor{{gray!75}}${x:.2f} \\%$"
    else:
        return f"\\cellcolor{{gray!100}}${x:.2f} \\%$"


abl_res = abl_perm_rel_res.map(get_grayscale_color).loc[mean_method_rank.index]
abl_res[~sign_diff] = abl_res.map(lambda x : x.replace("\cellcolor{gray!10}", "").replace("\cellcolor{gray!25}", ""))
abl_res

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
CoLES,\cellcolor{gray!25}$-1.63 \%$,\cellcolor{gray!75}$-33.86 \%$,\cellcolor{gray!25}$-4.66 \%$,\cellcolor{gray!100}$-68.79 \%$,$-0.09 \%$,\cellcolor{gray!25}$-1.86 \%$,\cellcolor{gray!100}$-84.49 \%$,\cellcolor{gray!25}$-2.36 \%$,\cellcolor{gray!25}$-1.57 \%$,\cellcolor{gray!15}$-0.49 \%$
GRU,\cellcolor{gray!25}$-1.15 \%$,\cellcolor{gray!75}$-46.88 \%$,\cellcolor{gray!25}$-4.46 \%$,\cellcolor{gray!100}$-69.46 \%$,$-0.10 \%$,\cellcolor{gray!25}$-4.24 \%$,\cellcolor{gray!100}$-76.09 \%$,\cellcolor{gray!25}$-1.49 \%$,\cellcolor{gray!25}$-2.25 \%$,\cellcolor{gray!15}$-0.67 \%$
MLEM,\cellcolor{gray!25}$-1.52 \%$,\cellcolor{gray!75}$-37.81 \%$,\cellcolor{gray!25}$-3.80 \%$,\cellcolor{gray!100}$-65.17 \%$,$-0.30 \%$,\cellcolor{gray!25}$-1.43 \%$,\cellcolor{gray!100}$-81.84 \%$,\cellcolor{gray!25}$-1.71 \%$,\cellcolor{gray!25}$-2.57 \%$,\cellcolor{gray!15}$-0.89 \%$
Transformer,$-0.00 \%$,\cellcolor{gray!50}$-15.12 \%$,$0.00 \%$,\cellcolor{gray!75}$-25.26 \%$,$0.00 \%$,$-0.00 \%$,$-0.00 \%$,$0.03 \%$,$-0.09 \%$,$-0.05 \%$
Mamba,\cellcolor{gray!25}$-1.20 \%$,\cellcolor{gray!100}$-53.37 \%$,\cellcolor{gray!50}$-9.56 \%$,\cellcolor{gray!100}$-54.18 \%$,$-0.06 \%$,\cellcolor{gray!25}$-3.04 \%$,\cellcolor{gray!100}$-82.14 \%$,\cellcolor{gray!15}$-0.65 \%$,\cellcolor{gray!25}$-2.44 \%$,$-0.00 \%$
ConvTran,\cellcolor{gray!50}$-9.55 \%$,\cellcolor{gray!100}$-60.45 \%$,\cellcolor{gray!50}$-17.04 \%$,\cellcolor{gray!100}$-68.66 \%$,\cellcolor{gray!50}$-7.28 \%$,\cellcolor{gray!50}$-8.21 \%$,\cellcolor{gray!100}$-77.61 \%$,$-0.47 \%$,\cellcolor{gray!75}$-29.02 \%$,\cellcolor{gray!25}$-4.51 \%$
mTAND,\cellcolor{gray!50}$-8.95 \%$,\cellcolor{gray!100}$-59.12 \%$,\cellcolor{gray!50}$-9.07 \%$,\cellcolor{gray!100}$-56.04 \%$,\cellcolor{gray!50}$-5.05 \%$,\cellcolor{gray!50}$-5.05 \%$,\cellcolor{gray!100}$-82.57 \%$,\cellcolor{gray!25}$-4.13 \%$,\cellcolor{gray!75}$-28.09 \%$,\cellcolor{gray!25}$-4.13 \%$
PrimeNet,\cellcolor{gray!50}$-7.82 \%$,\cellcolor{gray!100}$-53.38 \%$,\cellcolor{gray!25}$-4.73 \%$,\cellcolor{gray!100}$-54.38 \%$,\cellcolor{gray!25}$-4.08 \%$,\cellcolor{gray!25}$-3.72 \%$,\cellcolor{gray!100}$-75.88 \%$,\cellcolor{gray!25}$-3.95 \%$,\cellcolor{gray!75}$-26.41 \%$,\cellcolor{gray!25}$-2.12 \%$
MLP,$-0.00 \%$,$-0.00 \%$,$0.00 \%$,$-0.00 \%$,$-0.00 \%$,$-0.00 \%$,$-0.00 \%$,$-0.00 \%$,$-0.00 \%$,$-0.00 \%$


In [32]:
print_latex(abl_res)

\begin{tabular}{r|ccccc|ccc|cc}
\toprule
Category & \multicolumn{5}{c|}{Discrete \ES} & \multicolumn{3}{c|}{Continuous \ES} & \multicolumn{2}{c|}{Time Series} \\
\midrule
Dataset & \textbf{MBD} & \textbf{Retail} & \textbf{Age} & \textbf{Taobao} & \textbf{BPI17} & \textbf{PhysioNet2012} & \textbf{MIMIC-III} & \textbf{Pendulum} & \textbf{ArabicDigits} & \textbf{ElectricDevices} \\
\footnotesize{Metric} & \footnotesize{Mean ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{Accuracy} \\
\midrule
\textbf{CoLES} & $-0.09 \%$ & \cellcolor{gray!25}$-1.57 \%$ & \cellcolor{gray!25}$-1.63 \%$ & \cellcolor{gray!15}$-0.49 \%$ & \cellcolor{gray!25}$-4.66 \%$ & \cellcolor{gray!25}$-2.36 \%$ & \cellcolor{gray!25}$-1.86 \%$ & \cellcolor{gray!100}$-84.49 \%$ & \cellcolor{gray!75}$-33.86 \%$ & \cellcolor{gray!100}$-68.79 \%$ \\


## Random time

In [27]:
main_res_latex_raw = (
    "$"
    + main_res.loc[:, "mean"].map(lambda x: f"{x:.3f}")
    + " \pm "
    + main_res.loc[:, "std"].map(lambda x: f"{x:.3f}")
    + "$"
).loc[["PrimeNet", "mTAND"]]
main_res_latex_raw

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
PrimeNet,$0.583 \pm 0.011$,$0.958 \pm 0.009$,$0.730 \pm 0.006$,$0.636 \pm 0.016$,$0.780 \pm 0.006$,$0.887 \pm 0.004$,$0.600 \pm 0.026$,$0.839 \pm 0.004$,$0.521 \pm 0.003$,$0.681 \pm 0.010$
mTAND,$0.582 \pm 0.009$,$0.951 \pm 0.010$,$0.738 \pm 0.005$,$0.631 \pm 0.019$,$0.798 \pm 0.002$,$0.888 \pm 0.003$,$0.777 \pm 0.031$,$0.841 \pm 0.005$,$0.519 \pm 0.003$,$0.672 \pm 0.010$


In [28]:
abl_time_res = df_time.pivot_table(
    index="method",
    columns="dataset",
    values="metric",
    aggfunc=["mean", "std"],
).loc[["PrimeNet", "mTAND"]]
abl_time_res

Unnamed: 0_level_0,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,std,std,std,std,std,std,std,std,std,std
dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
PrimeNet,0.582239,0.681764,0.727874,0.600019,0.774628,0.88375,0.20201,0.839798,0.520654,0.679744,0.010346,0.081513,0.007071,0.013216,0.00573,0.00398,0.012842,0.003869,0.002726,0.010618
mTAND,0.581183,0.880582,0.737593,0.592679,0.794668,0.886383,0.33579,0.84031,0.518914,0.666003,0.00874,0.022516,0.005007,0.01562,0.001804,0.002733,0.102737,0.005297,0.003533,0.010163


In [29]:
sign_diff = mark_sign_diff(
    df.query("method in ('mTAND', 'PrimeNet')"),
    "main",
    "time",
)
sign_diff

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
PrimeNet,False,True,False,True,False,False,True,False,False,False
mTAND,False,True,False,True,True,False,True,False,False,False


In [49]:
import re
# abl_time_latex = (
#     ""
#     + sign_diff.map(lambda flag: "\\cellcolor{lightgray} " if flag else "")
#     + abl_time_res.loc[:, "mean"].map(lambda x: f"${x:.3f} \pm ")
#     + abl_time_res.loc[:, "std"].map(lambda x: f"{x:.3f}")
#     + sign_diff.map(lambda flag: "^*" if flag else "")
#     + "$"
# )
# abl_time_latex

abl_perm_rel_res = (
    abl_time_res["mean"] / main_res.loc[["PrimeNet", "mTAND"], "mean"]
    * 100
    - 100
).map(lambda s: -0.0 if np.abs(s) < np.finfo(s).eps else s)



def get_grayscale_color(x):
    if x >= -0.1: 
        return f"${x:.2f} \\%$"
    elif x >= -0.5:
        return f"\\cellcolor{{gray!25}}${x:.2f} \\%$"
    elif x >= -1:
        return f"\\cellcolor{{gray!25}}${x:.2f} \\%$"
    elif x >= -5:
        return f"\\cellcolor{{gray!50}}${x:.2f} \\%$"
    elif x >= -25: 
        return f"\\cellcolor{{gray!75}}${x:.2f} \\%$"
    else:
        return f"\\cellcolor{{gray!100}}${x:.2f} \\%$"


abl_res = abl_perm_rel_res.map(get_grayscale_color)
abl_res[~sign_diff] = abl_res.map(lambda x : re.sub(r"\\cellcolor\{gray!\d+\}", "", x))
abl_res

dataset,Age,ArabicDigits,BPI17,ElectricDevices,MBD,MIMIC-III,Pendulum,PhysioNet2012,Retail,Taobao
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
PrimeNet,$-0.12 \%$,\cellcolor{gray!100}$-28.86 \%$,$-0.30 \%$,\cellcolor{gray!75}$-5.62 \%$,$-0.72 \%$,$-0.40 \%$,\cellcolor{gray!100}$-66.34 \%$,$0.09 \%$,$-0.07 \%$,$-0.15 \%$
mTAND,$-0.06 \%$,\cellcolor{gray!75}$-7.44 \%$,$-0.00 \%$,\cellcolor{gray!75}$-6.11 \%$,\cellcolor{gray!25}$-0.45 \%$,$-0.23 \%$,\cellcolor{gray!100}$-56.79 \%$,$-0.08 \%$,$-0.01 \%$,$-0.91 \%$


In [32]:
# abl_time_comp = (
#     pd.concat((main_res_latex_raw, abl_time_latex), keys=["Real", "Random"], names=["Time", "Method"])
#     .swaplevel()
#     .sort_index(ascending=False)
# )
# abl_time_comp

In [50]:
print_latex(abl_res)

\begin{tabular}{r|ccccc|ccc|cc}
\toprule
Category & \multicolumn{5}{c|}{Discrete \ES} & \multicolumn{3}{c|}{Continuous \ES} & \multicolumn{2}{c|}{Time Series} \\
\midrule
Dataset & \textbf{MBD} & \textbf{Retail} & \textbf{Age} & \textbf{Taobao} & \textbf{BPI17} & \textbf{PhysioNet2012} & \textbf{MIMIC-III} & \textbf{Pendulum} & \textbf{ArabicDigits} & \textbf{ElectricDevices} \\
\footnotesize{Metric} & \footnotesize{Mean ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{ROC AUC} & \footnotesize{Accuracy} & \footnotesize{Accuracy} & \footnotesize{Accuracy} \\
\midrule
\textbf{PrimeNet} & $-0.72 \%$ & $-0.07 \%$ & $-0.12 \%$ & $-0.15 \%$ & $-0.30 \%$ & $0.09 \%$ & $-0.40 \%$ & \cellcolor{gray!100}$-66.34 \%$ & \cellcolor{gray!100}$-28.86 \%$ & \cellcolor{gray!75}$-5.62 \%$ \\
\textbf{mTAND} & \cellcolor{gray!25}$-0.45 \%$ & $-0.01 \%$ & $-0.06 \%$ & $-0.91 \%$ & $-0.00 \%$ & $-0.08 \%$ &