# Tutorial 2: Augment real-world data with TabEBM


## Environment setup


In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
!pip install tabebm
!pip install tabcamel



In [3]:
import warnings

warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from tabcamel.data.dataset import TabularDataset
from tabcamel.data.transform import CategoryTransform, NumericTransform, TargetTransform

from tabebm.TabEBM import TabEBM, seed_everything

In [4]:
seed_everything(42)

## Prepare real-world data


- Load the dataset


In [5]:
dataset = TabularDataset(
    dataset_name="adult",
    task_type="classification",
)
print(dataset)

Dataset: adult
Task type: classification
Status (is_tensor): False
Number of samples: 48842
Number of features: 14 (Numerical: 2, Categorical: 12)
Number of classes: 2
Class distribution: {'<=50K': 0.7607182343065395, '>50K': 0.23928176569346055}


- Subsample a small subset of the data to simulate low-sample-size scenario


In [6]:
subsample_dict = dataset.sample(
    sample_mode="stratified",
    sample_size=200,
    random_state=42,
)
dataset = subsample_dict["dataset_sampled"]

print("Subsampled dataset:")
print(dataset)

Subsampled dataset:
Dataset: adult
Task type: classification
Status (is_tensor): False
Number of samples: 200
Number of features: 14 (Numerical: 2, Categorical: 12)
Number of classes: 2
Class distribution: {'<=50K': 0.765, '>50K': 0.235}


- Preprocess the data


In [7]:
# Split the dataset into training and testing sets
split_dict = dataset.split(
    split_mode="stratified",
    train_size=0.8,
    random_state=42,
)
train_set = split_dict["train_set"]
test_set = split_dict["test_set"]
X_train, y_train = train_set.X_df, pd.DataFrame(train_set.y_s)
X_test, y_test = test_set.X_df, pd.DataFrame(test_set.y_s)

# Transform the features
feature_encoder = CategoryTransform(
    categorical_feature_list=train_set.categorical_feature_list,
    strategy="onehot",
)
feature_encoder.fit(X_train)
X_train = feature_encoder.transform(X_train)
X_test = feature_encoder.transform(X_test)

feature_scaler = NumericTransform(
    numerical_feature_list=train_set.numerical_feature_list,
    strategy="standard",
    include_categorical=False,
)
feature_scaler.fit(X_train)
X_train = feature_scaler.transform(X_train)
X_test = feature_scaler.transform(X_test)

target_encoder = TargetTransform(
    task="classification",
    target_feature=train_set.target_col,
)
target_encoder.fit(y_train)
y_train = target_encoder.transform(y_train)
y_test = target_encoder.transform(y_test)

## Prepare synthetic data

* Train TabEBM for synthetic data generation 

In [8]:
# === Fit TabEBM and generate synthetic samples ===
tabebm = TabEBM()
# Generate 50 synthetic samples per class
data_syn = tabebm.generate(X_train, y_train, num_samples=50)

* Combine synthetic data with real data

In [9]:
# === Combine the synthetic samples with the real samples ===
X_syn = np.concatenate(list(data_syn.values()))
y_syn = np.concatenate([np.full(len(data_syn[f"class_{i}"]), i) for i in range(len(data_syn.keys()))])

X_train_augmented = np.concatenate([X_train, X_syn])
y_train_augmented = np.concatenate([y_train.to_numpy().reshape(-1), y_syn])

## Train a downstream predictor


### Original dataset: only real data


In [10]:
model_vanilla = KNeighborsClassifier()
model_vanilla.fit(X_train, y_train)

### Augmented dataset: real + synthetic data


In [11]:
model_augmented = KNeighborsClassifier()
model_augmented.fit(X_train_augmented, y_train_augmented)

### Evaluate the predictive accuracy


In [12]:
from sklearn.metrics import balanced_accuracy_score

y_pred_vanilla = model_vanilla.predict(X_test)
y_pred_augmented = model_augmented.predict(X_test)
acc_vanilla = balanced_accuracy_score(y_test, y_pred_vanilla) * 100
acc_augmented = balanced_accuracy_score(y_test, y_pred_augmented) * 100

print(f"Vanilla model's balanced accuracy: {acc_vanilla:.2f}")
print(f"Augmented model's balanced accuracy: {acc_augmented:.2f}")

Vanilla model's balanced accuracy: 71.33
Augmented model's balanced accuracy: 75.99
