<a href="https://colab.research.google.com/github/JSK2022/RandomForest/blob/main/JS_simulation_code_230808.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install scikit-survival

In [None]:
from math import ceil
from numbers import Integral, Real

import numpy as np
from scipy.sparse import issparse
from sklearn.base import BaseEstimator
from sklearn.tree import _tree
from sklearn.tree._classes import DENSE_SPLITTERS, SPARSE_SPLITTERS
from sklearn.tree._splitter import Splitter
from sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.validation import check_is_fitted, check_random_state

import sksurv
from sksurv.base import SurvivalAnalysisMixin
from sksurv.functions import StepFunction
from sksurv.util import check_array_survival
from sksurv.tree._criterion import LogrankCriterion, get_unique_times

__all__ = ["SurvivalTree"]

DTYPE = _tree.DTYPE


def _array_to_step_function(x, array):
    n_samples = array.shape[0]
    funcs = np.empty(n_samples, dtype=np.object_)
    for i in range(n_samples):
        funcs[i] = StepFunction(x=x, y=array[i])
    return funcs


class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):

    _parameter_constraints = {
        "splitter": [StrOptions({"best", "random"})],
        "max_depth": [Interval(Integral, 1, None, closed="left"), None],
        "min_samples_split": [
            Interval(Integral, 2, None, closed="left"),
            Interval(Real, 0.0, 1.0, closed="neither"),
        ],
        "min_samples_leaf": [
            Interval(Integral, 1, None, closed="left"),
            Interval(Real, 0.0, 0.5, closed="right"),
        ],
        "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
        "max_features": [
            Interval(Integral, 1, None, closed="left"),
            Interval(Real, 0.0, 1.0, closed="right"),
            StrOptions({"auto", "sqrt", "log2"}, deprecated={"auto"}),
            None,
        ],
        "random_state": ["random_state"],
        "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None],
    }

    def __init__(
        self,
        *,
        splitter="best",
        max_depth=None,
        min_samples_split=6,
        min_samples_leaf=3,
        min_weight_fraction_leaf=0.0,
        max_features=None,
        random_state=None,
        max_leaf_nodes=None,
    ):
        self.splitter = splitter
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.min_weight_fraction_leaf = min_weight_fraction_leaf
        self.max_features = max_features
        self.random_state = random_state
        self.max_leaf_nodes = max_leaf_nodes

    def fit(self, X, y, sample_weight=None, check_input=True):

        random_state = check_random_state(self.random_state)

        if check_input:
            X = self._validate_data(X, ensure_min_samples=2, accept_sparse="csc")
            event, time = check_array_survival(X, y)
            time = time.astype(np.float64)
            self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
            if issparse(X):
                X.sort_indices()

            y_numeric = np.empty((X.shape[0], 2), dtype=np.float64)
            y_numeric[:, 0] = time
            y_numeric[:, 1] = event.astype(np.float64)
        else:
            y_numeric, self.unique_times_, self.is_event_time_ = y

        n_samples, self.n_features_in_ = X.shape
        params = self._check_params(n_samples)

        self.n_outputs_ = self.unique_times_.shape[0]
        # one "class" for CHF, one for survival function
        self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2

        # Build tree
        self.criterion = "logrank"
        criterion = LogrankCriterion(self.n_outputs_, n_samples, self.unique_times_)

        SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS

        splitter = self.splitter
        if not isinstance(self.splitter, Splitter):
            splitter = SPLITTERS[self.splitter](
                criterion, self.max_features_, params["min_samples_leaf"], params["min_weight_leaf"], random_state
            )

        self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_)

        # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
        if params["max_leaf_nodes"] < 0:
            builder = DepthFirstTreeBuilder(
                splitter,
                params["min_samples_split"],
                params["min_samples_leaf"],
                params["min_weight_leaf"],
                params["max_depth"],
                0.0,  # min_impurity_decrease
            )
        else:
            builder = BestFirstTreeBuilder(
                splitter,
                params["min_samples_split"],
                params["min_samples_leaf"],
                params["min_weight_leaf"],
                params["max_depth"],
                params["max_leaf_nodes"],
                0.0,  # min_impurity_decrease
            )

        builder.build(self.tree_, X, y_numeric, sample_weight)

        return self

    def _check_params(self, n_samples):
        self._validate_params()

        # Check parameters
        max_depth = (2**31) - 1 if self.max_depth is None else self.max_depth

        max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes

        if isinstance(self.min_samples_leaf, (Integral, np.integer)):
            min_samples_leaf = self.min_samples_leaf
        else:  # float
            min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))

        if isinstance(self.min_samples_split, Integral):
            min_samples_split = self.min_samples_split
        else:  # float
            min_samples_split = int(ceil(self.min_samples_split * n_samples))
            min_samples_split = max(2, min_samples_split)

        min_samples_split = max(min_samples_split, 2 * min_samples_leaf)

        self._check_max_features()

        if not 0 <= self.min_weight_fraction_leaf <= 0.5:
            raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")

        min_weight_leaf = self.min_weight_fraction_leaf * n_samples

        return {
            "max_depth": max_depth,
            "max_leaf_nodes": max_leaf_nodes,
            "min_samples_leaf": min_samples_leaf,
            "min_samples_split": min_samples_split,
            "min_weight_leaf": min_weight_leaf,
        }

    def _check_max_features(self):
        if isinstance(self.max_features, str):
            if self.max_features in ("auto", "sqrt"):
                max_features = max(1, int(np.sqrt(self.n_features_in_)))
            elif self.max_features == "log2":
                max_features = max(1, int(np.log2(self.n_features_in_)))

        elif self.max_features is None:
            max_features = self.n_features_in_
        elif isinstance(self.max_features, (Integral, np.integer)):
            max_features = self.max_features
        else:  # float
            if self.max_features > 0.0:
                max_features = max(1, int(self.max_features * self.n_features_in_))
            else:
                max_features = 0

        if not 0 < max_features <= self.n_features_in_:
            raise ValueError("max_features must be in (0, n_features]")

        self.max_features_ = max_features

    def _validate_X_predict(self, X, check_input, accept_sparse="csr"):
        """Validate X whenever one tries to predict"""
        if check_input:
            X = self._validate_data(X, dtype=DTYPE, accept_sparse=accept_sparse, reset=False)
        else:
            # The number of features is checked regardless of `check_input`
            self._check_n_features(X, reset=False)

        return X

    def predict(self, X, check_input=True):

        chf = self.predict_cumulative_hazard_function(X, check_input, return_array=True)
        return chf[:, self.is_event_time_].sum(1)

    def predict_cumulative_hazard_function(self, X, check_input=True, return_array=False):

        check_is_fitted(self, "tree_")
        X = self._validate_X_predict(X, check_input, accept_sparse="csr")

        pred = self.tree_.predict(X)
        arr = pred[..., 0]
        if return_array:
            return arr
        return _array_to_step_function(self.unique_times_, arr)

    def predict_survival_function(self, X, check_input=True, return_array=False):

        check_is_fitted(self, "tree_")
        X = self._validate_X_predict(X, check_input, accept_sparse="csr")

        pred = self.tree_.predict(X)
        arr = pred[..., 1]
        if return_array:
            return arr
        return _array_to_step_function(self.unique_times_, arr)

    def apply(self, X, check_input=True):

        check_is_fitted(self, "tree_")
        self._validate_X_predict(X, check_input)
        return self.tree_.apply(X)

    def decision_path(self, X, check_input=True):

        X = self._validate_X_predict(X, check_input)
        return self.tree_.decision_path(X)

In [None]:
from collections import defaultdict
import numpy as np

from collections import defaultdict
import numpy as np

def get_unique_times_by_id(ids, time_start, time_stop, event):
    unique_times_by_id = defaultdict(list)
    has_event_by_id = defaultdict(list)

    for id_, t_start, t_stop, e in zip(ids, time_start, time_stop, event):
        if not unique_times_by_id[id_] or t_stop != unique_times_by_id[id_][-1]:
            unique_times_by_id[id_].append(t_stop)
            has_event_by_id[id_].append(e)
        elif e == 1:
            has_event_by_id[id_][-1] = True

    for id_ in unique_times_by_id.keys():
        unique_times_by_id[id_] = np.asarray(unique_times_by_id[id_])
        has_event_by_id[id_] = np.asarray(has_event_by_id[id_], dtype=np.bool_)

    return unique_times_by_id, has_event_by_id

class RisksetCounter:
    def __init__(self, ids, time_start, time_stop, event):
        self.ids = ids
        self.time_start = time_start
        self.time_stop = time_stop
        self.event = event

        self.all_unique_times = np.unique(time_stop)
        self.n_unique_times = len(self.all_unique_times)

        self.n_at_risk = np.zeros(self.n_unique_times, dtype=np.int64)
        self.n_events = np.zeros(self.n_unique_times, dtype=np.int64)

        self.unique_times_by_id, self.has_event_by_id = get_unique_times_by_id(ids, time_start, time_stop, event)
        self.set_data()

    def reset(self):
        self.n_at_risk.fill(0)
        self.n_events.fill(0)
        self.set_data()

    def set_data(self):
        for id_, times in self.unique_times_by_id.items():
            events = self.has_event_by_id[id_]
            for t, e in zip(times, events):
                idx = np.searchsorted(self.all_unique_times, t)
                self.n_at_risk[idx:] += 1
                self.n_events[idx] += e

    def update(self, ids, time_start, time_stop, event):
        new_unique_times_by_id, new_has_event_by_id = get_unique_times_by_id(ids, time_start, time_stop, event)

        for id_, times in new_unique_times_by_id.items():
            events = new_has_event_by_id[id_]
            for t, e in zip(times, events):
                idx = np.searchsorted(self.all_unique_times, t)
                self.n_at_risk[idx:] -= 1
                self.n_events[idx] -= e

# Now, let's proceed with the updated class and demonstration
class RisksetCounterUpdated(RisksetCounter):
    def __init__(self, ids, time_start, time_stop, event):
        super().__init__(ids, time_start, time_stop, event)

    def at_id_time(self, id_, t_idx):
        at_risk = 0
        events = 0
        return at_risk, events



In [None]:
import numpy as np

ids = np.array([
    1, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 9, 9, 10, 11, 11, 11, 12, 12, 12
])

data = np.array([
    [0, 1, 0], [0, 4, 0], [0, 7, 0], [0, 10, 0], [0, 6, 1], [6, 10, 0], [0, 14, 0], [0, 18, 0], [0, 5, 1],
    [5, 18, 0], [0, 12, 1], [12, 16, 1], [16, 18, 0], [0, 23, 0], [0, 10, 1], [10, 15, 1], [15, 23, 0],
    [0, 3, 1], [3, 16, 1], [16, 23, 1]
])


In [None]:
ids_dummy = np.array([1, 1, 2, 2])
time_start_dummy = np.array([0, 1, 0, 2])
time_stop_dummy = np.array([1, 2, 2, 3])
event_dummy = np.array([0, 1, 0, 1])

# Proceeding with the demonstration using the dummy data
riskset_counter = RisksetCounterUpdated(ids_dummy, time_start_dummy, time_stop_dummy, event_dummy)
unique_times_dummy = np.unique(time_stop_dummy)

results = []
for id_ in np.unique(ids_dummy):
    for t_idx, t in enumerate(unique_times_dummy):
        at_risk, events = riskset_counter.at_id_time(id_, t_idx)
        results.append((id_, t, at_risk, events))
results

주요 변경 사항:

1. 데이터셋은 ID와 시간으로 정렬되어 있다고 가정하였습니다.
2. 이전 ID 값을 추적하여 현재 ID가 변경되면 위험 집합을 업데이트하는 시간을 재설정합니다. 이렇게 하면 동일한 ID에 대해 동일한 시간에 여러 번 위험 집합을 업데이트하지 않습니다.
3. 이벤트 발생 여부와 관계없이 위험 집합을 업데이트하며, 이벤트가 발생한 경우에만 이벤트 카운터를 증가시킵니다.

In [None]:
def argbinsearch(arr, key_val):
    arr_len = len(arr)
    min_idx = 0
    max_idx = arr_len

    while min_idx < max_idx:
        mid_idx = min_idx + ((max_idx - min_idx) // 2)

        if mid_idx < 0 or mid_idx >= arr_len:
            return -1

        mid_val = arr[mid_idx]
        if mid_val <= key_val:  # Change the condition to <=
            min_idx = mid_idx + 1
        else:
            max_idx = mid_idx

    return min_idx

In [None]:
from sklearn.utils import check_random_state

class PseudoScoreCriterion:
    def __init__(self, n_outputs, n_samples, unique_times, x, ids, time_start, time_stop, event, random_state=None):
        self.n_outputs = n_outputs
        self.n_samples = n_samples
        self.n_unique_times = len(unique_times)

        self.x = x
        self.ids = ids
        self.time_start = time_start
        self.time_stop = time_stop
        self.event = event
        self.unique_times = unique_times
        self.random_state = check_random_state(random_state)

        # For interval censoring
        self.n_node_samples = 0
        self.weighted_n_node_samples = 0.0
        self.weighted_n_left = 0.0
        self.weighted_n_right = 0.0

        # Initialize the risk set counter without unique_times
        self.riskset_total = RisksetCounter(ids, time_start, time_stop, event)
        self.delta_n_at_risk_left = np.zeros(self.n_unique_times, dtype=np.int64)
        self.n_events_left = np.zeros(self.n_unique_times, dtype=np.int64)
        self.delta_n_at_risk_right = np.zeros(self.n_unique_times, dtype=np.int64)
        self.n_events_right = np.zeros(self.n_unique_times, dtype=np.int64)

        # For reset
        self.start = 0
        self.pos = 0
        self.end = 0
        self.samples = None  # Initialize the samples attribute

        # For children impurity computation
        self.left_impurity = np.empty(self.n_unique_times, dtype=np.float64)
        self.right_impurity = np.empty(self.n_unique_times, dtype=np.float64)

        # To store the unique time index for each sample
        self.samples_time_idx = np.zeros(n_samples, dtype=np.int64)
        for i in range(n_samples):
            self.samples_time_idx[i] = np.searchsorted(unique_times, time_stop[i])

    def init(self, y, sample_weight, n_samples, samples, start, end):
        self.n_node_samples = end - start
        self.weighted_n_node_samples = 0.0

        start_times = y[:, 0]
        stop_times = y[:, 1]
        event = y[:, 2]
        self.samples = samples  # Storing the samples for this node

        for idx in samples[start:end]:
            self.riskset_total.update([self.ids[idx]], [start_times[idx]], [stop_times[idx]], [event[idx]])
            w = sample_weight[idx] if sample_weight is not None else 1.0
            self.weighted_n_node_samples += w

    def update(self, new_pos, split_feature, split_threshold):
        pos = self.start  # always start from the beginning

        # Initialize the statistics for each side of the split
        self.Y_left = np.zeros(self.n_unique_times, dtype=np.int64)
        self.Y_right = np.zeros(self.n_unique_times, dtype=np.int64)
        self.dN_left = np.zeros(self.n_unique_times, dtype=np.int64)
        self.dN_right = np.zeros(self.n_unique_times, dtype=np.int64)

        self.n_samples_left = new_pos - pos
        self.delta_n_at_risk_left[:] = 0
        self.n_events_left[:] = 0

        # Update statistics up to new_pos
        self.weighted_n_left = 0.0
        for i in range(pos, new_pos):
            idx = self.samples[i]
            event = self.event[idx]  # Modified this line
            time_idx = self.samples_time_idx[idx]

            # Decide which side of the split the sample falls on
            is_left = self.x[idx, split_feature] <= split_threshold

            if is_left:
                self.Y_left[time_idx] += 1
                self.dN_left[time_idx] += event
            else:
                self.Y_right[time_idx] += 1
                self.dN_right[time_idx] += event

            w = 1.0  # Since we aren't using sample weights in this implementation
            self.weighted_n_left += w * is_left

        self.weighted_n_left = (self.weighted_n_node_samples - self.weighted_n_left)
        self.pos = new_pos

    def proxy_impurity_improvement(self):
        # Compute w(t|x_j,c)
        w = (self.Y_left * self.Y_right) / (self.Y_left + self.Y_right + 1e-7)

        # Compute the numerator for the standardized pseudo-score test statistics
        numer = np.sum(w * (self.dN_left / (self.Y_left + 1e-7) - self.dN_right / (self.Y_right + 1e-7)))

        # Compute the variance estimate for the standardized pseudo-score test statistics
        var_estimate = 0.0
        for t in range(self.n_unique_times):
            for Y, dN, total_Y, total_dN in [(self.Y_left, self.dN_left, self.Y_left + self.Y_right, self.dN_left + self.dN_right), (self.Y_right, self.dN_right, self.Y_left + self.Y_right, self.dN_left + self.dN_right)]:
                if Y[t] == 0:
                    continue
                term = w[t] * Y[t] / total_Y[t] * (dN[t] - total_dN[t] / total_Y[t])
                var_estimate += term ** 2

        # Compute the standardized pseudo-score test statistic
        return numer / np.sqrt(var_estimate + 1e-7) if var_estimate != 0.0 else -np.inf


    def node_value(self):
        # Extract the necessary statistics from the risk set
        Y = self.riskset_total.n_at_risk
        dN_bar = self.riskset_total.n_events

        dest = np.zeros(2 * self.n_unique_times)

        for t in range(self.n_unique_times):
            if Y[t] == 0:
                break

            # Nelson-Aalen estimator for the mean function
            dest[2 * t] = np.sum(dN_bar[:t + 1] / (Y[:t + 1] + 1e-7))
            dest[2 * t + 1] = dN_bar[t] / (Y[t] + 1e-7) if Y[t] != 0 else 0

        return dest

    def reset(self):
        # Reset the risk set counter for the total node
        self.riskset_total.reset()


In [None]:
!pip install scikit-survival

In [None]:
from math import ceil
from numbers import Integral, Real

import numpy as np
from scipy.sparse import issparse
from sklearn.base import BaseEstimator
from sklearn.tree import _tree
from sklearn.tree._classes import DENSE_SPLITTERS, SPARSE_SPLITTERS
from sklearn.tree._splitter import Splitter
from sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.validation import check_is_fitted, check_random_state

import sksurv
from sksurv.base import SurvivalAnalysisMixin
from sksurv.functions import StepFunction
from sksurv.util import check_array_survival
from sksurv.tree._criterion import LogrankCriterion, get_unique_times

from sklearn.utils import check_random_state

from queue import LifoQueue
import pandas as pd
class PseudoScoreTreeBuilder:
    TREE_UNDEFINED = -np.inf

    def __init__(self, max_depth=4, min_leaf=20, random_state=None):
        self.max_depth = max_depth
        self.min_leaf = min_leaf
        self.random_state = random_state
        self.random_state_ = check_random_state(self.random_state)

    def build(self, X, ids, time_start, time_stop, event):
        n_samples, n_features = X.shape
        data = np.column_stack((time_start, time_stop, event))

        # Building the tree using the specified hyperparameters
        criterion = PseudoScoreCriterion(n_outputs=n_features, n_samples=n_samples, unique_times=np.unique(time_stop), x=X, ids=ids, time_start=time_start, time_stop=time_stop, event=event, random_state=self.random_state_)
        criterion.init(data, None, n_samples, np.arange(n_samples), 0, n_samples, 0, n_samples)
        val, feat, stat = self._get_best_split(X, data, criterion)
        splits = LifoQueue()
        splits.put((val, feat, stat, 0, np.arange(X.shape[0])))

        node_stats = []
        while splits.qsize() > 0:
            val, feat, stat, lvl, idx = splits.get()
            s = {"feature": feat, "threshold": val, "n_node_samples": idx.shape[0], "statistic": stat, "depth": lvl}
            node_stats.append(s)

            if val == self.TREE_UNDEFINED:
                continue

            left = X[idx, feat] <= val
            right = idx[~left]
            left = idx[left]

            if lvl == self.max_depth - 1:
                splits.put([self.TREE_UNDEFINED, self.TREE_UNDEFINED, [-np.inf], lvl + 1, right])
                splits.put([self.TREE_UNDEFINED, self.TREE_UNDEFINED, [-np.inf], lvl + 1, left])
                continue

            X_right = X[right, :]
            data_right = data[right, :]
            s_right = self._get_best_split(X_right, data_right, criterion)
            splits.put(list(s_right) + [lvl + 1, right])

            X_left = X[left, :]
            data_left = data[left, :]
            s_left = self._get_best_split(X_left, data_left, criterion)
            splits.put(list(s_left) + [lvl + 1, left])

        return pd.DataFrame.from_dict(dict(zip(range(len(node_stats)), node_stats)), orient="index")


    def _get_best_split(self, X, y, criterion):
        n_samples, n_features = X.shape
        min_leaf = self.min_leaf
        best_val = self.TREE_UNDEFINED
        best_feat = self.TREE_UNDEFINED
        best_stat = [-np.inf]

        # Initialize the criterion for the whole node
        criterion.init(y, None, n_samples, np.arange(n_samples), 0, n_samples)

        for feat in range(n_features):
            if np.unique(X[:, feat]).shape[0] < 2:
                continue

            for val in np.unique(X[:, feat]):
                # Reset criterion for each potential split
                criterion.reset()
                left = X[:, feat] <= val

                if np.sum(left) < min_leaf or np.sum(~left) < min_leaf:
                    continue

                # Update the criterion with the current potential split
                criterion.update(np.sum(left), feat, val, np.where(left)[0])
                stat = criterion.proxy_impurity_improvement()

                if stat > best_stat[0]:
                    best_feat = feat
                    best_val = val
                    best_stat = [stat]

        return best_val, best_feat, best_stat


In [None]:
class RecurrentTree:
    def __init__(self, max_depth=4, min_leaf=20, random_state=None):
        self.max_depth = max_depth
        self.min_leaf = min_leaf
        self.random_state = random_state
        self.tree_ = None

    def fit(self, X, ids, time_start, time_stop, event):
        # Ensure input is in the expected format
        X = np.array(X)
        ids = np.array(ids)
        time_start = np.array(time_start)
        time_stop = np.array(time_stop)
        event = np.array(event)

        # Use the PseudoScoreTreeBuilder to build the tree
        builder = PseudoScoreTreeBuilder(max_depth=self.max_depth, min_leaf=self.min_leaf, random_state=self.random_state)
        self.tree_ = builder.build(X, ids=ids, time_start=time_start, time_stop=time_stop, event=event)

        return self

    def get_tree(self):
        """
        Return the tree as a pandas DataFrame.
        Each row represents a node with its associated statistics.
        """
        return self.tree_

    def _traverse_tree(self, x, node_idx):
        """Traverse the tree to find the terminal node for a given sample."""
        node = self.tree_.loc[node_idx]

        # Check if it's a terminal node
        if node["threshold"] == PseudoScoreTreeBuilder.TREE_UNDEFINED:
            return node_idx

        if x[node["feature"]] <= node["threshold"]:
            return self._traverse_tree(x, 2 * node_idx + 1)  # Assume left child at 2*idx + 1
        else:
            return self._traverse_tree(x, 2 * node_idx + 2)  # Assume right child at 2*idx + 2

    def predict(self, X):
        # Ensure input is in the expected format
        X = np.array(X)
        n_samples = X.shape[0]

        terminal_node_indices = np.zeros(n_samples, dtype=int)

        for i in range(n_samples):
            # Traverse the tree to find the terminal node for the current sample
            terminal_node_idx = self._traverse_tree(X[i], 0)
            terminal_node_indices[i] = terminal_node_idx

        return terminal_node_indices


In [None]:
import numpy as np
from sklearn.model_selection import train_test_split

# Generate synthetic recurrent event data
np.random.seed(42)
n_samples = 1000
n_features = 5

# Feature data
X = np.random.randn(n_samples, n_features)

# Generate some sample IDs
ids = np.array([f"id_{i}" for i in range(1, n_samples + 1)])

# Generate synthetic time start, time stop, and event data
time_start = np.random.randint(0, 10, size=n_samples)
time_stop = time_start + np.random.randint(1, 5, size=n_samples)
event = np.random.randint(0, 2, size=n_samples)

# Split the data into training and testing sets
X_train, X_test, ids_train, ids_test, time_start_train, time_start_test, time_stop_train, time_stop_test, event_train, event_test = train_test_split(X, ids, time_start, time_stop, event, test_size=0.2, random_state=42)

# Now, let's try training the tree again
tree = RecurrentTree(max_depth=3, min_leaf=5)
tree.fit(X=X_train, ids=ids_train, time_start=time_start_train, time_stop=time_stop_train, event=event_train)

# Predict using the RecurrentTree on the test set
predictions = tree.predict(X_test)

# Print a summary of the built tree and the first few predictions
tree_summary = tree.get_tree()
tree_summary, predictions[:5]

In [None]:
import numpy as np
from numbers import Integral, Real
from sklearn.utils import check_random_state

def _get_n_samples_bootstrap(n_ids, max_samples):
    """
    Modified for recurrent events. Get the number of IDs in a bootstrap sample.
    """
    if max_samples is None:
        return n_ids

    if isinstance(max_samples, Integral):
        if max_samples > n_ids:
            msg = "`max_samples` must be <= n_ids={} but got value {}"
            raise ValueError(msg.format(n_ids, max_samples))
        return max_samples

    if isinstance(max_samples, Real):
        return max(round(n_ids * max_samples), 1)

def _generate_sample_indices(random_state, ids, n_ids_bootstrap):
    """
    Sample unique IDs and then expand to all associated events.
    """
    random_instance = check_random_state(random_state)
    sampled_ids = np.random.choice(ids, n_ids_bootstrap, replace=True)
    return sampled_ids

def _generate_unsampled_indices(random_state, ids, n_ids_bootstrap):
    """
    Determine unsampled IDs and then expand to all associated events.
    """
    sampled_ids = _generate_sample_indices(random_state, ids, n_ids_bootstrap)
    unsampled_ids = np.setdiff1d(ids, sampled_ids)

    # Expand these unsampled IDs to include all their associated events.
    # Again, this will depend on your data structure.
    # As an example:
    # unsampled_indices = np.concatenate([events_by_id[id] for id in unsampled_ids])

    return unsampled_ids  # or return unsampled_indices based on your data structure



from warnings import catch_warnings, simplefilter
from sklearn.utils.class_weight import compute_sample_weight

def _parallel_build_trees(
    tree,
    bootstrap,
    X,
    y,
    ids,  # New parameter: a list/array of IDs corresponding to each event in X and y
    sample_weight,
    tree_idx,
    n_trees,
    verbose=0,
    class_weight=None,
    n_ids_bootstrap=None,  # Instead of n_samples_bootstrap
):
    """
    Private function used to fit a single tree in parallel for recurrent events."""
    if verbose > 1:
        print("building tree %d of %d" % (tree_idx + 1, n_trees))

    if bootstrap:
        unique_ids = np.unique(ids)
        n_ids = len(unique_ids)

        # Generate bootstrap samples using IDs
        sampled_ids = _generate_sample_indices(
            tree.random_state, unique_ids, n_ids_bootstrap
        )

        # Expand sampled IDs to all their associated events
        indices = np.where(np.isin(ids, sampled_ids))[0]

        if sample_weight is None:
            curr_sample_weight = np.ones((X.shape[0],), dtype=np.float64)
        else:
            curr_sample_weight = sample_weight.copy()

        # Adjust the sample weight based on how many times each ID was sampled
        sample_counts_for_ids = np.bincount(np.searchsorted(unique_ids, sampled_ids), minlength=n_ids)
        curr_sample_weight *= sample_counts_for_ids[np.searchsorted(unique_ids, ids)]

        if class_weight == "subsample":
            with catch_warnings():
                simplefilter("ignore", DeprecationWarning)
                curr_sample_weight *= compute_sample_weight("auto", y, indices=indices)
        elif class_weight == "balanced_subsample":
            curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices)

        tree.fit(X[indices], y[indices], sample_weight=curr_sample_weight[indices], check_input=False)
    else:
        tree.fit(X, y, sample_weight=sample_weight, check_input=False)

    return tree


In [None]:
def compute_C_index(mu_OOB_i, mu_OOB_iprime, N_i, N_iprime):
    """
    Computes the concordance index for recurrent events.

    Parameters:
    - mu_OOB_i: List of cumulative hazards/risk scores for individuals i.
    - mu_OOB_iprime: List of cumulative hazards/risk scores for individuals i'.
    - N_i: List of number of events experienced by individuals i up to a given time.
    - N_iprime: List of number of events experienced by individuals i' up to a given time.

    Returns:
    - C_index: The computed concordance index.
    """

    m = len(mu_OOB_i)
    concordant_pairs = 0
    valid_pairs = 0

    for i in range(m):
        for iprime in range(m):
            if N_i[i] > N_iprime[iprime]:
                valid_pairs += 1
                if mu_OOB_i[i] > mu_OOB_iprime[iprime]:
                    concordant_pairs += 1

    C_index = concordant_pairs / valid_pairs if valid_pairs != 0 else 0
    return C_index

def prediction_error_rate(C_index):
    """Computes the prediction error rate given the concordance index."""
    return 1 - C_index

In [None]:
import numpy as np

# Number of unique IDs
n_ids = 5

# Generate a random number of events for each ID
events_per_id = np.random.randint(1, 6, size=n_ids)

# Generate the synthetic dataset
X = []
y = []
ids = []

for id_num, n_events in enumerate(events_per_id, 1):
    X.extend(np.random.rand(n_events, 3))  # 3 features for simplicity
    y.extend(np.random.randint(0, 2, n_events))  # Binary target values
    ids.extend([id_num] * n_events)

X = np.array(X)
y = np.array(y)
ids = np.array(ids)

X.shape, y.shape, ids.shape

In [None]:
# Generate bootstrap sample indices
sampled_indices = _generate_sample_indices(np.random.RandomState(None), np.unique(ids), 3)  # Sampling 3 unique IDs

# Expand to get the indices of all events associated with the sampled IDs
bootstrap_indices = np.where(np.isin(ids, sampled_indices))[0]

# Display the IDs for the bootstrapped events
bootstrapped_ids = ids[bootstrap_indices]

bootstrapped_ids

In [None]:
from warnings import catch_warnings, simplefilter
from sklearn.utils.class_weight import compute_sample_weight

from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

tree = DecisionTreeClassifier()

def _parallel_build_trees(
    tree,
    bootstrap,
    X,
    y,
    ids,  # New parameter: a list/array of IDs corresponding to each event in X and y
    sample_weight,
    tree_idx,
    n_trees,
    verbose=0,
    class_weight=None,
    n_ids_bootstrap=None,  # Instead of n_samples_bootstrap
):
    """
    Private function used to fit a single tree in parallel for recurrent events."""
    if verbose > 1:
        print("building tree %d of %d" % (tree_idx + 1, n_trees))

    if bootstrap:
        unique_ids = np.unique(ids)
        n_ids = len(unique_ids)

        # Generate bootstrap samples using IDs
        sampled_ids = _generate_sample_indices(
            tree.random_state, unique_ids, n_ids_bootstrap
        )

        # Expand sampled IDs to all their associated events
        indices = np.where(np.isin(ids, sampled_ids))[0]

        if sample_weight is None:
            curr_sample_weight = np.ones((X.shape[0],), dtype=np.float64)
        else:
            curr_sample_weight = sample_weight.copy()

        # Adjust the sample weight based on how many times each ID was sampled
        sample_counts_for_ids = np.bincount(np.searchsorted(unique_ids, sampled_ids), minlength=n_ids)
        curr_sample_weight *= sample_counts_for_ids[np.searchsorted(unique_ids, ids)]

        if class_weight == "subsample":
            with catch_warnings():
                simplefilter("ignore", DeprecationWarning)
                curr_sample_weight *= compute_sample_weight("auto", y, indices=indices)
        elif class_weight == "balanced_subsample":
            curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices)

        tree.fit(X[indices], y[indices], sample_weight=curr_sample_weight[indices], check_input=False)
    else:
        tree.fit(X, y, sample_weight=sample_weight, check_input=False)

    return tree

# Let's fit the tree again
_parallel_build_trees(
    tree,
    bootstrap=True,
    X=X,
    y=y,
    ids=ids,
    sample_weight=None,
    tree_idx=0,
    n_trees=1,
    verbose=1,
    class_weight=None,
    n_ids_bootstrap=3
)

# Plot the fitted tree
plt.figure(figsize=(15, 10))
plot_tree(tree, filled=True, feature_names=['Feature 1', 'Feature 2', 'Feature 3'], class_names=['Class 0', 'Class 1'])
plt.show()