# FPBoost: Fully Parametric Gradient Boosting for Survival Analysis

Source code of the paper "FPBoost: Fully Parametric Gradient Boosting for Survival Analysis" for AAAI 2025.

## Imports

In [None]:
from typing import Optional
from abc import ABC, abstractmethod

import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

from sklearn.model_selection import StratifiedKFold
from sklearn.compose import make_column_selector
from sklearndf.pipeline import PipelineDF
from sklearndf.transformation import (
    ColumnTransformerDF,
    OneHotEncoderDF,
    SimpleImputerDF,
    StandardScalerDF,
)
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split

from sksurv.metrics import integrated_brier_score
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest as RSF, GradientBoostingSurvivalAnalysis

try:
    from auton_survival.models.dsm import DeepSurvivalMachines as DSM
except ImportError:
    os.system("git clone https://github.com/autonlab/auton-survival.git")
    os.system("mv auton-survival/auton_survival .")
    os.system("rm -r auton-survival")
    from auton_survival.models.dsm import DeepSurvivalMachines as DSM

import numba
import torch
from torch import Tensor
from torch.autograd import Variable
import torch.nn.functional as F

import torchtuples as tt
from pycox.models import DeepHitSingle, CoxPH

import ray
from ray import tune
from ray.tune.search.optuna import OptunaSearch
import optuna

from utils.data_loader import load_dataframe


SEED = 42

np.random.seed(SEED)
torch.manual_seed(SEED);

## Data Loading

This section defines the data loading and preprocessing functions alongsiide the cross-validation code.

In [None]:
FOLDS = 10  # Number of folds for cross-validation

In [None]:
def get_preprocess_transformer():
    """Returns the preprocessing sklearn transformer."""
    sel_fac = make_column_selector(pattern="^fac\\_")
    enc_fac = PipelineDF(
        steps=[("ohe", OneHotEncoderDF(sparse_output=False, handle_unknown="ignore"))]
    )
    sel_num = make_column_selector(pattern="^num\\_")
    enc_num = PipelineDF(
        steps=[
            ("impute", SimpleImputerDF(strategy="median")),
            ("scale", StandardScalerDF()),
        ]
    )
    tr = ColumnTransformerDF(transformers=[("ohe", enc_fac, sel_fac), ("s", enc_num, sel_num)])
    return tr


def get_k_fold_splits(df):
    """Returns a generator of k-fold splits."""
    events = df["event"].values.astype(bool)
    times = df["time"].values
    times = times / times.max()

    skf = StratifiedKFold(n_splits=FOLDS, shuffle=True, random_state=SEED)
    splits = list(skf.split(df, events))
    splits = [s[1] for s in splits]

    for i in range(FOLDS):
        val_idx = splits[i]
        test_idx = splits[(i + 1) % FOLDS]
        train_idx = [j for j in range(len(df)) if j not in val_idx and j not in test_idx]

        df_train = df.iloc[train_idx]
        df_val = df.iloc[val_idx]
        df_test = df.iloc[test_idx]

        tr = get_preprocess_transformer()
        X_train = tr.fit_transform(df_train).to_numpy().astype(np.float32)
        X_val = tr.transform(df_val).to_numpy().astype(np.float32)
        X_test = tr.transform(df_test).to_numpy().astype(np.float32)

        e_train = events[train_idx]
        e_val = events[val_idx]
        e_test = events[test_idx]

        t_train = times[train_idx]
        t_val = times[val_idx]
        t_test = times[test_idx]

        max_time = t_train.max()
        min_time = t_train.min()
        X_val = X_val[(min_time < t_val) & (t_val < max_time)]
        e_val = e_val[(min_time < t_val) & (t_val < max_time)]
        t_val = t_val[(min_time < t_val) & (t_val < max_time)]
        X_test = X_test[(min_time < t_test) & (t_test < max_time)]
        e_test = e_test[(min_time < t_test) & (t_test < max_time)]
        t_test = t_test[(min_time < t_test) & (t_test < max_time)]

        sksurv_type = [("event", bool), ("time", float)]
        y_train = np.array([(e, t) for e, t in zip(e_train, t_train)], dtype=sksurv_type)
        y_val = np.array([(e, t) for e, t in zip(e_val, t_val)], dtype=sksurv_type)
        y_test = np.array([(e, t) for e, t in zip(e_test, t_test)], dtype=sksurv_type)

        yield (X_train, y_train), (X_val, y_val), (X_test, y_test)

## Models

This section defines the FPBoost model and the base learners, all implementing the `SurvModel` abstract class.

In [None]:
@numba.njit
def concordance_index_td(
    events: np.ndarray, times: np.ndarray, risks: np.ndarray, percentile: float = 1.0
) -> float:
    """Computes the concordance index for time-dependent data."""
    threshold_time = np.percentile(times, percentile * 100)
    concordant_pairs, comparable_pairs = 0, 0
    for i, ti in enumerate(times):
        for j, tj in enumerate(times):
            if events[i] == 1 and ti < tj and ti < threshold_time:
                comparable_pairs += 1
                if risks[i] > risks[j]:
                    concordant_pairs += 1
    return concordant_pairs / comparable_pairs if comparable_pairs > 0 else 0.0

In [None]:
class SurvModel(ABC):
    """Base class for survival models."""

    @abstractmethod
    def fit(self, X_train, y_train):
        """Fits the model to the training data.

        Args:
            X_train: Training data of shape (n_samples, n_features).
            y_train: Training labels of shape (n_samples,) with dtype=[("event", bool), ("time", float)].
        """
        pass

    @abstractmethod
    def predict(self, X_test, times) -> np.array:
        """Predicts the survival function for the given times.

        Args:
            X_test: Test data of shape (n_samples, n_features).
            times: Times at which to predict the survival function of shape (n_times,).

        Returns:
            Survival function of shape (n_samples, n_times).
        """
        pass

    def evaluate(self, X_test, y_test, y_train) -> dict[str, float]:
        """Evaluates the model on the test data.

        Args:
            X_test: Test data of shape (n_samples, n_features).
            y_test: Test labels of shape (n_samples,) with dtype=[("event", bool), ("time", float)].
            y_train: Training labels of shape (n_samples,) with dtype=[("event", bool), ("time", float)].

        Returns:
            Dictionary of survival metrics.
        """
        min_time, max_time = y_test["time"].min(), y_test["time"].max()
        tolerance = 0.1 * (max_time - min_time)
        times = np.linspace(min_time + tolerance, max_time - tolerance, 100)
        survs = self.predict(X_test, times)
        mean_times = survs.sum(axis=1)
        c25 = concordance_index_td(y_test["event"], y_test["time"], -mean_times, 0.25)
        c50 = concordance_index_td(y_test["event"], y_test["time"], -mean_times, 0.50)
        c75 = concordance_index_td(y_test["event"], y_test["time"], -mean_times, 0.75)
        cid = concordance_index_td(y_test["event"], y_test["time"], -mean_times)
        try:
            ibs = integrated_brier_score(y_train, y_test, survs, times)
        except ValueError as e:
            ibs = 0.25
        return {
            "cid": cid,
            "ibs": ibs,
            "c25": c25,
            "c50": c50,
            "c75": c75,
        }

In [None]:
class Cox(SurvModel):
    """Cox proportional hazards model (Cox, 1972)."""

    def __init__(self) -> None:
        self.model = CoxPHSurvivalAnalysis(alpha=0.01)
        self.failed_opt = False

    def fit(self, X_train, y_train):
        try:
            self.model.fit(X_train, y_train)
        except ValueError as e:
            self.failed_opt = True

    def predict(self, X_test, times) -> np.array:
        if self.failed_opt:
            return np.ones((X_test.shape[0], len(times))) * 0.5
        return np.array([S(times) for S in self.model.predict_survival_function(X_test)])

In [None]:
class RandomSurvivalForest(SurvModel):
    """Random survival forest model (Ishwaran et al., 2008)."""

    def __init__(self) -> None:
        self.model = RSF(n_jobs=-1)

    def fit(self, X_train, y_train):
        self.model.fit(X_train, y_train)

    def predict(self, X_test, times) -> np.array:
        return np.array([S(times) for S in self.model.predict_survival_function(X_test)])

In [None]:
class DeepSurv(SurvModel):
    """DeepSurv model (Katzman et al., 2018)."""

    def __init__(self) -> None:
        self.net = None
        self.model = None

    def fit(self, X_train, y_train):
        net = tt.practical.MLPVanilla(
            X_train.shape[1],
            [i * X_train.shape[1] for i in [3, 5, 3]],  # hidden layers as in Katzman et al., 2018
            1,  # outputs
            False,  # batch norm
            0.6,  # dropout, as in Katzman et al., 2018
        )
        X_train, X_val, y_train, y_val = train_test_split(
            X_train, y_train, test_size=0.15, stratify=y_train["event"]
        )
        _y_val = (y_val["time"].copy(), y_val["event"].copy())
        val_data = (X_val, _y_val)
        callbacks = [tt.callbacks.EarlyStopping()]
        self.model = CoxPH(net, tt.optim.Adam)
        _y_train = (y_train["time"].copy(), y_train["event"].copy())
        self.model.fit(
            X_train, _y_train, 256, 256, val_data=val_data, callbacks=callbacks, verbose=False
        )
        self.model.compute_baseline_hazards()

    def predict(self, X_test, times) -> np.array:
        preds = self.model.predict_surv_df(X_test)
        unique_times = preds.index.to_numpy()
        survs = np.array([preds.iloc[:, i].values for i in range(preds.shape[1])])
        surv = np.zeros((X_test.shape[0], len(times)))
        for i, t in enumerate(times):
            idx = np.abs(unique_times - t).argmin()
            surv[:, i] = survs[:, idx]
        return surv

In [None]:
class DeepHit(SurvModel):
    """DeepHit model (Lee et al., 2018)."""

    def __init__(self, num_durations: int = 5) -> None:
        self.net = None
        self.model = None
        self.labtrans = None
        self.num_durations = num_durations

    def fit(self, X_train, y_train):
        X_train, X_val, y_train, y_val = train_test_split(
            X_train, y_train, test_size=0.15, stratify=y_train["event"]
        )
        self.labtrans = DeepHitSingle.label_transform(self.num_durations)
        _y_train = self.labtrans.fit_transform(y_train["time"], y_train["event"])
        _y_val = self.labtrans.transform(y_val["time"], y_val["event"])
        val_data = (X_val, _y_val)
        net = tt.practical.MLPVanilla(
            X_train.shape[1],
            [i * X_train.shape[1] for i in [3, 5, 3]],
            self.num_durations,  # outputs
            False,  # batch norm
            0.6,  # dropout
        )
        callbacks = [tt.callbacks.EarlyStopping()]
        self.model = DeepHitSingle(net, tt.optim.Adam, alpha=0.5, duration_index=self.labtrans.cuts)
        self.model.fit(
            X_train, _y_train, 256, 256, val_data=val_data, callbacks=callbacks, verbose=False
        )

    def predict(self, X_test, times) -> np.array:
        preds = self.model.predict_surv_df(X_test)
        unique_times = preds.index.to_numpy()
        survs = np.array([preds.iloc[:, i].values for i in range(preds.shape[1])])
        surv = np.zeros((X_test.shape[0], len(times)))
        for i, t in enumerate(times):
            idx = np.abs(unique_times - t).argmin()
            surv[:, i] = survs[:, idx]
        return surv

In [None]:
class DeepSurvivalMachines(SurvModel):
    """Deep survival machines model (Nagpal et al., 2021)."""

    def __init__(self) -> None:
        self.model = None

    def fit(self, X_train, y_train):
        self.model = DSM(layers=[i * X_train.shape[1] for i in [3, 5, 3]])
        self.model.fit(X_train, y_train["time"], y_train["event"])

    def predict(self, X_test, times) -> np.array:
        r = np.concatenate(
            [
                self.model.predict_risk(X_test.astype(np.float64), t.astype(np.float64))
                for t in times
            ],
            axis=1,
        )
        survs = np.exp(-r)
        return survs

In [None]:
class CoxBoost(SurvModel):
    """CoxBoost model (Ridgeway, 1999)."""

    def __init__(self) -> None:
        self.model = GradientBoostingSurvivalAnalysis()

    def fit(self, X_train, y_train):
        self.model.fit(X_train, y_train)

    def predict(self, X_test, times) -> np.array:
        return np.array([S(times) for S in self.model.predict_survival_function(X_test)])

In [None]:
class FPBoost(SurvModel):
    """FPBoost model for AAAI submission.

    Args:
        weibull_heads: Number of Weibull heads.
        loglogistic_heads: Number of log-logistic heads.
        n_estimators: Number of base learners per estimated parameter.
        max_depth: Maximum depth of the base learners.
        learning_rate: Learning rate for the boosting algorithm.
        alpha: ElasticNet regularization strength.
        l1_ratio: Ratio between L1 and L2 regularization.
        uniform_heads: Whether to use uniform weights for the heads.
        heads_activation: Activation function for the heads. Can be "relu" or "softmax".
        patience: Patience for early stopping.
        verbose: Whether to print progress.
    """

    def __init__(
        self,
        weibull_heads: int,
        loglogistic_heads: int,
        n_estimators: int,
        max_depth: int,
        learning_rate: float,
        alpha: float,
        l1_ratio: float,
        uniform_heads: bool,
        heads_activation: str,
        patience: Optional[int],
        verbose: bool = False,
    ):
        self.weibull_heads = weibull_heads
        self.loglogistic_heads = loglogistic_heads
        self.heads = weibull_heads + loglogistic_heads
        if self.heads == 0:
            self.weibull_heads = 1
            self.heads = 1
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.learning_rate = learning_rate
        self.alpha = alpha
        self.l1_ratio = l1_ratio
        self.uniform_heads = uniform_heads
        self.heads_activation = heads_activation
        self.patience = patience
        self.verbose = verbose

        # Random initialization of the parameters
        self.init_eta = np.random.rand(self.heads) + 0.5
        self.eta_heads = [[] for _ in range(self.heads)]
        self.init_k = np.random.rand(self.heads) * 2
        self.k_heads = [[] for _ in range(self.heads)]
        self.init_w = np.random.rand(self.heads)
        self.w_heads = [[] for _ in range(self.heads)]

        heads_activation_fns = {
            "relu": lambda w: F.relu(w),
            "softmax": lambda w: F.softmax(w, dim=1),
        }
        if heads_activation not in heads_activation_fns:
            raise ValueError(f"Heads activation function not in {heads_activation_fns.keys()}")
        self.heads_activation_fn = heads_activation_fns[heads_activation]

    def _predict_etas(self, X: np.array) -> np.array:
        output = np.zeros((len(X), self.heads)) + self.init_eta.reshape((1, -1))
        for i, regs in enumerate(self.eta_heads):
            if len(regs) == 0:
                continue
            preds = np.concatenate([reg.predict(X).reshape((-1, 1)) for reg in regs], axis=1)
            output[:, i] += self.learning_rate * np.sum(preds, axis=1)
        return output

    def _predict_ks(self, X: np.array) -> np.array:
        output = np.ones((len(X), self.heads)) * self.init_k.reshape((1, -1))
        for i, regs in enumerate(self.k_heads):
            if len(regs) == 0:
                continue
            preds = np.concatenate([reg.predict(X).reshape((-1, 1)) for reg in regs], axis=1)
            output[:, i] += self.learning_rate * np.sum(preds, axis=1)
        return output

    def _predict_ws(self, X: np.array) -> np.array:
        if self.uniform_heads:
            return np.ones((len(X), self.heads)) / self.heads
        output = np.ones((len(X), self.heads)) * self.init_w.reshape((1, -1))
        for i, regs in enumerate(self.w_heads):
            if len(regs) == 0:
                continue
            preds = np.concatenate([reg.predict(X).reshape((-1, 1)) for reg in regs], axis=1)
            output[:, i] += self.learning_rate * np.sum(preds, axis=1)
        return output

    def _predict_params(self, X: np.array) -> np.array:
        etas = self._predict_etas(X).reshape((-1, self.heads, 1))
        ks = self._predict_ks(X).reshape((-1, self.heads, 1))
        ws = self._predict_ws(X).reshape((-1, self.heads, 1))
        return np.concatenate([etas, ks, ws], -1)

    def _weibull_hazard(self, eta, k, times):
        return k * eta * times ** (k - 1)

    def _weibull_cum_hazard(self, eta, k, times):
        return eta * times**k

    def _loglogistic_hazard(self, eta, k, times):
        return eta * k * times ** (k - 1) / (1 + eta * times**k)

    def _loglogistic_cum_hazard(self, eta, k, times):
        if torch.is_tensor(times):
            return torch.log1p(eta * times**k)
        return np.log1p(eta * times**k)

    def _get_neg_grads(self, params: np.array, events: Tensor, times: Tensor) -> np.array:
        params_torch = Variable(torch.tensor(params).float(), requires_grad=True)

        etas = F.relu(params_torch[:, :, 0])
        ks = F.relu(params_torch[:, :, 1])
        ws = self.heads_activation_fn(params_torch[:, :, 2])

        hazard = torch.zeros(len(times))
        cum_hazard = torch.zeros(len(times))

        if self.weibull_heads > 0:
            weibull_hazard = self._weibull_hazard(
                etas[:, : self.weibull_heads], ks[:, : self.weibull_heads], times
            )
            weibull_cum_hazard = self._weibull_cum_hazard(
                etas[:, : self.weibull_heads], ks[:, : self.weibull_heads], times
            )
            hazard += (weibull_hazard * ws[:, : self.weibull_heads]).sum(dim=1)
            cum_hazard += (weibull_cum_hazard * ws[:, : self.weibull_heads]).sum(dim=1)

        if self.loglogistic_heads > 0:
            loglogistic_hazard = self._loglogistic_hazard(
                etas[:, self.weibull_heads :], ks[:, self.weibull_heads :], times
            )
            loglogistic_cum_hazard = self._loglogistic_cum_hazard(
                etas[:, self.weibull_heads :], ks[:, self.weibull_heads :], times
            )
            hazard += (loglogistic_hazard * ws[:, self.weibull_heads :]).sum(dim=1)
            cum_hazard += (loglogistic_cum_hazard * ws[:, self.weibull_heads :]).sum(dim=1)

        log_likelihood = (events * torch.log(hazard) - cum_hazard).mean()
        l1_reg = torch.abs(params_torch).mean()
        l2_reg = (params_torch**2).mean()
        elastic_net_reg = self.l1_ratio * l1_reg + (1 - self.l1_ratio) * l2_reg
        loss = -log_likelihood + self.alpha * elastic_net_reg

        loss.backward()
        grad = params_torch.grad.numpy()
        grad[np.isnan(grad)] = 0.0
        return -(grad / np.abs(grad).max())

    def _fit_base_learner(self, X: np.array, y: np.array):
        reg = DecisionTreeRegressor(max_depth=self.max_depth)
        reg.fit(X, y)
        return reg

    def fit(self, X_train: np.array, y_train: np.array) -> None:
        if self.verbose:
            print(f"Fitting a Survival Boosting model with {self.heads} heads...")

        patience_counter, best_num_base_learners, best_cid = 0, 0, 0.0
        if self.patience is not None:
            X_train, X_val, y_train, y_val = train_test_split(
                X_train,
                y_train,
                test_size=0.2,
                stratify=y_train["event"],
            )

        events = torch.tensor(y_train["event"].copy()).float().reshape((-1,))
        times = torch.tensor(y_train["time"].copy()).float().reshape((-1, 1))
        timeline = np.linspace(np.min(y_train["time"]), np.max(y_train["time"]), 100)

        for j in range(self.n_estimators):
            params = self._predict_params(X_train)

            neg_grads = self._get_neg_grads(params, events, times)
            eta_grads = neg_grads[:, :, 0]
            k_grads = neg_grads[:, :, 1]
            w_grads = neg_grads[:, :, 2]

            for i in range(self.heads):
                self.eta_heads[i].append(self._fit_base_learner(X_train, eta_grads[:, i]))
                self.k_heads[i].append(self._fit_base_learner(X_train, k_grads[:, i]))
                if not self.uniform_heads:
                    self.w_heads[i].append(self._fit_base_learner(X_train, w_grads[:, i]))

            if self.patience is not None:
                survs = self.predict(X_val, timeline)
                mean_times = survs.sum(axis=1)
                cid = concordance_index_td(y_val["event"], y_val["time"], -mean_times)
                if self.verbose:
                    print(f"[Iteration {j:04}] Concordance: {cid:.4f}")
                if cid > best_cid:
                    best_cid = cid
                    best_num_base_learners = len(self.eta_heads[0])
                    patience_counter = 0
                else:
                    patience_counter += 1
                    if patience_counter >= self.patience:
                        break

        if self.patience is not None:
            self.eta_heads = [heads[:best_num_base_learners] for heads in self.eta_heads]
            self.k_heads = [heads[:best_num_base_learners] for heads in self.k_heads]
            self.w_heads = [heads[:best_num_base_learners] for heads in self.w_heads]

    def predict(self, X_test, times) -> np.array:
        times = times.reshape((1, 1, -1))
        params = torch.tensor(self._predict_params(X_test)).float()

        etas = F.relu(params[:, :, 0]).numpy().reshape((-1, self.heads, 1))
        ks = F.relu(params[:, :, 1]).numpy().reshape((-1, self.heads, 1))
        ws = self.heads_activation_fn(params[:, :, 2]).numpy().reshape((-1, self.heads, 1))

        cum_hazard = np.zeros((len(X_test), len(times[0][0])))

        if self.weibull_heads > 0:
            weibull_cum_hazard = self._weibull_cum_hazard(
                etas[:, : self.weibull_heads], ks[:, : self.weibull_heads], times
            )
            cum_hazard += (weibull_cum_hazard * ws[:, : self.weibull_heads]).sum(axis=1)

        if self.loglogistic_heads > 0:
            loglogistic_cum_hazard = self._loglogistic_cum_hazard(
                etas[:, self.weibull_heads :], ks[:, self.weibull_heads :], times
            )
            cum_hazard += (loglogistic_cum_hazard * ws[:, self.weibull_heads :]).sum(axis=1)
        surv = np.exp(-cum_hazard)
        return surv

In [None]:
def init_model(model: str, params: dict) -> SurvModel:
    """Initializes a survival model with the given parameters.

    Args:
        model: String with the model name.
        params: Dictionary with the model parameters.

    Returns:
        Initialized survival model.
    """

    models = {
        "cox": Cox,
        "rsf": RandomSurvivalForest,
        "deepsurv": DeepSurv,
        "deephit": DeepHit,
        "dsm": DeepSurvivalMachines,
        "coxboost": CoxBoost,
        "fpboost": FPBoost,
    }
    return models[model](**params)

## Training

In [None]:
def atomic_training(dataset: str, model: str, params: dict) -> dict[str, float]:
    """Trains and evaluates a model on a dataset."""

    np.random.seed(SEED)
    torch.manual_seed(SEED)

    df = load_dataframe(dataset)

    metrics = {}
    for train, val, test in get_k_fold_splits(df):
        X_train, y_train = train
        X_val, y_val = val
        X_test, y_test = test

        m = init_model(model, params)
        m.fit(X_train, y_train)
        val_results = m.evaluate(X_val, y_val, y_train)
        test_results = m.evaluate(X_test, y_test, y_train)

        for k, v in val_results.items():
            k = f"{k}_val"
            if k not in metrics:
                metrics[k] = []
            metrics[k].append(v)
        for k, v in test_results.items():
            k = f"{k}_test"
            if k not in metrics:
                metrics[k] = []
            metrics[k].append(v)

    ret = {}
    for k, v in metrics.items():
        ret[f"{k}_mean"] = np.mean(v).item()
        ret[f"{k}_std"] = np.std(v).item()

    return ret

## Results

Train and evaluate the baseline models and FPBoost on the datasets.

In [None]:
NUM_CPUS = 8  # Number of CPUs for parallel training
OBJ_MEMORY_GB = 2  # Memory for each ray object in GB

RESULTS_PATH = "results"  # Path to save the results

# Datasets on which to train the models
DATASETS = [
    "aids",
    "breast_cancer",
    "gbsg",
    "metabric",
    "support",
    "veterans",
    "whas",
]

# Baseline models to train
BASELINE_MODELS = ["rsf", "cox", "coxboost", "deepsurv", "dsm", "deephit"]

# Search space for the hyperparameter optimization of the FPBoost model
SEARCH_SPACE = {
    "weibull_heads": tune.randint(0, 8),
    "loglogistic_heads": tune.randint(0, 8),
    "n_estimators": tune.randint(1, 256),
    "max_depth": tune.randint(1, 8),
    "learning_rate": tune.uniform(1e-2, 1),
    "alpha": tune.uniform(0.0, 1.0),
    "l1_ratio": tune.uniform(0, 1),
    "uniform_heads": tune.choice([True, False]),
    "heads_activation": tune.choice(["relu", "softmax"]),
    "patience": tune.choice([None, 4, 16]),
}

ITERATIONS = 8  # Number of iterations for the hyperparameter optimization

In [None]:
os.makedirs(RESULTS_PATH, exist_ok=True)
ray.init(
    num_cpus=NUM_CPUS,
    object_store_memory=OBJ_MEMORY_GB * 1024 * 1024 * 1024,
    ignore_reinit_error=True,
)

### Baselines

In [None]:
tempdir = os.path.join(RESULTS_PATH, "temp_baselines")
os.makedirs(tempdir, exist_ok=True)


@ray.remote
def remote_baseline_training(dataset: str, model: str) -> dict[str, float]:
    if os.path.exists(os.path.join(tempdir, f"{dataset}_{model}.json")):
        return None
    ret = atomic_training(dataset, model, {})
    ret["dataset"] = dataset
    ret["model"] = model
    ret["params"] = {}
    with open(os.path.join(tempdir, f"{dataset}_{model}.json"), "w") as f:
        json.dump(ret, f)
    return ret


BASELINE_RESULTS_FILE = os.path.join(RESULTS_PATH, "baseline_results.csv")

baseline_results = []
for dataset in DATASETS:
    for model in BASELINE_MODELS:
        baseline_results.append(remote_baseline_training.remote(dataset, model))
baseline_results = ray.get(baseline_results)

results = {}
for f in os.listdir(tempdir):
    with open(os.path.join(tempdir, f), "r") as file:
        r = json.load(file)
        for k, v in r.items():
            if k not in results:
                results[k] = []
            results[k].append(v)

baseline_results_df = pd.DataFrame(results)
baseline_results_df.to_csv(BASELINE_RESULTS_FILE, index=False)

### FPBoost

In [None]:
def objective(config):
    return atomic_training(config["dataset"], "fpboost", config["params"])


os.makedirs(RESULTS_PATH, exist_ok=True)

for dataset in DATASETS:
    print(f"Training FPBoost on {dataset}...")

    search_alg = OptunaSearch(
        metric="cid_val_mean", mode="max", sampler=optuna.samplers.TPESampler()
    )

    analysis = tune.run(
        objective,
        config={"dataset": dataset, "params": SEARCH_SPACE},
        num_samples=ITERATIONS,
        search_alg=search_alg,
        name=f"{dataset}_{model}_tune_optuna_experiment",
        storage_path=f"file://{os.path.abspath(RESULTS_PATH)}",
    )

    df = analysis.results_df
    csv_filename = f"{dataset}_{model}_tune_optuna_experiment.csv"
    df.to_csv(os.path.join(RESULTS_PATH, csv_filename), index=False)

In [None]:
fpboost_results = []
for file in os.listdir(RESULTS_PATH):
    if file.endswith("_tune_optuna_experiment.csv"):
        df = pd.read_csv(os.path.join(RESULTS_PATH, file))
        fpboost_results.append(df)

df = pd.concat(fpboost_results)
df["model"] = "fpboost"
df = df.dropna(axis=1, how="any")
df.columns = df.columns.str.replace("config/", "")
df.reset_index(drop=True, inplace=True)

# Select the best hyperparameters for each dataset accrding to the C-Index - IBS difference on the validation set
df["sel_col"] = df["cid_val_mean"] - df["ibs_val_mean"]
idx = df.groupby(["dataset", "model"])["sel_col"].idxmax()
df = df.loc[idx]
assert df.groupby(["dataset", "model"]).size().eq(1).all()

df = df.drop(
    columns=[
        "sel_col",
        "timestamp",
        "time_since_restore",
        "pid",
        "time_total_s",
        "date",
        "training_iteration",
        "time_this_iter_s",
        "done",
        "hostname",
        "node_ip",
        "iterations_since_restore",
        "experiment_tag",
    ]
)

FPBOOST_RESULTS_FILE = os.path.join(RESULTS_PATH, "fpboost_results.csv")

df.to_csv(FPBOOST_RESULTS_FILE, index=False)

### Results Collection

Load the results returned by the previous cells.

In [None]:
baseline_results_df = pd.read_csv(BASELINE_RESULTS_FILE)
fpboost_results_df = pd.read_csv(FPBOOST_RESULTS_FILE)

results_df = pd.concat([baseline_results_df, fpboost_results_df])
results_df.sort_values(["dataset", "model"], inplace=True)
results_df.to_csv(os.path.join(RESULTS_PATH, "results.csv"), index=False)