<a href="https://colab.research.google.com/github/JSK2022/RandomForest-for-Recurrent-Events/blob/Thesis-Code/JS_simulation_code_230913.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install scikit-survival

In [None]:
import numpy as np
import pandas as pd


path="/Users/jeongsookim/Downloads"
data = pd.read_csv(f"{path}/simuDat.csv")

ids = data['ID'].values
time_start = data['start'].values
time_stop = data['stop'].values
event = data['event'].values
x = data[['group','x1','gender']].values

In [None]:
# Update the RisksetCounter class again with the above mentioned changes
import numpy as np
from collections import defaultdict
#from functools import lru_cache

class RisksetCounter:
    def __init__(self, ids, time_start, time_stop, event):
        """
        클래스 초기화
        중복 없는 고유한 time_stop 값을 정렬, all_unique_times에 저장, all_unique_times의 길이를 n_unique_times에 저장
        n_at_risk, n_events를 0으로 초기화, set_data 호출하여 data 설정
        """
        self.ids = ids
        self.time_start = time_start
        self.time_stop = time_stop
        self.event = event

        self.all_unique_times = np.unique(time_stop)
        self.n_unique_times = len(self.all_unique_times)

        self.n_at_risk = np.zeros(self.n_unique_times, dtype=np.int64)
        self.n_events = np.zeros(self.n_unique_times, dtype=np.int64)
        self.set_data()

    def set_data(self):
        """
        all_unique_times에 있는 각 시간에 대한 Riskset과 number of events 계산
        """
        self.all_unique_times = np.unique(self.time_stop)  # Update unique times based on current data
        self.n_unique_times = len(self.all_unique_times)
        for t_idx, t in enumerate(self.all_unique_times):
            self.n_at_risk[t_idx] = sum([self.Y_i(id_, t_idx) for id_ in set(self.ids)])
            self.n_events[t_idx] = sum([self.dN_bar_i(id_, t_idx) for id_ in set(self.ids)])

    def Y_i(self, id_, t_idx):
        time_at_t_idx = self.all_unique_times[t_idx]
        indices = (self.ids == id_) & (time_at_t_idx <= self.time_stop)
        return np.any(indices)

    def dN_bar_i(self, id_, t_idx):
        time_at_t_idx = self.all_unique_times[t_idx]
        indices = (self.ids == id_) & (time_at_t_idx == self.time_stop) & (self.event == 1)
        return np.any(indices)

    def update(self, ids, time_start, time_stop, event):
        # Initialize with the new values
        self.ids = np.array(ids)
        self.time_start = np.array(time_start)
        self.time_stop = np.array(time_stop)
        self.event = np.array(event)

        # Reset the n_at_risk and n_events arrays to zeros
        self.n_at_risk = np.zeros_like(self.all_unique_times)
        self.n_events = np.zeros_like(self.all_unique_times)

        # Update the n_at_risk and n_events arrays
        for t_idx, t in enumerate(self.all_unique_times):
            self.n_at_risk[t_idx] = sum([self.Y_i(id_, t_idx) for id_ in set(self.ids)])
            self.n_events[t_idx] = sum([self.dN_bar_i(id_, t_idx) for id_ in set(self.ids)])

    def reset(self):
        """
        n_at_risk, n_events 데이터 구조를 재설정
        """
        # Reset the data structures and clear the cache
        self.n_at_risk.fill(0)
        self.n_events.fill(0)
        #self.Y_i.cache_clear()
        #self.dN_bar_i.cache_clear()

    def copy(self):
        """
        현재 객체 복사
        """
        return RisksetCounter(self.ids.copy(), self.time_start.copy(), self.time_stop.copy(), self.event.copy())

    def __reduce__(self):
        """
        객체의 생성자와 생성자의 인수 반환
        """
        # Return a tuple of class constructor and its arguments to bypass caching
        return (self.__class__, (self.ids, self.time_start, self.time_stop, self.event))


In [None]:
import numpy as np
import pandas as pd


path="/Users/jeongsookim/Downloads"
data = pd.read_csv(f"{path}/simuDat.csv")

ids = data['ID'].values
time_start = data['start'].values
time_stop = data['stop'].values
event = data['event'].values
x = data[['group','x1','gender']].values

In [None]:
counter = RisksetCounter(ids, time_start, time_stop, event)
initial_n_at_risk = counter.n_at_risk.copy()
initial_n_events = counter.n_events.copy()


In [None]:
initial_n_at_risk ##위험 집합의 수

In [None]:
initial_n_events ## 각 위험 집합에 따른 사건 발생 횟수

In [None]:
def argbinsearch(arr, key_val):
    arr_len = len(arr)
    min_idx = 0
    max_idx = arr_len

    while min_idx < max_idx:
        mid_idx = min_idx + ((max_idx - min_idx) // 2)

        if mid_idx < 0 or mid_idx >= arr_len:
            return -1

        mid_val = arr[mid_idx]
        if mid_val <= key_val:  # Change the condition to <=
            min_idx = mid_idx + 1
        else:
            max_idx = mid_idx

    return min_idx

이 함수는 argbinsearch라는 이름의 함수로, 배열에서 주어진 키 값보다 크거나 같은 첫 번째 원소의 인덱스를 이진 탐색으로 찾아 반환합니다.

자세한 코드 설명을 아래에 제공합니다:

1. 입력:

  * arr: 탐색 대상인 정렬된 배열
  key_val: 찾고자 하는 키 값

2. 초기 변수 설정:

  * arr_len: 배열의 길이를 저장합니다.
  * min_idx: 탐색 범위의 최솟값으로, 처음에는 배열의 시작 인덱스인 0으로 설정됩니다.
  * max_idx: 탐색 범위의 최댓값으로, 처음에는 배열의 길이로 설정됩니다.

3. 이진 탐색:

  * while 루프를 사용하여 min_idx가 max_idx보다 작은 동안 탐색을 반복합니다.
  * mid_idx: 현재 탐색 범위의 중간 인덱스를 계산합니다.
  * mid_val: 중간 인덱스에 해당하는 배열의 원소 값을 가져옵니다.

4. 키 값과 중간 값을 비교합니다:
  * 만약 중간 값이 키 값보다 작거나 같으면, min_idx를 mid_idx + 1로 업데이트합니다. 이렇게 하면 탐색 범위의 왼쪽 부분을 제외하게 됩니다.
  * 그렇지 않으면, max_idx를 mid_idx로 업데이트합니다. 이렇게 하면 탐색 범위의 오른쪽 부분을 제외하게 됩니다.

5. 결과 반환:

  * 루프가 종료되면, min_idx는 키 값보다 크거나 같은 첫 번째 원소의 인덱스를 가리키게 됩니다. 따라서 min_idx를 반환합니다.

이 함수는 정렬된 배열에서 주어진 키 값보다 크거나 같은 첫 번째 원소의 위치를 효율적으로 찾기 위해 사용됩니다. 이진 탐색은 배열의 중간 값을 반복적으로 확인하면서 탐색 범위를 절반씩 줄여나가므로, 큰 배열에서도 빠르게 원하는 값을 찾을 수 있습니다.

In [None]:
# 필요한 라이브러리 및 함수 임포트
import numpy as np

class PseudoScoreCriterion:
    def __init__(self, n_outputs, n_samples, unique_times, x, ids, time_start, time_stop, event):
        """
        Constructor of the class
        Initialize instance variables using the provided input parameters
        Objects 'riskset_left', 'riskset_right', and 'riskset_total' are initialized using the 'RisksetCounter' class
        """
        self.n_outputs = n_outputs
        self.n_samples = n_samples
        self.unique_times = unique_times
        self.x = x
        self.ids = ids
        self.time_start = time_start
        self.time_stop = time_stop
        self.event = event

        self.unique_ids = set(self.ids)  # Store unique ids for later use
        self.unique_times = unique_times


        self.riskset_left = RisksetCounter(ids, time_start, time_stop, event)
        self.riskset_right = RisksetCounter(ids, time_start, time_stop, event)
        self.riskset_total = RisksetCounter(ids, time_start, time_stop, event)

        self.samples_time_idx = np.searchsorted(unique_times, time_stop)

        self.split_pos = 0
        self.split_time_idx = 0

    def init(self, y, sample_weight, n_samples, samples, start, end):
        """
        Initialization function
        Reset the risk set counters ('riskset_left','riskset_right','riskset_total') and updates 'riskset_total' with new data
        """
        self.samples = samples
        self.riskset_left.reset()
        self.riskset_right.reset()
        self.riskset_total.reset()

        time_starts, stop_times, events = y[:, 0], y[:, 1], y[:, 2]
        ids_for_update = [self.ids[idx] for idx in samples[start:end]]
        time_starts_for_update = [time_starts[idx] for idx in samples[start:end]]
        stop_times_for_update = [stop_times[idx] for idx in samples[start:end]]
        events_for_update = [events[idx] for idx in samples[start:end]]

        self.riskset_total.update(ids_for_update, time_starts_for_update, stop_times_for_update, events_for_update)

    def set_unique_times(self, unique_times):
        """Sets the unique times for the current node."""
        self.unique_times = unique_times

    def update(self, split_count):
        """
        Update the criterion based on a specified sample count.
        This will split the data into left and right nodes based on the provided split count.

        Parameters:
            - split_count: The number of samples to be allocated to the left node.
        """
        # Reset the riskset counters for the left and right nodes
        self.riskset_left.reset()
        self.riskset_right.reset()

        # Create empty lists to store the ids, start times, stop times, and events for both left and right splits
        ids_left, start_left, stop_left, event_left = [], [], [], []
        ids_right, start_right, stop_right, event_right = [], [], [], []

        # Create a set to keep track of IDs that are already assigned (to ensure an ID is only in one node)
        assigned_ids = set()

        # For each unique ID, decide whether to assign it to the left or right node based on split_count
        for id_ in self.unique_ids:
            id_indices = np.where(self.ids == id_)[0]  # Get all indices for this ID

            if len(ids_left) < split_count:
                ids_left.extend([self.ids[i] for i in id_indices])
                start_left.extend([self.time_start[i] for i in id_indices])
                stop_left.extend([self.time_stop[i] for i in id_indices])
                event_left.extend([self.event[i] for i in id_indices])
            else:
                ids_right.extend([self.ids[i] for i in id_indices])
                start_right.extend([self.time_start[i] for i in id_indices])
                stop_right.extend([self.time_stop[i] for i in id_indices])
                event_right.extend([self.event[i] for i in id_indices])

        # Update the risk sets for the left and right nodes
        self.riskset_left.update(ids_left, start_left, stop_left, event_left)
        self.riskset_right.update(ids_right, start_right, stop_right, event_right)


    #@lru_cache(maxsize=None)
        """
        Functions returning the risk set value and event value for the given ID and time index from the respective risk set (left or right)
        """
    def Y_left_value(self, id_, t):
        return self.riskset_left.Y_i(id_, t)

    #@lru_cache(maxsize=None)
    def Y_right_value(self, id_, t):
        return self.riskset_right.Y_i(id_, t)

    #@lru_cache(maxsize=None)
    def dN_bar_left_value(self, id_, t):
        return self.riskset_left.dN_bar_i(id_, t)

    #@lru_cache(maxsize=None)
    def dN_bar_right_value(self, id_, t):
        return self.riskset_right.dN_bar_i(id_, t)

    def calculate_variance_estimate(self):
        """
        Functions to compute the variance estimate for the split
        """
        left_n_at_risk = self.riskset_left.n_at_risk + 1e-7
        right_n_at_risk = self.riskset_right.n_at_risk + 1e-7

        w = (left_n_at_risk * right_n_at_risk) / (left_n_at_risk + right_n_at_risk)

        # Expand w and n_at_risk arrays to match the size of Y_left and Y_right
        w_expanded = np.tile(w, len(self.unique_ids))
        left_n_at_risk_expanded = np.tile(left_n_at_risk, len(self.unique_ids))
        right_n_at_risk_expanded = np.tile(right_n_at_risk, len(self.unique_ids))

        Y_left, Y_right, term_left, term_right = [], [], [], []

        for id_ in self.unique_ids:
            for t in range(self.riskset_left.n_unique_times):
                Y_left_val = self.Y_left_value(id_, t)
                Y_right_val = self.Y_right_value(id_, t)

                dN_bar_left_val = self.dN_bar_left_value(id_, t)
                dN_bar_right_val = self.dN_bar_right_value(id_, t)

                term_left_val = (dN_bar_left_val - (self.riskset_left.n_events[t] / left_n_at_risk[t])) ** 2
                term_right_val = (dN_bar_right_val - (self.riskset_right.n_events[t] / right_n_at_risk[t])) ** 2

                Y_left.append(Y_left_val)
                Y_right.append(Y_right_val)
                term_left.append(term_left_val)
                term_right.append(term_right_val)

        Y_left = np.array(Y_left)
        Y_right = np.array(Y_right)
        term_left = np.array(term_left)
        term_right = np.array(term_right)

        var_estimate_L = np.sum(w_expanded * (Y_left / left_n_at_risk_expanded) * term_left)
        var_estimate_R = np.sum(w_expanded * (Y_right / right_n_at_risk_expanded) * term_right)

        return var_estimate_L + var_estimate_R


    def proxy_impurity_improvement(self):
        """
        Functions that calculates the pseudo impurity improvement of the split
        This value represents the reduction in pseudo impurity in the risk sets after the split
        """
        left_n_at_risk = self.riskset_left.n_at_risk + 1e-7
        right_n_at_risk = self.riskset_right.n_at_risk + 1e-7

        w = (left_n_at_risk * right_n_at_risk) / (left_n_at_risk + right_n_at_risk)
        term = (self.riskset_left.n_events / left_n_at_risk) - (self.riskset_right.n_events / right_n_at_risk)
        numer = np.sum(w * term)
        var_estimate = self.calculate_variance_estimate()

        return numer / (np.sqrt(var_estimate) + 1e-7)

    def node_value(self):
        """
        Returns the expected risk value of the node
        """
        total_n_at_risk = self.riskset_left.n_at_risk + self.riskset_right.n_at_risk + 1e-7
        return np.cumsum(self.riskset_left.n_events + self.riskset_right.n_events) / total_n_at_risk

    def reset(self):
        """
        Functions to reset all risk set counters
        """
        self.riskset_total.reset()
        self.riskset_left.reset()
        self.riskset_right.reset()
        #self.Y_left_value.cache_clear()
        #self.Y_right_value.cache_clear()
        #self.dN_bar_left_value.cache_clear()
        #self.dN_bar_right_value.cache_clear()

    def copy(self):
        """
        Creates and returns a copy of the current object.
        """
        new_criterion = PseudoScoreCriterion(self.n_outputs, self.n_samples, self.unique_times,
                                                     self.x, self.ids, self.time_start, self.time_stop,
                                                     self.event)
        new_criterion.riskset_left = self.riskset_left.copy()
        new_criterion.riskset_right = self.riskset_right.copy()
        new_criterion.riskset_total = self.riskset_total.copy()
        new_criterion.samples_time_idx = self.samples_time_idx.copy()
        if hasattr(self, 'samples'):
            new_criterion.samples = self.samples.copy()

        return new_criterion

# 주어진 코드를 기반으로 수정된 PseudoScoreCriterion 클래스를 정의하였습니다.


In [None]:
ids_test = data['ID'].values
time_start_test = data['start'].values
time_stop_test = data['stop'].values
event_test = data['event'].values
x_test = data['gender'].values
n_samples_test = 500

In [None]:
# 2. PseudoScoreCriterion 객체 초기화 (수정된 인터페이스 사용)
criterion_test = PseudoScoreCriterion(n_outputs=1, n_samples=n_samples_test, x=x_test,
                                      unique_times=np.unique(time_stop_test),
                                      ids=ids_test, time_start=time_start_test,
                                      time_stop=time_stop_test, event=event_test)


In [None]:
y_test = np.column_stack((time_start_test, time_stop_test, event_test))
sample_weight_test = np.ones(n_samples_test)
sample_indices_test = np.arange(n_samples_test)
weighted_n_samples_test = sum(sample_weight_test)
criterion_test.init(y_test, sample_weight_test, weighted_n_samples_test, sample_indices_test, 0, n_samples_test)


In [None]:
split_pos_test = 62
criterion_test.update(split_pos_test)

In [None]:
len_left = criterion_test.riskset_left
len_right = criterion_test.riskset_right

len_left, len_right

In [None]:
criterion_test.riskset_left.n_events

In [None]:
criterion_test.riskset_right.n_events

In [None]:
len(criterion_test.riskset_left.n_at_risk)

In [None]:
len(criterion_test.riskset_right.n_at_risk)

In [None]:
criterion_test.riskset_left.all_unique_times

In [None]:
criterion_test.riskset_left.n_at_risk

In [None]:
criterion_test.riskset_right.n_at_risk

In [None]:
criterion_test.proxy_impurity_improvement()

In [None]:
import pandas as pd
import numpy as np
from sklearn.utils import check_random_state

def check_random_state(seed):
    """
    Ensures a consistent random state based on the given 'seed'
    If the seed is 'None', an integer, or an instance of 'np.integer', a new random state is created.
    If the seed is an instance of 'np.random.RandomState', it's returned as is.
    Otherwise, a 'ValueError' is raised.
    """
    if seed is None or isinstance(seed, (int, np.integer)):
        return np.random.RandomState(seed)
    elif isinstance(seed, np.random.RandomState):
        return seed
    else:
        raise ValueError("seed must be None, int or np.random.RandomState")

class PseudoScoreTreeBuilder:
    """
    Class designed to build a decision tree based on the pseudo-score test statistics criterion, typically used in recurrent events data analysis.
    """
    TREE_UNDEFINED = -1  # Placeholder

    def __init__(self, max_depth=None, min_samples_split=2, min_samples_leaf=1,
                 max_features=None, max_thresholds=None, min_impurity_decrease=0,
                 random_state=None):
        """
        Constructor of the class
        Initializes the hyperparameters and settings of the tree, such as 'max_depth','min_samples_split','max_features', and the others.
        The 'random_state' is checked and stored.
        """
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.max_features = max_features
        self.max_thresholds = max_thresholds
        self.min_impurity_decrease = min_impurity_decrease
        self.random_state = check_random_state(random_state)

    def split_indices(self, X_column, threshold):
        """Efficiently splits the data based on the given threshold for a specific feature column (X_column)."""
        return np.where(X_column <= threshold)[0], np.where(X_column > threshold)[0]

    def _split(self, X, criterion, start, end):
        """
        Finds the best feature and threshold to split on for the data in the node defined by the range [start, end].
        Iterates over features and possible thresholds to determine the best split based on the pseudo-score test statistics criterion.
        Returns the feature index, threshold, and improvement of the best split.
        """
        best_split = {
            'feature_index': None,
            'threshold': None,
            'improvement': -np.inf
        }
        n_features = X.shape[1]

        if self.max_features is None:
            features_to_consider = np.arange(n_features)
        else:
            features_to_consider = self.random_state.choice(n_features, self.max_features, replace=False)

        for feature_index in features_to_consider:
            sorted_indices = np.argsort(X[start:end, feature_index])
            X_sorted = X[start:end][sorted_indices]
            unique_thresholds = np.unique(X_sorted[:, feature_index])

            # Modify the handling for max_thresholds
            if self.max_thresholds is not None:
                n_thresholds = len(unique_thresholds)
                if isinstance(self.max_thresholds, float):
                    n_sample_thresholds = int(n_thresholds * self.max_thresholds)
                else:
                    n_sample_thresholds = self.max_thresholds

                if n_thresholds > n_sample_thresholds:
                    unique_thresholds = self.random_state.choice(unique_thresholds, n_sample_thresholds, replace=False)

            for threshold in unique_thresholds:
                new_pos = np.searchsorted(X_sorted[:, feature_index], threshold, side='right')
                criterion.update(new_pos)  # <-- 여기를 수정함
                improvement = criterion.proxy_impurity_improvement()

                if improvement > best_split['improvement']:
                    best_split = {
                        'feature_index': feature_index,
                        'threshold': threshold,
                        'improvement': improvement
                    }

                if improvement < self.min_impurity_decrease:
                    break

        return best_split

    def _build(self, X, y, criterion, depth=0, start=0, end=None):
        """
        Recursively builds the decision tree.
        If the current node meets the termination criteria (e.g., maximum depth, minimum samples in the node), it returns a terminal node.
        Otherwise, it finds the best split for the current node, splits the data accordingly, and recursively constructs the left and right subtrees.
        Returns a dictionary representing the node and its children.
        """
        n_samples = X.shape[0]
        if end is None:
            end = n_samples

        # Conditions for terminal node
        node_time_start = y[start:end, 0]
        node_time_stop = y[start:end, 1]
        unique_times = np.unique(np.concatenate([node_time_start, node_time_stop]))
        criterion.set_unique_times(unique_times)  # Update the criterion with the unique times of this node

        node_value = criterion.node_value()

        if depth == self.max_depth or (end - start) <= self.min_samples_leaf or (end - start) < self.min_samples_split:
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value,
                'unique_times': unique_times.tolist()  # Store the unique times of the node
           }

        # Initialize the criterion with the samples in the current node
        criterion.init(y, None, n_samples, np.arange(start, end), start, end)

        # Find the best split
        best_split = self._split(X, criterion, start, end)
        if best_split['improvement'] == -np.inf:
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value,
                'unique_times': unique_times.tolist()  # Store the unique times of the node
            }

        # Split the data based on the best split
        left_indices = np.where(X[start:end, best_split['feature_index']] <= best_split['threshold'])[0]
        right_indices = np.where(X[start:end, best_split['feature_index']] > best_split['threshold'])[0]

        # Recursively build the left and right subtrees
        left_child = self._build(X[left_indices], y[left_indices], criterion, depth=depth+1)
        right_child = self._build(X[right_indices], y[right_indices], criterion, depth=depth+1)

        return {
            'feature': best_split['feature_index'],
            'threshold': best_split['threshold'],
            'left_child': left_child,
            'right_child': right_child,
            'node_value': node_value,
            'unique_times': unique_times.tolist()  # Store the unique times of the node
        }

    def build(self, X, ids, time_start, time_stop, event):
        """
        The main method to invoke the tree building process.
        Initializes the pseudo-likelihood criterion using the input data and constructs the tree using the _build method.
        Finally, converts the resulting tree dictionary into a pandas DataFrame and returns it.
        """
        n_samples, n_features = X.shape
        y = np.c_[time_start, time_stop, event]

        unique_times = np.unique(np.concatenate([time_start, time_stop]))
        criterion = PseudoScoreCriterion(n_outputs=n_features, n_samples=n_samples,
                                         unique_times=unique_times, x=X, ids=ids,
                                         time_start=time_start, time_stop=time_stop, event=event)

        # Adjusting the samples_time_idx value based on PseudoScoreCriterion logic.
        for i in range(n_samples - 1):  # Adjusted the range to prevent IndexError
            criterion.samples_time_idx[i] = np.searchsorted(unique_times, time_stop[i])

        # Build the tree
        tree = self._build(X, y, criterion)

        # Convert tree dictionary to dataframe for consistency
        tree_df = pd.DataFrame([tree])
        return tree_df
# Since the PseudoScoreTreeBuilder is not directly testable (it's dependent on the state of the object),
# we will assume the refactoring is correct.



In [None]:
data

In [None]:
ids = data['ID'].values
time_start = data['start'].values
time_stop = data['stop'].values
event = data['event'].values
x = data[['group','x1','gender']].values

In [None]:
tree_builder=PseudoScoreTreeBuilder(max_depth=3, min_samples_leaf=5, max_thresholds=0.5, min_impurity_decrease=0.5, random_state=1190)
tree_df = tree_builder.build(x, ids, time_start, time_stop, event)

# Display the tree dataframe
tree_df

In [None]:
import warnings
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix

class RecurrentTree:
    def __init__(self, max_depth=None, min_samples_split=2, min_samples_leaf=1,
                 max_features=None, max_thresholds=None, min_impurity_decrease=0,
                 random_state=None):
        """
        Constructor of the class
        Initializes the tree's hyperparameters and settings
        """
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.max_features = max_features
        self.max_thresholds = max_thresholds
        self.min_impurity_decrease = min_impurity_decrease
        self.random_state = random_state
        self.tree_ = None

    def fit(self, X, ids, time_start, time_stop, event):
        """
        Trains the recurrent tree using the input data
        """
        # Ensure input is in the expected format
        X = np.array(X)
        ids = np.array(ids)
        time_start = np.array(time_start)
        time_stop = np.array(time_stop)
        event = np.array(event)

        # Use the PseudoScoreTreeBuilder to build the tree
        builder = PseudoScoreTreeBuilder(
            max_depth=self.max_depth,
            min_samples_split=self.min_samples_split,
            min_samples_leaf=self.min_samples_leaf,
            max_features=self.max_features,
            max_thresholds=self.max_thresholds,
            min_impurity_decrease=self.min_impurity_decrease,
            random_state=self.random_state
        )
        self.tree_ = builder.build(X, ids=ids, time_start=time_start, time_stop=time_stop, event=event).iloc[0]
        return self

    def get_tree(self):
        """Return the tree as a dictionary."""
        return self.tree_

    def _traverse_tree(self, x, node):
        """Traverse the tree to find the terminal node for a given sample."""
        if node["threshold"] is None:
            return node
        if x[node["feature"]] <= node["threshold"]:
            return self._traverse_tree(x, node["left_child"])
        else:
            return self._traverse_tree(x, node["right_child"])

    def _predict_for_group(self, X_group):
        """Make predictions for a group of rows representing a single ID."""
        # Assume all rows in X_group represent the same individual and take the first row to traverse the tree
        x = X_group[0]
        terminal_node = self._traverse_tree(x, self.tree_)

        # Extract the cumulative hazard (mean function) for the terminal node
        mean_function = terminal_node.get('node_value', np.array([]))
        if not mean_function.size:
            warnings.warn("No function found. Using a zero array as a placeholder.")
            mean_function = np.zeros_like(self.tree_['node_value'])

        # Compute the rate function as the difference in consecutive values of the cumulative hazard
        rate_function = np.diff(mean_function, prepend=0)

        return rate_function, mean_function

    def predict_rate_function(self, X, ids):
        """
        Predict the nonparametric estimates of dμ(t) = ρ(t)dt for given samples.
        """
        X = np.array(X)
        ids = np.array(ids)

        unique_ids = np.unique(ids)
        rate_functions = []

        used_indices = []

        for uid in unique_ids:
            idx = np.where(ids == uid)
            X_group = X[idx]
            rate_function, _ = self._predict_for_group(X_group)
            rate_functions.append(rate_function)
            used_indices.extend(idx)

        return rate_functions, used_indices

    def predict_mean_function(self, X, ids):
        """
        Predict the Nelson-Aalen estimator of the mean function for given samples.
        """
        X = np.array(X)
        ids = np.array(ids)

        unique_ids = np.unique(ids)
        mean_functions = []

        used_indices = []

        for uid in unique_ids:
            idx = np.where(ids == uid)
            X_group = X[idx]
            _, mean_function = self._predict_for_group(X_group)
            mean_functions.append(mean_function)
            used_indices.extend(idx)

        return mean_functions, used_indices

    def apply(self, X, check_input=True):
        """Return the index of the leaf that each sample is predicted as."""
        if check_input:
            X = np.array(X, dtype=np.float32)
        return np.array([self._get_leaf_index(x, self.tree_) for x in X])

    def _get_leaf_index(self, x, node, current_index=0):
        """Traverse the tree to find the leaf index for a given sample."""
        if node["threshold"] is None:  # This is a leaf node
            return current_index
        if x[node["feature"]] <= node["threshold"]:
            return self._get_leaf_index(x, node["left_child"], current_index*2 + 1)
        else:
            return self._get_leaf_index(x, node["right_child"], current_index*2 + 2)

    def _get_decision_path(self, x, node, path):
        """Recursively build the decision path for a sample x."""
        path.append(node["id"])
        if node["threshold"] is None:
            return path
        if x[node["feature"]] <= node["threshold"]:
            return self._get_decision_path(x, node["left_child"], path)
        else:
            return self._get_decision_path(x, node["right_child"], path)

    def decision_path(self, X, check_input=True):
        """Return the decision path in the tree."""
        if check_input:
            X = np.array(X, dtype=np.float32)

        n_samples, n_nodes = len(X), self.tree_["node_count"]
        data, indices, indptr = [], [], [0]

        for x in X:
            path = self._get_decision_path(x, self.tree_, [])
            data.extend([1] * len(path))
            indices.extend(path)
            indptr.append(len(indices))

        return csr_matrix((data, indices, indptr), shape=(n_samples, n_nodes))

In [None]:
# 2. RecurrentTree 학습 및 예측
tree_model = RecurrentTree(max_depth=5, min_samples_leaf=15, max_thresholds=0.5, min_impurity_decrease=0.5, random_state=1190)
tree_model.fit(x, ids, time_start, time_stop, event)

In [None]:
%pip install graphviz

In [None]:
tree_model

In [None]:
tree_model.apply(x)

In [None]:
tree_model._predict_for_group(x)

In [None]:
tree_model.predict_mean_function(x,1)

In [None]:
tree_model.predict_mean_function(x,2)

In [None]:
tree_model.predict_mean_function(x,3)

In [None]:
import numpy as np
from numbers import Integral, Real
from sklearn.utils import check_random_state

def _get_n_samples_bootstrap(n_ids, max_samples):
    """
    Modified for recurrent events. Get the number of IDs in a bootstrap sample.
    """
    if max_samples is None:
        return n_ids

    if isinstance(max_samples, Integral):
        if max_samples > n_ids:
            msg = "`max_samples` must be <= n_ids={} but got value {}"
            raise ValueError(msg.format(n_ids, max_samples))
        return max_samples

    if isinstance(max_samples, Real):
        return max(round(n_ids * max_samples), 1)

def _generate_sample_indices(random_state, ids, n_ids_bootstrap):
    """
    Sample unique IDs and then expand to all associated events.
    """
    random_instance = check_random_state(random_state)
    sampled_ids = np.random.choice(ids, n_ids_bootstrap, replace=True)
    return sampled_ids

def _generate_unsampled_indices(random_state, ids, n_ids_bootstrap):
    """
    Determine unsampled IDs and then expand to all associated events.
    """
    unique_ids = np.unique(ids)
    sampled_ids = _generate_sample_indices(random_state, unique_ids, n_ids_bootstrap)
    unique_sampled_ids = np.unique(sampled_ids)  # Ensure sampled IDs are unique
    unsampled_ids = np.setdiff1d(unique_ids, unique_sampled_ids)

    # Expand these unsampled IDs to include all their associated events.
    unsampled_indices = np.where(np.isin(ids, unsampled_ids))[0]

    return unsampled_indices

from warnings import catch_warnings, simplefilter
from sklearn.utils.class_weight import compute_sample_weight

def _parallel_build_trees(
    tree,
    bootstrap,
    X,
    y,  # Now, y is expected to be a dictionary with 'id', 'time_start', 'time_stop', and 'event' as keys
    sample_weight,
    tree_idx,
    n_trees,
    verbose=0,
    class_weight=None,
    n_ids_bootstrap=None
):
    """
    Private function used to fit a single tree in parallel for recurrent events.
    """
    # Extract necessary data from y
    ids = y['id']
    time_start = y['time_start']
    time_stop = y['time_stop']
    event = y['event']

    if verbose > 1:
        print("building tree %d of %d" % (tree_idx + 1, n_trees))

    if bootstrap:
        unique_ids = np.unique(ids)
        n_ids = len(unique_ids)

        # Generate bootstrap samples using IDs
        sampled_ids = _generate_sample_indices(
            tree.random_state, unique_ids, n_ids_bootstrap
        )

        # Expand sampled IDs to all their associated events
        indices = np.where(np.isin(ids, sampled_ids))[0]

        if sample_weight is None:
            curr_sample_weight = np.ones((X.shape[0],), dtype=np.float64)
        else:
            curr_sample_weight = sample_weight.copy()

        # Adjust the sample weight based on how many times each ID was sampled
        sample_counts_for_ids = np.bincount(np.searchsorted(unique_ids, sampled_ids), minlength=n_ids)
        curr_sample_weight *= sample_counts_for_ids[np.searchsorted(unique_ids, ids)]

        if class_weight == "subsample":
            with catch_warnings():
                simplefilter("ignore", DeprecationWarning)
                curr_sample_weight *= compute_sample_weight("auto", event, indices=indices)
        elif class_weight == "balanced_subsample":
            curr_sample_weight *= compute_sample_weight("balanced", event, indices=indices)

        tree.fit(X[indices], {'id': ids[indices], 'time_start': time_start[indices], 'time_stop': time_stop[indices], 'event': event[indices]}, sample_weight=curr_sample_weight[indices], check_input=False)
    else:
        tree.fit(X, y, sample_weight=sample_weight, check_input=False)

    return tree



In [None]:
from sklearn.base import BaseEstimator
from sklearn.utils import check_array
from joblib import Parallel, delayed
from sklearn.utils.validation import check_is_fitted
import warnings
import numpy as np

class RecurrentRandomForest(BaseEstimator):
    """
    A Random Forest model designed for recurrent event data.
    """
    def __init__(self, n_estimators=100, max_depth=None, min_samples_split=2,
                 min_samples_leaf=1, bootstrap=True, oob_score=False, n_jobs=None,
                 random_state=None, verbose=0, warm_start=False, max_samples=None,
                 min_impurity_decrease=0.0, max_features=None):  # Add new parameters
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.bootstrap = bootstrap
        self.oob_score = oob_score
        self.n_jobs = n_jobs
        self.random_state = random_state
        self.verbose = verbose
        self.warm_start = warm_start
        self.max_samples = max_samples
        self.min_impurity_decrease = min_impurity_decrease
        self.max_features = max_features
        self.estimators_ = [self._make_estimator(random_state=i) for i in range(self.n_estimators)]

    def _make_estimator(self, random_state=None):
        """
        Constructs a new instances of the 'RecurrentTree' with the specified hyperparameters
        Allows for creating each tree with a different 'random_state' for randomness
        """
        return RecurrentTree(
            max_depth=self.max_depth,
            min_samples_split=self.min_samples_split,
            min_samples_leaf=self.min_samples_leaf,
            random_state=random_state,
            min_impurity_decrease=self.min_impurity_decrease,  # Pass the new parameter
            max_features=self.max_features  # Pass the new parameter
        )

    def fit(self, X, y, sample_weight=None):
        """
        Trains the random forest using the input data
        """
        X = self._validate_data(X, accept_sparse='csc', ensure_min_samples=2)
        ids = y['id']
        time_start = y['time_start']
        time_stop = y['time_stop']
        event = y['event']
        self.n_features_in_ = X.shape[1]

        n_samples_bootstrap = _get_n_samples_bootstrap(len(np.unique(ids)), self.max_samples)

        def _fit_tree(tree):
            if self.bootstrap:
                unique_ids = np.unique(ids)
                sampled_ids = _generate_sample_indices(tree.random_state, unique_ids, n_samples_bootstrap)
                bootstrap_indices = np.where(np.isin(ids, sampled_ids))[0]
                X_bootstrap = X[bootstrap_indices]
                ids_bootstrap = np.array(ids)[bootstrap_indices]
                time_start_bootstrap = np.array(time_start)[bootstrap_indices]
                time_stop_bootstrap = np.array(time_stop)[bootstrap_indices]
                event_bootstrap = np.array(event)[bootstrap_indices]
            else:
                X_bootstrap = X
                ids_bootstrap = ids
                time_start_bootstrap = time_start
                time_stop_bootstrap = time_stop
                event_bootstrap = event

            tree.fit(X_bootstrap, ids_bootstrap, time_start_bootstrap, time_stop_bootstrap, event_bootstrap)
            return tree

        self.estimators_ = Parallel(n_jobs=self.n_jobs)(
            delayed(_fit_tree)(tree) for tree in self.estimators_
        )

        for estimator in self.estimators_:
            if hasattr(estimator, 'riskset_counter') and estimator.riskset_counter is not None:
                estimator.riskset_counter.reset()

        if self.oob_score:
            self._set_oob_score_and_attributes(X, y)
        return self

    def _set_oob_score_and_attributes(self, X, y):
        """
        Calculates the out-of-bag (OOB) scores using the ensemble's predictions for the training data samples that were not seen during the training of a given tree.
        Also sets the 'oob_prediction_' and 'oob_score_' attributes of the class.
        """
        n_samples = X.shape[0]

        # Assuming y is a structured array with these keys.
        ids = y['id']
        time_start = y['time_start']
        time_stop = y['time_stop']
        event = y['event']

        # Calculate total number of events for each ID
        total_events = {}
        for i in range(n_samples):
            total_events[ids[i]] = total_events.get(ids[i], 0) + event[i]

        all_predictions = {uid: [] for uid in np.unique(ids)}
        n_predictions = np.zeros(n_samples)

        n_samples_bootstrap = _get_n_samples_bootstrap(len(np.unique(ids)), self.max_samples)
        max_shape = None

        for estimator in self.estimators_:
            unsampled_indices = _generate_unsampled_indices(estimator.random_state, np.unique(ids), n_samples_bootstrap)
            p_estimator_result, _ = estimator.predict_mean_function(X[unsampled_indices, :], unsampled_indices)

            # Update the max shape if necessary
            current_shape = np.array(p_estimator_result).shape
            if max_shape is None or current_shape[1] > max_shape[1]:
                max_shape = current_shape

            for idx, uid in enumerate(ids[unsampled_indices]):
                all_predictions[uid].append(p_estimator_result[idx])
                n_predictions[unsampled_indices[idx]] += 1

        # Adjust the shape of predictions
        for uid in all_predictions:
            for i in range(len(all_predictions[uid])):
                if np.array(all_predictions[uid][i]).shape[0] < max_shape[1]:
                    padding = np.zeros(max_shape[1] - np.array(all_predictions[uid][i]).shape[0])
                    all_predictions[uid][i] = np.concatenate([all_predictions[uid][i], padding])

        predictions = {}
        for uid in all_predictions:
            if all_predictions[uid]:
                predictions[uid] = np.mean(all_predictions[uid], axis=0)

        # Convert the dictionary into a list to match the original structure
        final_predictions = np.array([predictions.get(uid, np.zeros(max_shape[1])) for uid in ids])


        self.oob_prediction_ = final_predictions

        # Pass the calculated total_events to the method
        self.oob_score_ = self._estimate_recurrent_concordance_index(final_predictions, X, event, ids, total_events)

    def _estimate_recurrent_concordance_index(self, predictions, X, event, ids, total_events):
        """
        Estimate the C-index for recurrent events using OOB ensemble estimates for right-censored data.

        Parameters:
        - predictions: Predicted mean functions for all samples using OOB.
        - X: The data matrix.
        - event: Observed recurrent events for all samples.
        - ids: IDs for each event.
        - total_events: Total number of events for each ID.

        Returns:
        - C-index estimate.
        """
        unique_ids = np.unique(ids)
        n_unique_ids = len(unique_ids)
        id_to_avg_prediction = {uid: np.mean(predictions[ids == uid]) for uid in unique_ids}

        concordant_pairs = 0
        permissible_pairs = 0

        for i in range(n_unique_ids):
            for j in range(i+1, n_unique_ids):
                uid_i = unique_ids[i]
                uid_j = unique_ids[j]

                right_censored_i = event[ids == uid_i][0] < total_events[uid_i]
                right_censored_j = event[ids == uid_j][0] < total_events[uid_j]

                if not right_censored_i and not right_censored_j:  # Both are not right-censored
                    if event[ids == uid_i][0] > event[ids == uid_j][0]:
                        permissible_pairs += 1
                        if id_to_avg_prediction[uid_i] > id_to_avg_prediction[uid_j]:
                            concordant_pairs += 1
                else:  # At least one is right-censored
                    if not right_censored_i:  # i is not right-censored but j is
                        permissible_pairs += 1
                        if id_to_avg_prediction[uid_i] > id_to_avg_prediction[uid_j]:
                            concordant_pairs += 1
                    elif not right_censored_j:  # j is not right-censored but i is
                        permissible_pairs += 1
                        if id_to_avg_prediction[uid_i] < id_to_avg_prediction[uid_j]:
                            concordant_pairs += 1

        c_index = concordant_pairs / permissible_pairs if permissible_pairs > 0 else 0
        return 2 * c_index



        """Validate input data('X') to ensure it's in the correct format and meets the necessary conditions for processing."""
    def _validate_data(self, X, accept_sparse=False, ensure_min_samples=1):
        """Validate input data('X') to ensure it's in the correct format and meets the necessary conditions for processing."""
        return check_array(X, accept_sparse=accept_sparse, ensure_min_samples=ensure_min_samples)

    def _validate_X_predict(self, X):
        """Validate X whenever one tries to predict."""
        X = check_array(X)
        if X.shape[1] != self.n_features_in_:
            raise ValueError("Number of features of the model must match the input. Model n_features is {} and input n_features is {}."
                             .format(self.n_features_in_, X.shape[1]))
        return X


    def predict_rate_function(self, X, ids):
        X = self._validate_X_predict(X)
        all_predictions = {uid: [] for uid in np.unique(ids)}

        for tree in self.estimators_:
            tree_predictions = tree.apply(X)
            rate_predictions = [tree.tree_.value[node][0][0] for node in tree_predictions]

            for idx, uid in enumerate(ids):
                all_predictions[uid].append(rate_predictions[idx])

        # Average the predictions for each ID
        averaged_predictions = {uid: np.mean(all_predictions[uid]) for uid in all_predictions}
        return averaged_predictions

    def predict_mean_function(self, X, ids):
        X = self._validate_X_predict(X)
        all_predictions = {uid: [] for uid in np.unique(ids)}

        for tree in self.estimators_:
            tree_predictions = tree.apply(X)
            mean_predictions = [tree.tree_.value[node][0][1] for node in tree_predictions]

            for idx, uid in enumerate(ids):
                all_predictions[uid].append(mean_predictions[idx])

        # Average the predictions for each ID
        averaged_predictions = {uid: np.mean(all_predictions[uid]) for uid in all_predictions}
        return averaged_predictions





In [None]:
rrf = RecurrentRandomForest(n_estimators=10, max_depth=3, min_samples_leaf=5, min_impurity_decrease=0.2, random_state=42, oob_score=True, n_jobs=6)
y = {
    'id': ids,
    'time_start': time_start,
    'time_stop': time_stop,
    'event': event
}
rrf.fit(x,y)

In [None]:
rrf.predict_mean_function(X=x, ids=1)

In [None]:
rrf.predict_mean_function(X=x, ids=1)

In [None]:
rrf.oob_score_

In [None]:
"""
보류
"""

class PermutationImportance:
    def __init__(self, model, n_repeats=30, random_state=None):
        """
        Constructor of the class
        'model': The trained model for which we want to compute feature importances
        'n_repeats': Number of times to repeat the permutation for each feature to get a reliable estimate.
        'random_state': Seed for reproducibility
        """
        self.model = model
        self.n_repeats = n_repeats
        self.random_state = random_state

    def _compute_baseline_cindex(self, X, event, time_stop, ids):
        """
        Computes the baseline C-index using the original (non-permuted) data.
        The C-index is a metric to evaluate the model's predictions, especially for recurrent events.
        This function first predicts the mean function using the model, averages the predictions, and then calculates the C-index.
        """
        # Get predictions using the model
        predictions = self.model.predict_mean_function(X)

        # Take the mean of the predictions for each individual for the C-index computation
        mean_predictions = np.mean(predictions, axis=1)

        # Compute total number of events for each ID
        unique_ids, total_events = np.unique(ids, return_counts=True)
        total_events_dict = dict(zip(unique_ids, total_events))
        total_events_arr = np.array([total_events_dict[id_] for id_ in ids])

        return self._estimate_concordance_index_recurrent(time_stop, event, mean_predictions, ids, total_events_arr)

    def _estimate_concordance_index_recurrent(self, time_stop, event, predictions, ids, total_events):
        """
        A wrapper around the model's method to estimate the recurrent C-index.
        It's used for convenience and to make the code more readable.
        """
        return self.model._estimate_recurrent_concordance_index(predictions, time_stop, event, ids, total_events)

    def compute_importance(self, X, event, time_stop, ids):
        """
        The main function that computes the feature importances
        For each feature:
            It permutes (shuffles) the feature's values a number of times (specified by 'n_repeats')
            For each permutation, it calculates the drop in C-index (compared to the baseline C-index) due to the permutation
            The drop in performance (C-index) due to the permutation gives an indicartion of the feature's importance.
        It returns the computed importances matrix where each row corresponds to a feature and each column to a permutation repitition
        """
        baseline_cindex = self._compute_baseline_cindex(X, event, time_stop, ids)

        rng = check_random_state(self.random_state)

        n_features = X.shape[1]
        importances = np.zeros((n_features, self.n_repeats))

        unique_ids = np.unique(ids)

        for feature in range(n_features):
            for repeat in range(self.n_repeats):
                X_permuted = X.copy()

                permuted_ids = rng.permutation(unique_ids)  # ID를 섞습니다.

                # ID에 따라 값을 변경합니다.
                for orig_id, new_id in zip(unique_ids, permuted_ids):
                    orig_idx = np.where(ids == orig_id)[0]
                    new_idx = np.where(ids == new_id)[0]
                    X_permuted[orig_idx, feature] = X[new_idx, feature]

                # Calculate c-index for permuted X
                permuted_cindex = self._compute_baseline_cindex(X_permuted, event, time_stop, ids)

                # The importance is the drop in c-index
                importances[feature, repeat] = baseline_cindex - permuted_cindex

        return importances

    def report_importance(self, X, event, time_stop, ids):
        """
        Computes the feature importances and then calculates the mean and standard deviation of the importances for each feature across all the repeats.
        The mean gives an average measure of the importance of each feature, while the standard deviation provides an estimate of the variability or uncertainty in the importance estimates.
        It returns the mean and standard deviation of the feature importances.
        """
        importances = self.compute_importance(X, event, time_stop, ids)

        # Compute mean and std of importances
        importance_mean = np.mean(importances, axis=1)
        importance_std = np.std(importances, axis=1)

        return importance_mean, importance_std


In [None]:
permutation_importance = PermutationImportance(rrf, n_repeats=10, random_state=42)
mean_importances, std_importances = permutation_importance.report_importance(x, event, time_stop, ids)

print("Mean Importances:", mean_importances)
print("STD Importances:", std_importances)