Is early stopping better than tuning the number of trees in LightGBM? Benchmarked on the [bank-account-fraud dataset](https://www.kaggle.com/datasets/sgpjesus/bank-account-fraud-dataset-neurips-2022) (published at NeurIPS'22).

**Executive Summary**
- There does not seem to be a significant difference in performance between tuning the number of boosting iterations and using early stopping with a high maximum number of trees.
- The early stopping strategy more than halves average training time.

In [None]:
import warnings
import numpy as np
import pandas as pd
import optuna
import lightgbm as lgb
from sklearn import metrics
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
warnings.filterwarnings('ignore')
optuna.logging.set_verbosity(optuna.logging.WARNING)
sns.set_style('whitegrid')

In [None]:
SEED = 42
N_TRIALS = 100

# Data Loading

In [None]:
df = pd.read_csv('/kaggle/input/bank-account-fraud-dataset-neurips-2022/Base.csv')
df

In [None]:
TIMESTAMP_COL = 'month'
LABEL_COL = 'fraud_bool'
CATEGORICAL_COLS = ['payment_type', 'employment_status', 'housing_status', 'source', 'device_os']

# LightGBM and XGBoost use this to treat categoricals differently:
df[CATEGORICAL_COLS] = df[CATEGORICAL_COLS].astype('category')

train = df[df[TIMESTAMP_COL] <= 5].drop(columns=TIMESTAMP_COL)
val = df[df[TIMESTAMP_COL] == 6].drop(columns=TIMESTAMP_COL)
test = df[df[TIMESTAMP_COL] == 7].drop(columns=TIMESTAMP_COL)

X_train, y_train = train.drop(columns=[LABEL_COL]), train[LABEL_COL]
X_val, y_val = val.drop(columns=[LABEL_COL]), val[LABEL_COL]
X_test, y_test = test.drop(columns=[LABEL_COL]), test[LABEL_COL]

In [None]:
dataset_params = dict(
    categorical_feature=CATEGORICAL_COLS,
    free_raw_data=False,
)
dtrain = lgb.Dataset(X_train, label=y_train, **dataset_params)
dval = lgb.Dataset(X_val, label=y_val, **dataset_params).construct()
dtest = lgb.Dataset(X_test, label=y_test, **dataset_params).construct()

# Hyperparameter Tuning

In [None]:
# aux functions for binary classification evaluation
def calc_threshold_at_fpr(y_true: np.ndarray, y_score: np.ndarray, fpr: float):
    temp_df = pd.DataFrame(
        {'y_true': y_true,
         'y_score': y_score,
         })
    temp_df = temp_df.sort_values(by='y_score', ascending=False, ignore_index=True)
    temp_df['pseudo_fpr'] = (temp_df['y_true']
                             .apply(lambda x: 1 if x == 0 else 0)
                             .cumsum()
                             .divide(sum(y_true == 0)))

    critical_threshold = float((
        temp_df
        .loc[temp_df['pseudo_fpr'] < fpr, 'y_score']
        .iloc[-1]))

    return critical_threshold

def predict_at_fpr(y_true: np.ndarray, y_score: np.ndarray, fpr: float):
    threshold_at_fpr = calc_threshold_at_fpr(
        y_true=y_true, y_score=y_score, fpr=fpr)

    return (y_score > threshold_at_fpr).astype(int)

In [None]:
NON_TUNED_PARAMS = dict(
    objective='binary',
    verbosity='-1',
    enable_bundle=True,
    feature_pre_filter=False,  # to enable min_child_samples exploration
)

In [None]:
def _objective(trial, dtrain, dval, dtest, categorical_cols, early_stopping, optimization_logs):
    params = {
        'boosting_type': trial.suggest_categorical('boosting_type', ['gbdt', 'goss']),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.5, log=True),
        'num_leaves': trial.suggest_int('num_leaves', 2, 256, log=True),
        'max_depth': trial.suggest_int('max_depth', 2, 10),
        'min_child_samples': trial.suggest_int('min_child_samples', 2, 256, log=True),
        'bagging_freq': trial.suggest_categorical('bagging_freq', [0, 1]),
        'pos_bagging_fraction': trial.suggest_float('pos_bagging_fraction', 0.1, 1),
        'neg_bagging_fraction': trial.suggest_float('neg_bagging_fraction', 0.1, 1),
        'reg_alpha': trial.suggest_float('reg_alpha', 0.00001, 0.1, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 0.00001, 0.1, log=True),
    }

    model = lgb.train(
        {**NON_TUNED_PARAMS, **params},
        dtrain,
        categorical_feature=categorical_cols,
        num_boost_round=(
            4000 if early_stopping
            else trial.suggest_int('num_boost_rounds', 10, 4000)
        ),
        valid_sets=dval if early_stopping else None,
        callbacks=(
            [lgb.early_stopping(stopping_rounds=100, verbose=False)] if early_stopping
            else None
        ),
    )

    y_val = dval.get_label()
    y_val_score = model.predict(dval.get_data())
    y_test = dtest.get_label()  # test eval for experimental purposes
    y_test_score = model.predict(dtest.get_data())

    artifacts = {
        # 'y_score': y_score,
        'num_boosting_rounds': model.num_trees(),
        'log_loss': metrics.log_loss(y_true=y_val, y_pred=y_val_score),
        'recall_at_fpr': metrics.recall_score(y_true=y_val, y_pred=predict_at_fpr(y_val, y_val_score, fpr=0.05)),
        'test_log_loss': metrics.log_loss(y_true=y_test, y_pred=y_test_score),
        'test_recall_at_fpr': metrics.recall_score(y_true=y_test, y_pred=predict_at_fpr(y_test, y_test_score, fpr=0.05)),
    }
    optimization_logs.append(artifacts)

    return artifacts['log_loss']

In [None]:
NO_ES_OPTIMIZATION_LOGS = []  # keep records
no_es_study = optuna.create_study(
    sampler=optuna.samplers.RandomSampler(seed=SEED),
    directions=['minimize'],  # ignored (random sampling)
)
no_es_study.optimize(
    lambda trial: _objective(
        trial=trial,
        dtrain=dtrain, dval=dval, dtest=dtest, categorical_cols=CATEGORICAL_COLS,
        early_stopping=False,
        optimization_logs=NO_ES_OPTIMIZATION_LOGS),
    n_trials=N_TRIALS,
)

In [None]:
ES_OPTIMIZATION_LOGS = []  # keep records
es_study = optuna.create_study(
    sampler=optuna.samplers.RandomSampler(seed=SEED),
    directions=['minimize'],  # ignored (random sampling)
)
es_study.optimize(
    lambda trial: _objective(
        trial=trial,
        dtrain=dtrain, dval=dval, dtest=dtest, categorical_cols=CATEGORICAL_COLS,
        early_stopping=True,
        optimization_logs=ES_OPTIMIZATION_LOGS),
    n_trials=N_TRIALS,
)

# Analysis

In [None]:
no_es_results = pd.concat([no_es_study.trials_dataframe(), pd.DataFrame(NO_ES_OPTIMIZATION_LOGS)], axis=1)
no_es_results[f'best_log_loss'] = no_es_results['log_loss'].expanding().min()
no_es_results[f'best_recall_at_fpr'] = no_es_results['recall_at_fpr'].expanding().max()
no_es_results[f'test_best_log_loss'] = no_es_results['test_log_loss'].expanding().min()
no_es_results[f'test_best_recall_at_fpr'] = no_es_results['test_recall_at_fpr'].expanding().max()
no_es_results

In [None]:
es_results = pd.concat([es_study.trials_dataframe(), pd.DataFrame(ES_OPTIMIZATION_LOGS)], axis=1)
es_results[f'best_log_loss'] = es_results['log_loss'].expanding().min()
es_results[f'best_recall_at_fpr'] = es_results['recall_at_fpr'].expanding().max()
es_results[f'test_best_log_loss'] = es_results['test_log_loss'].expanding().min()
es_results[f'test_best_recall_at_fpr'] = es_results['test_recall_at_fpr'].expanding().max()
es_results

In [None]:
results = pd.concat([
    no_es_results.assign(early_stopping=False),
    es_results.assign(early_stopping=True)
]).reset_index(drop=True)
results['trial'] = results['number'] + 1
results

## Validation set

### Cross-entropy loss

In [None]:
sns.lineplot(
    data=results,
    x='trial', y='best_log_loss', hue='early_stopping',
    linewidth=2, alpha=0.9, legend=None,
)
sns.scatterplot(
    data=results, hue='early_stopping',
    x='trial', y='log_loss',
    s=20, alpha=0.9,
)
plt.xlabel('Trial')
plt.ylabel('Cross-entropy loss')
plt.legend(title='Early stopping')
plt.show()

sns.lineplot(
    data=results, hue='early_stopping',
    x='trial', y='best_log_loss',
    linewidth=2, alpha=0.9, legend=None,
)
sns.scatterplot(
    data=results, hue='early_stopping',
    x='trial', y='log_loss',
    s=20, alpha=0.9,
)
plt.ylim(0.05225, results['log_loss'].quantile(0.8))
plt.xlabel('Trial')
plt.ylabel('Cross-entropy loss')
plt.legend(title='Early stopping')
plt.show()

In [None]:
sns.stripplot(data=results, x='early_stopping', y='log_loss', color='black', alpha=0.7)
sns.violinplot(data=results, x='early_stopping', y='log_loss')
plt.xlabel('Early stopping')
plt.ylabel('Cross-entropy loss')
plt.show()

plot_data = results[results['log_loss'] <= results['log_loss'].quantile(0.8)]
sns.stripplot(data=plot_data, x='early_stopping', y='log_loss', color='black', alpha=0.7)
sns.violinplot(data=plot_data, x='early_stopping', y='log_loss')
plt.ylim(0.05225, results['log_loss'].quantile(0.8))
plt.xlabel('Early stopping')
plt.ylabel('Cross-entropy loss')
plt.show()

In [None]:
results.groupby('early_stopping')['log_loss'].mean()

### Recall at 5% FPR

In [None]:
sns.lineplot(
    data=results,
    x='trial', y='best_recall_at_fpr', hue='early_stopping',
    linewidth=2, alpha=0.9, legend=None,
)
sns.scatterplot(
    data=results, hue='early_stopping',
    x='trial', y='recall_at_fpr',
    s=20, alpha=0.9,
)
plt.xlabel('Trial')
plt.ylabel('Recall at 5% FPR')
plt.legend(title='Early stopping')
plt.show()

sns.lineplot(
    data=results, hue='early_stopping',
    x='trial', y='best_recall_at_fpr',
    linewidth=2, alpha=0.9, legend=None,
)
sns.scatterplot(
    data=results, hue='early_stopping',
    x='trial', y='recall_at_fpr',
    s=20, alpha=0.9,
)
plt.ylim(results['recall_at_fpr'].quantile(0.2), 0.54)
plt.xlabel('Trial')
plt.ylabel('Recall at 5% FPR')
plt.legend(title='Early stopping')
plt.show()

In [None]:
sns.stripplot(data=results, x='early_stopping', y='recall_at_fpr', color='black', alpha=0.7)
sns.violinplot(data=results, x='early_stopping', y='recall_at_fpr')
plt.ylim(bottom=0)
plt.xlabel('Early stopping')
plt.ylabel('Recall at 5% FPR')
plt.show()

In [None]:
results.groupby('early_stopping')['recall_at_fpr'].mean()

## Test set

### Cross-entropy loss

In [None]:
sns.lineplot(
    data=results,
    x='trial', y='test_best_log_loss', hue='early_stopping',
    linewidth=2, alpha=0.9, legend=None,
)
sns.scatterplot(
    data=results, hue='early_stopping',
    x='trial', y='test_log_loss',
    s=20, alpha=0.9,
)
plt.xlabel('Trial')
plt.ylabel('Cross-entropy loss')
plt.legend(title='Early stopping')
plt.show()

sns.lineplot(
    data=results, hue='early_stopping',
    x='trial', y='test_best_log_loss',
    linewidth=2, alpha=0.9, legend=None,
)
sns.scatterplot(
    data=results, hue='early_stopping',
    x='trial', y='test_log_loss',
    s=20, alpha=0.9,
)
plt.ylim(0.0545, results['test_log_loss'].quantile(0.8))
plt.xlabel('Trial')
plt.ylabel('Cross-entropy loss')
plt.legend(title='Early stopping')
plt.show()

In [None]:
sns.stripplot(data=results, x='early_stopping', y='test_log_loss', color='black', alpha=0.7)
sns.violinplot(data=results, x='early_stopping', y='test_log_loss')
plt.xlabel('Early stopping')
plt.ylabel('Cross-entropy loss')
plt.show()

plot_data = results[results['test_log_loss'] <= results['test_log_loss'].quantile(0.8)]
sns.stripplot(data=plot_data, x='early_stopping', y='test_log_loss', color='black', alpha=0.7)
sns.violinplot(data=plot_data, x='early_stopping', y='test_log_loss')
plt.ylim(0.05225, results['test_log_loss'].quantile(0.8))
plt.xlabel('Early stopping')
plt.ylabel('Cross-entropy loss')
plt.show()

In [None]:
results.groupby('early_stopping')['test_log_loss'].mean()

### Recall at 5% FPR

In [None]:
sns.lineplot(
    data=results,
    x='trial', y='test_best_recall_at_fpr', hue='early_stopping',
    linewidth=2, alpha=0.9, legend=None,
)
sns.scatterplot(
    data=results, hue='early_stopping',
    x='trial', y='test_recall_at_fpr',
    s=20, alpha=0.9,
)
plt.xlabel('Trial')
plt.ylabel('Recall at 5% FPR')
plt.legend(title='Early stopping')
plt.show()

sns.lineplot(
    data=results, hue='early_stopping',
    x='trial', y='test_best_recall_at_fpr',
    linewidth=2, alpha=0.9, legend=None,
)
sns.scatterplot(
    data=results, hue='early_stopping',
    x='trial', y='test_recall_at_fpr',
    s=20, alpha=0.9,
)
plt.ylim(results['test_recall_at_fpr'].quantile(0.2), 0.60)
plt.xlabel('Trial')
plt.ylabel('Recall at 5% FPR')
plt.legend(title='Early stopping')
plt.show()

In [None]:
sns.stripplot(data=results, x='early_stopping', y='test_recall_at_fpr', color='black', alpha=0.7)
sns.violinplot(data=results, x='early_stopping', y='test_recall_at_fpr')
plt.ylim(bottom=0)
plt.xlabel('Early stopping')
plt.ylabel('Recall at 5% FPR')
plt.show()

In [None]:
results.groupby('early_stopping')['test_recall_at_fpr'].mean()

## Training Time

In [None]:
results['seconds'] = results['duration'].dt.seconds

In [None]:
sns.stripplot(data=results, x='early_stopping', y='seconds', color='black', alpha=0.7)
sns.violinplot(data=results, x='early_stopping', y='seconds')
plt.ylim(bottom=0)
plt.xlabel('Early stopping')
plt.ylabel('Training time (s)')
plt.show()

In [None]:
sns.stripplot(data=results, x='early_stopping', y='seconds', color='black', alpha=0.7)
sns.boxplot(data=results, x='early_stopping', y='seconds', showfliers=False)
plt.ylim(bottom=0)
plt.xlabel('Early stopping')
plt.ylabel('Training time (s)')
plt.show()

In [None]:
results.groupby('early_stopping')['seconds'].mean()

# Plots for Medium

In [None]:
sns.lineplot(
    data=results, hue='early_stopping',
    x='trial', y='test_best_log_loss',
    linewidth=2, alpha=0.9, legend=None,
)
sns.scatterplot(
    data=results, hue='early_stopping',
    x='trial', y='test_log_loss',
    s=20, alpha=0.9, legend=None
)
plt.ylim(0.0545, results['test_log_loss'].quantile(0.8))
plt.xlabel('Trial')
plt.ylabel('Cross-entropy loss')
plt.show()

In [None]:
sns.lineplot(
    data=results, hue='early_stopping',
    x='trial', y='test_best_recall_at_fpr',
    linewidth=2, alpha=0.9, legend=None,
)
sns.scatterplot(
    data=results, hue='early_stopping',
    x='trial', y='test_recall_at_fpr',
    s=20, alpha=0.9, legend=None
)
plt.ylim(results['test_recall_at_fpr'].quantile(0.2), 0.60)
plt.xlabel('Trial')
plt.ylabel('Recall at 5% FPR')
plt.show()