# Drift-Resilient TabPFN

## Installation (this takes about 5 minutes)

In [None]:
username = "automl"
repo = "Drift-Resilient_TabPFN"

# Uninstall pre-installed torch, torchaudio, torchvision, and torchtext to
# prevent conflicts
!pip uninstall -y torch torchaudio torchvision torchtext numpy tsfresh transformers sentence-transformers peft
# Clone our Package, including code to run the experiments.
!pip install git+https://github.com/{username}/{repo}.git

# Due to a numpy revert to <2.0.0 we have to restart the runtime before we move on
# See https://github.com/googlecolab/colabtools/issues/5238
print('Stopping runtime due to numpy downgrade!')
import os
os.kill(os.getpid(), 9)

## GPU Check

To get fast training/inference times, enable GPU processing for this notebook by navigating to Edit → Notebook Settings and selecting GPU as the Hardware accelerator.

In [None]:
import torch

if not torch.cuda.is_available():
    raise RuntimeWarning(
        "No GPU was found. Change the notebook settings for faster training/inference as described above."
    )

## Common Helpers

Helper function used throughout the notebook. Execute this cell before moving on to the sections below.

In [None]:
from contextlib import contextmanager, nullcontext
import functools
import warnings
from urllib3.exceptions import InsecureRequestWarning
import requests

@contextmanager
def temporary_no_ssl_verify():
    """
    Temporarily monkey-patch requests so all HTTPS calls use verify=False.
    Restores the original behaviour on exit.
    """
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", InsecureRequestWarning)

        _orig_request = requests.Session.request
        requests.Session.request = functools.partialmethod(_orig_request, verify=False)
        try:
            yield
        finally:
            requests.Session.request = _orig_request

## Load Models

In this section, we will first load all the pre-trained models that we will use throughout this demo.

In [None]:
from importlib import resources

import tabpfn
from tabpfn.scripts.tabular_baselines import transformer_metric
from tabpfn.best_models import get_best_tabpfn, TabPFNModelPathsConfig

# Get the library path for the tabpfn package
libpath = resources.files(tabpfn)


# Helper function to load each pre-trained model with the corresponding configuration.
def get_model(task_type, model_path, model_type):
    model_path_config = TabPFNModelPathsConfig(
        paths=[f"{libpath}/model_cache/{model_path}.cpkt"], task_type=task_type
    )

    model = get_best_tabpfn(
        task_type=task_type,
        model_type=model_type,
        paths_config=model_path_config,
        debug=False,
        device="auto"
    )
    model.show_progress = False
    model.seed = 1

    return model


task_type = "dist_shift_multiclass"

models_to_load = [
    ("tabpfn_dist_model_1", "best_dist"),
    ("tabpfn_dist_model_2", "best_dist"),
    ("tabpfn_dist_model_3", "best_dist"),
    ("tabpfn_dist_ablation_no_t2v_model_1", "best_dist"),
    ("tabpfn_base_model_1", "best_base"),
    ("tabpfn_base_model_2", "best_base"),
    ("tabpfn_base_model_3", "best_base"),
]

# Load each model
models = {
    model_name: get_model(task_type, model_name, model_type)
    for model_name, model_type in models_to_load
}

In [None]:
# Display one of the loaded distribution shift models
models["tabpfn_dist_model_1"]

## Use Drift-Resilient TabPFN

This section briefly shows how to interact with the Sklearn interface of our models.

In [None]:
import numpy as np

from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.model_selection import train_test_split

from tabpfn.datasets.dist_shift_datasets import get_rotated_moons_drain

# Get one of our pre-trained models
clf = models["tabpfn_dist_model_1"]

# Load the Rotated Two Moons dataset
dataset = get_rotated_moons_drain()

# Split the dataset into train and test sets
X_train, X_test, y_train, y_test, dist_shift_domain_train, dist_shift_domain_test = train_test_split(
    dataset.x, dataset.y, dataset.dist_shift_domain, test_size=0.50, shuffle=False, random_state=42)

# Fit the classifier on the training data
clf.fit(X_train, y_train, additional_x={"dist_shift_domain": dist_shift_domain_train})

# Predict probabilities on the test data
preds = clf.predict_proba(X_test, additional_x={"dist_shift_domain": dist_shift_domain_test})

# Get the predicted classes
y_eval = np.argmax(preds, axis=1)

# Print ROC AUC and accuracy scores
print("")
print('ROC AUC: ', roc_auc_score(y_test, preds[:,1], multi_class='ovo'), 'Accuracy', accuracy_score(y_test, y_eval))


## Qualitative Analysis

Here you can visualise the shifts in our synthetic 2D datasets and analyze the decision boundaries of TabPFN$_\text{dist}$ compared to TabPFN$_\text{base}$.

### Common Setup

Execute this cell before moving on to the sections below.

In [None]:
from tabpfn.datasets import get_benchmark_for_task
from tabpfn.utils import default_task_settings

max_samples, max_features, max_times, max_classes = default_task_settings()

datasets_dict = {}

# In case the dataset folktables can't be loaded from US census with a certificate error, activate the following context manager to ignore it:
with temporary_no_ssl_verify():
# with nullcontext():
    # Load the validation and test datasets.
    for split in ["valid", "test"]:
        datasets, _ = get_benchmark_for_task(
            task_type=task_type,
            split=split,
            max_samples=max_samples,
            max_features=max_features,
            max_classes=max_classes,
            return_as_lists=False,
        )

        datasets_dict[split] = datasets

# Map the names of our models to a display string
name_mapping = {
    "tabpfn_dist_model_1": "TabPFN$_{\\mathrm{dist}_1}$",
    "tabpfn_dist_model_2": "TabPFN$_{\\mathrm{dist}_2}$",
    "tabpfn_dist_model_3": "TabPFN$_{\\mathrm{dist}_3}$",
    "tabpfn_dist_ablation_no_t2v_model_1": "TabPFN$_{\\mathrm{no t2v}}$",
    "tabpfn_base_model_1": "TabPFN$_{\\mathrm{base}_1}$",
    "tabpfn_base_model_2": "TabPFN$_{\\mathrm{base}_2}$",
    "tabpfn_base_model_3": "TabPFN$_{\\mathrm{base}_3}$",
}

### Dataset Plots

Plot the datapoints of our 2D datasets across domains.

In [None]:
# Print the available datasets to plot decision boundaries (e.g. that have 2 features)
print("Validation Datasets with 2 Features:")
for i, dataset in enumerate(datasets_dict["valid"]):
    if dataset.x.shape[1] == 2:
        print(f"{i}: {dataset}")

print("Test Datasets with 2 Features:")
for i, dataset in enumerate(datasets_dict["test"]):
    if dataset.x.shape[1] == 2:
        print(f"{i}: {dataset}")

In [None]:
# Rotating Two Moons
datasets_dict["test"][0].plot(animate=True)

In [None]:
# Intersecting Blobs Dataset
datasets_dict["test"][2].plot(animate=True)

### Decision Boundaries

Plot and compare the decision boundaries of our pre-trained models given a range of training domains on the unseen out-of-distribution test domains.

In [None]:
# Print the available datasets to plot decision boundaries (e.g. that have 2 features)
print("Validation Datasets with 2 Features:")
for i, dataset in enumerate(datasets_dict["valid"]):
    if dataset.x.shape[1] == 2:
        print(f"{i}: {dataset}")

print("Test Datasets with 2 Features:")
for i, dataset in enumerate(datasets_dict["test"]):
    if dataset.x.shape[1] == 2:
        print(f"{i}: {dataset}")

In [None]:
from matplotlib import pyplot as plt
import torch

from tabpfn.scripts.decision_boundary import plot_decision_boundary

# Select a subset of the models to compare here
models_to_eval = {
    name: model
    for name, model in models.items()
    if name in {"tabpfn_dist_model_1", "tabpfn_base_model_1"}
}

# Check that N_ensemble_configurations is not None, since this applies a different number of ensemble configurations
# for the grid to be predicted than for the foreground samples of the dataset visualized, leading to a mismatch
# in the plots.
assert all(model.N_ensemble_configurations is not None for model in models_to_eval.values()), "Detected N_ensemble_configurations=None in one of the models, aborting."

# Select the index of the dataset to plot according to the cell above
current_dataset = datasets_dict["test"][2]

# Set the number of domains to train with
num_train_domains = 4

# Set the number of domains to predict into the future, depending on the dataset
max_predict_domains = 3

# Define the plot
plt.rcParams.update({"font.size": 16})
plt.rcParams["axes.linewidth"] = 0.5
fig, axs = plt.subplots(
    len(models_to_eval),
    max_predict_domains,
    figsize=(6 * (max_predict_domains + 0.6), 4 * (len(models_to_eval))),
)
fig.set_dpi(300)
fig.subplots_adjust(
    bottom=0.05, top=0.95, left=0.1, right=0.8, wspace=0.03, hspace=0.03
)
if axs.ndim == 1:
    axs = axs.reshape(1, -1)
# Add the colorbar as a separate axis in the end
cbar_ax = fig.add_axes([0.82, 0.1, 0.03, 0.6])
legend_ax = fig.add_axes([0.81, 0.75, 0.03, 0.15])

def process_model(model, label, k):
    num_classes = current_dataset.y.unique().shape[0]

    # Split the data into the train and id/ood test sets.
    train_ds, test_portions = current_dataset.generate_valid_split(
        test_set_start_index=num_train_domains, num_predict_domains=max_predict_domains
    )
    ood_ds, id_ds = test_portions["ood"], test_portions["id"]

    # To be comparable to DRAIN and GI we add the id test set back to the training data.
    train_ds.x = torch.concat([train_ds.x, id_ds.x], dim=0)
    train_ds.dist_shift_domain = torch.concat(
        [train_ds.dist_shift_domain, id_ds.dist_shift_domain], dim=0
    )
    train_ds.y = torch.concat([train_ds.y, id_ds.y], dim=0)

    # Fit the train set and predict the ood test set.
    model = model.fit(
        train_ds.x,
        train_ds.y,
        additional_x={"dist_shift_domain": train_ds.dist_shift_domain},
    )
    pred = model.predict(
        ood_ds.x, additional_x={"dist_shift_domain": ood_ds.dist_shift_domain}
    )

    # Get the unique domains we predicted
    unique_vals = torch.unique(ood_ds.dist_shift_domain)

    # Plot the decision boundary for each domain separately and display it alongside
    # the predicted samples for each domain.
    for j, domain in enumerate(unique_vals[:max_predict_domains]):
        # Get the axis to plot into
        ax = axs[k, j]

        # Define some bool conditions for labels
        first_col = j == 0
        last_col = j == max_predict_domains - 1
        first_row = k == 0
        last_row = k == len(models_to_eval) - 1
        last_plot = last_col and last_row

        # Filter our test set as well as our predictions for the current domain
        test_of_current_domain_x = ood_ds.x[ood_ds.dist_shift_domain == domain]
        test_of_current_domain_y = ood_ds.y[ood_ds.dist_shift_domain == domain]
        pred_of_current_domain = pred[ood_ds.dist_shift_domain == domain]

        # Display the decision boundary for this domain
        disp = plot_decision_boundary(
            estimator=model,
            all_X=current_dataset.x,
            X=test_of_current_domain_x,
            y_gt=test_of_current_domain_y,
            y_pred=pred_of_current_domain,
            dist_shift_domain=domain,
            xlabel="$x_1$" if last_row else "",
            ylabel="$x_2$" if first_col else "",
            grid_resolution=100,
            ax=ax,
            eps=0.1,
            show_colorbar=last_plot,
            show_legend=last_plot,
            cbar_ax=cbar_ax if last_plot else None,
            legend_ax=legend_ax if last_plot else None,
        )

        # Set some desc
        if first_row:
            ax.set_title(f"Train Domains: 0-{num_train_domains-1} | Test Domain {num_train_domains+j}", fontsize=20)

        if first_col:
            ax.annotate(
                label,
                xy=(0, 0.5),
                xytext=(-ax.yaxis.labelpad - 10, 0),
                xycoords=ax.yaxis.label,
                textcoords="offset points",
                size=20,
                ha="right",
                va="center",
                rotation=90,
            )

        ax.tick_params(
            labelbottom=False, labelleft=False, labelright=False, labeltop=False
        )


for k, (name, model) in enumerate(models_to_eval.items()):
    print(f"Evaluating {name}...")
    process_model(model, name_mapping.get(name, name), k)

plt.show()

## Quantitative Analysis

This section allows to reproduce the quantitative results stated in our paper.



### Common Setup

Execute this cell before moving on to the sections below.


In [None]:
from itertools import product
import pandas as pd

from tabpfn.datasets import get_benchmark_for_task
from tabpfn.utils import default_task_settings

max_samples, max_features, max_times, max_classes = default_task_settings()

task_type = "dist_shift_multiclass"

datasets_dict = {}

# In case the dataset folktables can't be loaded from US census with a certificate error, activate the following context manager to ignore it:
with temporary_no_ssl_verify():
# with nullcontext():
# Load the validation and test datasets.
    for split in ["valid", "test"]:
        datasets, _ = get_benchmark_for_task(
            task_type=task_type,
            split=split,
            max_samples=max_samples,
            max_features=max_features,
            max_classes=max_classes,
            return_as_lists=False,
        )

        datasets_dict[split] = datasets


# Define the evaluation parameters for the different evaluation scenarios used
# in our paper.
def get_eval_kwargs(setting):
    if setting == "w_indices":
        return {
            "minimize_num_train_domains": False,
            "append_domain_as_feature": True,
        }
    elif setting == "wo_indices":
        return {
            "minimize_num_train_domains": False,
            "append_domain_as_feature": False,
        }
    elif setting == "l_dom_wo_indices":
        return {
            "minimize_num_train_domains": True,
            "append_domain_as_feature": False,
        }
    else:
        raise ValueError(f"Unknown setting: {setting}")


# Helper function to simplify metric names
def simplify_name(metric, includes_model_type=False):
    parts = metric.split("/")
    offset = 1 if includes_model_type else 0

    # Define metric's domain and name
    domain = parts[1 + offset]
    metric_name = parts[-1].replace("mean_", "")

    # Determine benchmark based on parts, handling possible special cases
    benchmark = "overall"
    if len(parts) >= 5 + offset:
        benchmark_part = parts[4 + offset]
        if "real-world" in benchmark_part or "synthetic" in benchmark_part:
            benchmark = benchmark_part.split("_")[1]
        elif "dataset" in parts[3 + offset]:
            benchmark = parts[4 + offset].split("_")[0].replace(" ", "_")

    return f"{domain}/{benchmark}/{metric_name}"


# Helper function to filter the evaluation results
def generate_vis_metrics(split):
    domains = ["ood", "id"]
    portions = ["", "/per_task/benchmark_real-world", "/per_task/benchmark_synthetic"]
    metrics = ["mean_acc", "mean_f1", "mean_roc",  "mean_ece"]

    vis_metrics = [
        f"{split}/{domain}/3_splits{task}/{metric}"
        for task, metric, domain in product(portions, metrics, domains)
    ]
    simplified_metrics = [simplify_name(metric) for metric in vis_metrics]

    return vis_metrics, simplified_metrics


# Set pandas display options.
pd.set_option("display.float_format", "{:.3f}".format)
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 2000)


# Helper function to create and sort a pandas dataframe of the results
def get_df(results, metrics):
    df = pd.DataFrame(
        list(results.values()), columns=["name", "initialization", "setting"] + metrics
    )
    df.sort_values(by=["setting", "name", "initialization"], inplace=True)
    return df


# Initialize results to be None
results = None

### Evaluate TabPFN

Evaluate TabPFN$_\text{dist}$ and TabPFN$_\text{base}$ on our test datasets to reproduce the performance metrics in our paper.

> **NOTE:** Due to ensembling and preprocessing this evaluation took us about 30min per model and setting on 8 CPUs and 1 GPU. Intermediate results are saved in the evaluation folder in the left sidebar. Download and restore them to continue where you left off or distribute these tasks on a cluster you have access to.

In [None]:
from functools import partial
import pandas as pd
import re

import ipywidgets as widgets
from IPython.display import display, clear_output

from tabpfn.scripts.tabular_evaluation import evaluate_and_score
from tabpfn.scripts.tabular_baselines import transformer_metric

from tabpfn.scripts.tabular_metrics import (
    get_standard_eval_metrics,
    get_main_eval_metric,
)

# Create an accordion widget
accordion = widgets.Accordion(
    children=[
        widgets.Output(),
        widgets.Output(layout={"height": "800px", "overflow": "auto"}),
    ]
)
accordion.set_title(0, "Current Results")
accordion.set_title(1, "Evaluation Output")
display(accordion)

# For the TabPFN base models you can define this to be 'w_indices',
# 'wo_indices', 'l_dom_wo_indices'. TabPFN dist is implemented to only work
# correctly for the setting 'w_indices'.
settings = ["w_indices", "wo_indices", "l_dom_wo_indices"]

# You can switch between the 'test' and 'valid' datasets here.
split = "test"

# Get result strings to filter for after the evaluation is done.
vis_metrics, simplified_metrics = generate_vis_metrics(split)

# Start the evaluation.
results = {} if results is None else results

for setting in settings:
    for model_name, model in models.items():
        # Skip the settings 'wo_indices' and 'l_dom_wo_indices' for TabPFN dist.
        if "tabpfn_dist" in model_name and setting != "w_indices":
            continue

        with accordion.children[1]:
            print(
                f"Evaluation of {model_name} in setting {setting} on 3 splits of {split} datasets."
            )

            # For our methods we need to extract the model initialization from the filename.
            matches = re.match(r"^(.*?)(_)(\d+)$", model_name)
            if matches:
                model_name, init = matches.group(1), int(matches.group(3))
                # init is off by one in comparison to other baselines
                init -= 1
            else:
                init = 1

            log_msg, _ = evaluate_and_score(
                method_name=model_name,
                valid_datasets=datasets_dict[split],
                valid_metrics=get_standard_eval_metrics(task_type),
                metric_with_model=partial(transformer_metric, classifier=model),
                metric_used=get_main_eval_metric(task_type),
                split_name=split,
                log_per_dataset_metrics=True,
                log_per_split_metrics=True,
                num_splits=3,
                base_path="./evaluation/",
                eval_kwargs=get_eval_kwargs(setting),
                device="cuda",
                rename_gpu_runs=True,
                save=True,
                overwrite=False,
                path_interfix=f"{setting}_{init}",
            )

        # Store the evaluation result.
        results[f"{setting}_{model_name}_{init}"] = {
            **{"name": model_name, "initialization": init, "setting": setting},
            **{
                simplified_metric: log_msg[metric]
                for metric, simplified_metric in zip(vis_metrics, simplified_metrics)
            },
        }

        # Print the current state of the results.
        with accordion.children[0]:
            clear_output(wait=True)
            print("Updated results: ")
            display(get_df(results, simplified_metrics))
            accordion.selected_index = 0


# Print the current state of the results.
with accordion.children[0]:
    clear_output(wait=True)
    print("Evaluation done. Final results: ")
    display(get_df(results, simplified_metrics))
    accordion.selected_index = 0

### Evaluate Baselines

Evaluate the tree-based baselines as well as the methods of the WildTime-Benchmark on our test datasets to reproduce the performance metrics in our paper. Note, however, that we used 8 CPUs and 1 GPU for model training. Also, newer versions of the packages for CatBoost, XGBoost, and LightGBM were released since the release of our paper, leading to slight deviations in the results.

> **NOTE:** With a `max_time` of 1200s for hpo per dataset split, this will take about 18h per method and setting. Intermediate results are saved in the evaluation folder in the left sidebar. Download and restore them to continue where you left off or distribute these tasks on a cluster you have access to.

In [None]:
from functools import partial
import pandas as pd

import ipywidgets as widgets
from IPython.display import display, clear_output

from tabpfn.scripts.tabular_evaluation import evaluate_and_score
from tabpfn.scripts.tabular_baselines import transformer_metric, get_clf_dict

from tabpfn.scripts.tabular_metrics import get_standard_eval_metrics, get_main_eval_metric

# Create an accordion widget
accordion = widgets.Accordion(children=[widgets.Output(), widgets.Output(layout={'height': '800px', 'overflow': 'auto'})])
accordion.set_title(0, 'Current Results')
accordion.set_title(1, 'Evaluation Output')
display(accordion)

# For the TabPFN base models you can define this to be 'w_indices',
# 'wo_indices', 'l_dom_wo_indices'. TabPFN dist is implemented to only work
# correctly for the setting 'w_indices'.
settings = ["w_indices", "wo_indices", "l_dom_wo_indices"]

# You can switch between the 'test' and 'valid' datasets here.
split = "test"

# The number of model initializations to train and evaluate.
num_initializations = 3

# The maximum time budget in seconds for hpo of each baseline method on each dataset split.
max_time = 1200

# Get result strings to filter for after the evaluation is done.
vis_metrics, simplified_metrics = generate_vis_metrics(split)

# Baseline methods to train and evaluate.
tree_methods = ["catboost", "xgb", "lightgbm"]
wildtime_methods = ['wildtime_MLP_erm', 'wildtime_MLP_ft', 'wildtime_MLP_ewc',
                    'wildtime_MLP_si', 'wildtime_MLP_agem', 'wildtime_MLP_coral',
                    'wildtime_MLP_groupdro', 'wildtime_MLP_irm', 'wildtime_MLP_erm_mixup',
                    'wildtime_MLP_erm_lisa', 'wildtime_MLP_swa']

# Filter here for the methods you like to evaluate.
baseline_methods = tree_methods # + wildtime_methods

# Start the evaluation.
results = {} if results is None else results
for init in range(num_initializations):
    for setting in settings:
        for baseline_method in baseline_methods:
            # Skip the setting 'l_dom_wo_indices' for WildTime non-ERM methods.
            if baseline_method.startswith("wildtime") and baseline_method != "wildtime_MLP_erm" and setting == 'l_dom_wo_indices':
                continue

            with accordion.children[1]:
                print(f"Training and evaluation {baseline_method} (init {init}) in setting {setting} on 3 splits of {split} datasets.")

                log_msg, _ = evaluate_and_score(
                            method_name       = baseline_method,
                            valid_datasets    = datasets_dict[split],
                            valid_metrics     = get_standard_eval_metrics(task_type),
                            metric_with_model = get_clf_dict(task_type)[baseline_method],
                            metric_used       = get_main_eval_metric(task_type),
                            split_name        = split,
                            log_per_dataset_metrics = True,
                            log_per_split_metrics   = True,
                            num_splits        = 3,
                            base_path         = "./evaluation/",
                            eval_kwargs       = get_eval_kwargs(setting),
                            device            = "cuda",
                            rename_gpu_runs   = True,
                            save              = True,
                            overwrite         = False,
                            path_interfix     = f"{setting}_{init}",
                            random_state      = init,
                            max_time          = max_time
                          )

            # Store the evaluation result.
            results[f"{setting}_{baseline_method}_{init}"] = ({
                      **{'name': baseline_method, 'initialization': init, 'setting': setting},
                      **{simplified_metric: log_msg[metric] for metric, simplified_metric in zip(vis_metrics, simplified_metrics)}
                  })

            # Print the current state of the results.
            with accordion.children[0]:
                clear_output(wait=True)
                print("Updated results: ")
                display(get_df(results, simplified_metrics))
                accordion.selected_index = 0

with accordion.children[0]:
    clear_output(wait=True)
    print("Evaluation done. Final results: ")
    display(get_df(results, simplified_metrics))
    accordion.selected_index = 0

### Calculate Mean and Confidence Intervals

Once the results table is complete, you can calculate the mean value and the confidence intervals for all methods and settings across initializations as follows.

In [None]:
import scipy.stats as st

def compute_mean_and_conf_interval(accuracies, confidence=.95):
    accuracies = np.array(accuracies)
    n = len(accuracies)

    # Only show results if we have 3 initializations for each method
    if n != 3:
        return np.nan, np.nan

    m, se = np.mean(accuracies), st.sem(accuracies)
    h = se * st.t.ppf((1 + confidence) / 2., n - 1) if n > 1 else 0

    return m, h

# Define a dictionary of aggregation functions for each metric
agg_dict = {metric: [("mean", lambda x: compute_mean_and_conf_interval(x)[0]),
                      ("ci", lambda x: compute_mean_and_conf_interval(x)[1])]
            for metric in simplified_metrics}

# Get the results per method and initialization
results_per_initialization = get_df(results, simplified_metrics)

# Aggregate accross initializations for each method and setting
result_df = results_per_initialization.groupby(['name', 'setting']).agg(agg_dict)

display(result_df)