In [1]:
from pathlib import Path
import xarray as xr
from tabularbench.core.enums import BenchmarkOrigin
from tabularbench.results.ranking_table import make_ranking_table_, process_benchmark_results, process_sweep_results
from tabularbench.results.results_sweep import ResultsSweep
from tabularbench.config.config_benchmark_sweep import ConfigBenchmarkSweep

cfg_general = ConfigBenchmarkSweep.load(Path("outputs_done/tabzilla_tabpfn_orig_finetune/tabpfn-default-tabzilla_has_completed_runs/config_benchmark_sweep.yaml"))
ds_benchmark = process_benchmark_results(cfg_general)

def prepare_ds(txt: str, path: Path):
    cfg = cfg_general
    cfg_general.model_plot_name = txt
    results_sweep = ResultsSweep.load(path / "results_sweep.nc")
    ds = process_sweep_results(cfg, results_sweep)
    ds.coords['model_name'] = [cfg.model_plot_name]
    return ds


dss = [
    prepare_ds("TabPFN (original) - Zero-shot", Path("outputs_done/tabzilla_tabpfn_orig_zeroshot/tabpfn-default-tabzilla_has_completed_runs")),
    prepare_ds("TabPFN (original) - Fine-tune", Path("outputs_done/tabzilla_tabpfn_orig_finetune/tabpfn-default-tabzilla_has_completed_runs")),
    prepare_ds("TabPFN (retrained) - Zero-shot", Path("outputs_done/tabzilla_tabpfn_foun_zeroshot/foundation-default-tabzilla_has_completed_runs")),
    prepare_ds("TabPFN (retrained) - Fine-tune", Path("outputs_done/tabzilla_tabpfn_foun_finetune/foundation-default-tabzilla_has_completed_runs")),
    prepare_ds("TabForest - Zero-shot", Path("outputs_done/tabzilla_tabsgfd_zeroshot/foundation-default-tabzilla_has_completed_runs")),
    prepare_ds("TabForest - Fine-tune", Path("outputs_done/tabzilla_tabsgfd_finetune/foundation-default-tabzilla_has_completed_runs")),
    prepare_ds("TabForestPFN - Zero-shot", Path("outputs_done/foundation_mix_600k_zeroshot/foundation-default-tabzilla_has_completed_runs")),
    prepare_ds("TabForestPFN - Fine-tune", Path("outputs_done/foundation_mix_600k_finetune/test_tabzilla_has_completed_runs")),
]

ds = xr.merge([
    ds_benchmark,
    *dss
])


df = make_ranking_table_(cfg_general, ds)
df

[32m2024-04-09 14:51:22.397[0m | [1mINFO    [0m | [36mtabularbench.results.results_sweep[0m:[36mload[0m:[36m44[0m - [1mLoaded results from outputs_done/tabzilla_tabpfn_orig_zeroshot/tabpfn-default-tabzilla_has_completed_runs/results_sweep.nc[0m
[32m2024-04-09 14:51:22.482[0m | [1mINFO    [0m | [36mtabularbench.results.results_sweep[0m:[36mload[0m:[36m44[0m - [1mLoaded results from outputs_done/tabzilla_tabpfn_orig_finetune/tabpfn-default-tabzilla_has_completed_runs/results_sweep.nc[0m
[32m2024-04-09 14:51:22.560[0m | [1mINFO    [0m | [36mtabularbench.results.results_sweep[0m:[36mload[0m:[36m44[0m - [1mLoaded results from outputs_done/tabzilla_tabpfn_foun_zeroshot/foundation-default-tabzilla_has_completed_runs/results_sweep.nc[0m
[32m2024-04-09 14:51:22.652[0m | [1mINFO    [0m | [36mtabularbench.results.results_sweep[0m:[36mload[0m:[36m44[0m - [1mLoaded results from outputs_done/tabzilla_tabpfn_foun_finetune/foundation-default-tabzilla_has_

Unnamed: 0,rank_min,rank_max,rank_mean,rank_median,acc_mean,acc_median
TabForestPFN - Fine-tune,1.0,24.0,8.0,6.0,0.68,0.663
TabPFN (retrained) - Fine-tune,1.0,26.0,8.3,7.75,0.678,0.7
CatBoost,1.0,22.0,8.9,7.75,0.676,0.663
XGBoost,1.0,23.0,9.2,8.25,0.674,0.671
TabPFN (original) - Fine-tune,1.0,25.0,9.3,8.0,0.67,0.677
TabForestPFN - Zero-shot,1.0,26.0,10.6,10.0,0.633,0.65
TabForest - Fine-tune,1.0,24.5,10.6,9.5,0.65,0.667
TabPFN (original) - Zero-shot,1.0,25.0,10.9,10.25,0.632,0.606
lightGBM,1.0,26.0,11.1,11.0,0.646,0.653
RandomForest,1.0,25.0,11.3,11.0,0.636,0.637


In [2]:
df[['rank_min', 'rank_max']] = df[['rank_min', 'rank_max']].astype(int)
# df['rank_mean'] = df['rank_mean'].apply(lambda x: f"{x:.1f}")
# df[['acc_mean', 'acc_median']] = df[['acc_mean', 'acc_median']].apply(lambda x: f"{x:.3f}")

format_mapping = {
    'acc_mean': '{:.3f}',
    'acc_median': '{:.3f}',
    'rank_mean': '{:.1f}',
    'rank_median': '{:.1f}',
}


print(df.to_latex(formatters=format_mapping))

\begin{tabular}{lrrrrrr}
\toprule
 & rank_min & rank_max & rank_mean & rank_median & acc_mean & acc_median \\
\midrule
TabForestPFN - Fine-tune & 1 & 24 & 8.0 & 6.0 & 0.680 & 0.663 \\
TabPFN (retrained) - Fine-tune & 1 & 26 & 8.3 & 7.8 & 0.678 & 0.700 \\
CatBoost & 1 & 22 & 8.9 & 7.8 & 0.676 & 0.663 \\
XGBoost & 1 & 23 & 9.2 & 8.2 & 0.674 & 0.671 \\
TabPFN (original) - Fine-tune & 1 & 25 & 9.3 & 8.0 & 0.670 & 0.677 \\
TabForestPFN - Zero-shot & 1 & 26 & 10.6 & 10.0 & 0.633 & 0.650 \\
TabForest - Fine-tune & 1 & 24 & 10.6 & 9.5 & 0.650 & 0.667 \\
TabPFN (original) - Zero-shot & 1 & 25 & 10.9 & 10.2 & 0.632 & 0.606 \\
lightGBM & 1 & 26 & 11.1 & 11.0 & 0.646 & 0.653 \\
RandomForest & 1 & 25 & 11.3 & 11.0 & 0.636 & 0.637 \\
Resnet & 1 & 26 & 11.9 & 10.0 & 0.602 & 0.613 \\
NODE & 1 & 26 & 12.1 & 12.0 & 0.611 & 0.596 \\
SAINT & 1 & 26 & 12.2 & 12.8 & 0.600 & 0.614 \\
SVM & 1 & 25 & 12.5 & 13.2 & 0.589 & 0.573 \\
FT-Transformer & 1 & 23 & 12.6 & 12.5 & 0.599 & 0.601 \\
DANet & 3 & 25 & 14.5 &

In [3]:
ds = xr.merge([
    ds_benchmark,
    dss[0]
])

make_ranking_table_(cfg_general, ds)

Unnamed: 0,rank_min,rank_max,rank_mean,rank_median,acc_mean,acc_median
CatBoost,1.0,15.0,6.1,5.0,0.676,0.663
XGBoost,1.0,16.0,6.2,5.0,0.674,0.671
TabPFN (original) - Zero-shot,1.0,19.0,7.4,6.0,0.632,0.606
lightGBM,1.0,19.0,7.5,6.25,0.646,0.653
RandomForest,1.0,18.0,7.7,8.0,0.636,0.637
NODE,1.0,19.0,8.3,8.0,0.611,0.596
Resnet,1.0,19.0,8.3,8.0,0.602,0.613
SAINT,1.0,19.0,8.4,7.5,0.6,0.614
SVM,1.0,18.5,8.5,8.0,0.589,0.573
FT-Transformer,1.0,16.0,8.7,8.5,0.599,0.601


In [4]:
ds = xr.merge([
    ds_benchmark,
    dss[1]
])

make_ranking_table_(cfg_general, ds)

Unnamed: 0,rank_min,rank_max,rank_mean,rank_median,acc_mean,acc_median
CatBoost,1.0,15.0,6.1,5.25,0.676,0.663
XGBoost,1.0,16.0,6.3,5.0,0.674,0.671
TabPFN (original) - Fine-tune,1.0,19.0,6.4,5.0,0.67,0.677
lightGBM,1.0,19.0,7.6,6.0,0.646,0.653
RandomForest,1.0,18.0,7.8,8.0,0.636,0.637
Resnet,1.0,19.0,8.3,8.0,0.602,0.613
SAINT,1.0,19.0,8.4,7.5,0.6,0.614
NODE,1.0,19.0,8.4,8.0,0.611,0.596
SVM,1.0,18.5,8.6,8.0,0.589,0.573
FT-Transformer,1.0,16.0,8.8,8.5,0.599,0.601


In [5]:
ds = xr.merge([
    ds_benchmark,
    dss[2]
])

make_ranking_table_(cfg_general, ds)

Unnamed: 0,rank_min,rank_max,rank_mean,rank_median,acc_mean,acc_median
CatBoost,1.0,15.0,5.8,5.0,0.676,0.663
XGBoost,1.0,17.0,6.0,5.0,0.674,0.671
lightGBM,1.0,19.0,7.4,6.0,0.646,0.653
RandomForest,1.0,18.0,7.5,7.0,0.636,0.637
Resnet,1.0,19.0,8.0,8.0,0.602,0.613
NODE,1.0,19.0,8.1,8.0,0.611,0.596
SAINT,1.0,19.0,8.2,7.5,0.6,0.614
SVM,1.0,18.5,8.3,8.0,0.589,0.573
FT-Transformer,1.0,16.0,8.5,8.0,0.599,0.601
DANet,2.0,19.0,9.6,9.0,0.596,0.608


In [6]:
ds = xr.merge([
    ds_benchmark,
    dss[3]
])

make_ranking_table_(cfg_general, ds)

Unnamed: 0,rank_min,rank_max,rank_mean,rank_median,acc_mean,acc_median
TabPFN (retrained) - Fine-tune,1.0,19.0,5.9,5.0,0.678,0.7
CatBoost,1.0,15.0,6.2,5.25,0.676,0.663
XGBoost,1.0,16.0,6.3,5.0,0.674,0.671
lightGBM,1.0,19.0,7.7,6.25,0.646,0.653
RandomForest,1.0,18.0,7.8,8.0,0.636,0.637
Resnet,1.0,19.0,8.3,8.0,0.602,0.613
SAINT,1.0,19.0,8.4,7.5,0.6,0.614
NODE,1.0,19.0,8.4,8.0,0.611,0.596
SVM,1.0,18.5,8.6,8.0,0.589,0.573
FT-Transformer,1.0,16.0,8.8,8.5,0.599,0.601


In [7]:
ds = xr.merge([
    ds_benchmark,
    dss[4]
])

make_ranking_table_(cfg_general, ds)

Unnamed: 0,rank_min,rank_max,rank_mean,rank_median,acc_mean,acc_median
CatBoost,1.0,15.0,5.8,5.0,0.676,0.663
XGBoost,1.0,17.0,6.0,5.0,0.674,0.671
lightGBM,1.0,19.0,7.4,6.0,0.646,0.653
RandomForest,1.0,18.0,7.5,7.0,0.636,0.637
Resnet,1.0,19.0,8.0,8.0,0.602,0.613
NODE,1.0,19.0,8.1,8.0,0.611,0.596
SAINT,1.0,19.0,8.2,7.5,0.6,0.614
SVM,1.0,18.5,8.3,8.0,0.589,0.573
FT-Transformer,1.0,16.0,8.5,8.0,0.599,0.601
DANet,2.0,19.0,9.6,9.0,0.596,0.608


In [8]:
ds = xr.merge([
    ds_benchmark,
    dss[5]
])

make_ranking_table_(cfg_general, ds)

Unnamed: 0,rank_min,rank_max,rank_mean,rank_median,acc_mean,acc_median
CatBoost,1.0,15.0,6.1,5.0,0.676,0.663
XGBoost,1.0,16.0,6.2,5.0,0.674,0.671
TabForest - Fine-tune,1.0,19.0,7.4,7.0,0.65,0.667
lightGBM,1.0,19.0,7.5,6.0,0.646,0.653
RandomForest,1.0,18.0,7.7,8.0,0.636,0.637
Resnet,1.0,19.0,8.2,8.0,0.602,0.613
SAINT,1.0,19.0,8.3,7.5,0.6,0.614
NODE,1.0,19.0,8.3,8.5,0.611,0.596
SVM,1.0,18.5,8.5,8.0,0.589,0.573
FT-Transformer,1.0,16.0,8.7,9.0,0.599,0.601
