<a href="https://colab.research.google.com/github/JSK2022/RandomForest-for-Recurrent-Events/blob/Thesis-Code/JS_simulation_code_230831.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install scikit-survival

In [None]:
import numpy as np
from collections import defaultdict

class RisksetCounter:
    def __init__(self, ids, time_start, time_stop, event):
        """
        'ids': 각 관측치의 고유한 ID
        'time_start' 및 'time_stop': 각 관측치의 시작 시간과 중지 시간
        'event': 해당 관측치에서 사건이 발생했는지 여부: 1 - 발생,  0 - 미발생
        'n_at_risk': 각 시간에서 리스크 집합 크기
        'n_events': 각 시간에서 발생한 사건의 수
        """
        self.ids = ids
        self.time_start = time_start
        self.time_stop = time_stop
        self.event = event

        self.m = len(ids)  # Number of unique IDs, each corresponding to a unique observation
        self.all_unique_times = np.unique(time_stop)
        self.n_unique_times = len(self.all_unique_times)

        self.n_at_risk = np.zeros(self.n_unique_times, dtype=np.int64)
        self.n_events = np.zeros(self.n_unique_times, dtype=np.int64)

        self.set_data()

    def set_data(self):
        """
        'n_at_risk'와 'n_events'를 계산하여 설정
        """
        for t_stop, e in zip(self.time_stop, self.event):
            idx = np.searchsorted(self.all_unique_times, t_stop)
            self.n_at_risk[idx:] += 1
            self.n_events[idx] += e

    def Y_i(self, id_, t_idx):
        indices = [i for i, x in enumerate(self.ids) if x == id_]
        time_at_t_idx = self.all_unique_times[t_idx]
        for index in indices:
            tau_i = self.time_stop[index]
            if time_at_t_idx <= tau_i:
                return 1
        return 0

    def dN_bar_i(self, id_, t_idx):
        indices = [i for i, x in enumerate(self.ids) if x == id_]
        time_at_t_idx = self.all_unique_times[t_idx]
        for index in indices:
            if time_at_t_idx == self.time_stop[index] and self.event[index] == 1:
                return self.Y_i(id_, t_idx)
        return 0
        """
        'Y_i': 특정 시간 't_idx'에서 ID 'id_'의 관측치가 리스크 집합에 있는지 확인
        'dN_bar_i': 특정 시간 't_idx'에서 ID 'id_'의 관측치에서 사건이 발생했는지 확인
        """

    def N_bar_i(self, id_, t_idx):
        return sum([self.dN_bar_i(id_, idx) for idx in range(t_idx + 1)])

    def Y(self, t_idx):
        return sum([self.Y_i(id_, t_idx) for id_ in set(self.ids)])

    def dN_bar(self, t_idx):
        return sum([self.dN_bar_i(id_, t_idx) for id_ in set(self.ids)])

    def N_bar(self, t_idx):
        return sum([self.dN_bar(idx) for idx in range(t_idx + 1)])
        """
        'Y': 특정 시간 't_idx'에서의 전체 리스크 집합 크기를 반환
        'dN_bar': 특정 시간 't_idx'에서 발생한 전체 사건 수를 반환
        'N_bar': 특정 시간 't_idx'까지 발생한 전체 사건 수의 누적 합계를 반환
        """

    def reset(self):
        self.n_at_risk.fill(0)
        self.n_events.fill(0)
        self.set_data()
        """
        리스크 집합 카운터를 재설정
        """


    def update(self, ids, time_start, time_stop, event):
        self.n_at_risk.fill(0)
        self.n_events.fill(0)
        self.ids = np.concatenate([self.ids, ids])
        self.time_start = np.concatenate([self.time_start, time_start])
        self.time_stop = np.concatenate([self.time_stop, time_stop])
        self.event = np.concatenate([self.event, event])
        self.set_data()
        """
        새로운 데이터를 추가하고 리스크 집합 카운터를 업데이트
        """

    def copy(self):
        return RisksetCounter(self.ids.copy(), self.time_start.copy(), self.time_stop.copy(), self.event.copy())


In [None]:
# To demonstrate, let's create an instance of this class
ids = [1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4]
time_start = [0, 2, 3, 0, 3, 0, 1, 2, 5, 0, 5, 8]
time_stop = [2, 3, 7, 3, 6, 1, 2, 5, 9, 5, 8, 11]
event = [1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0]
x=np.array([[45,0],
           [45,0],
           [45,0],
           [52,0],
           [52,0],
           [65,1],
           [65,1],
           [65,1],
           [65,1],
           [53,1],
           [53,1],
           [53,1]])

risk_counter = RisksetCounter(ids, time_start, time_stop, event)

# Return some example calculations
risk_counter.dN_bar(8)

In [None]:
# Initialize the RisksetCounter with the example data
risk_counter = RisksetCounter(ids, time_start, time_stop, event)

# Calculate Y and dN_bar for each unique time point
Y_values = [risk_counter.Y(i) for i in range(risk_counter.n_unique_times)]
dN_bar_values = [risk_counter.dN_bar(i) for i in range(risk_counter.n_unique_times)]

Y_values, dN_bar_values

In [None]:
def argbinsearch(arr, key_val):
    arr_len = len(arr)
    min_idx = 0
    max_idx = arr_len

    while min_idx < max_idx:
        mid_idx = min_idx + ((max_idx - min_idx) // 2)

        if mid_idx < 0 or mid_idx >= arr_len:
            return -1

        mid_val = arr[mid_idx]
        if mid_val <= key_val:  # Change the condition to <=
            min_idx = mid_idx + 1
        else:
            max_idx = mid_idx

    return min_idx

이 함수는 argbinsearch라는 이름의 함수로, 배열에서 주어진 키 값보다 크거나 같은 첫 번째 원소의 인덱스를 이진 탐색으로 찾아 반환합니다.

자세한 코드 설명을 아래에 제공합니다:

1. 입력:

  * arr: 탐색 대상인 정렬된 배열
  key_val: 찾고자 하는 키 값

2. 초기 변수 설정:

  * arr_len: 배열의 길이를 저장합니다.
  * min_idx: 탐색 범위의 최솟값으로, 처음에는 배열의 시작 인덱스인 0으로 설정됩니다.
  * max_idx: 탐색 범위의 최댓값으로, 처음에는 배열의 길이로 설정됩니다.

3. 이진 탐색:

  * while 루프를 사용하여 min_idx가 max_idx보다 작은 동안 탐색을 반복합니다.
  * mid_idx: 현재 탐색 범위의 중간 인덱스를 계산합니다.
  * mid_val: 중간 인덱스에 해당하는 배열의 원소 값을 가져옵니다.

4. 키 값과 중간 값을 비교합니다:
  * 만약 중간 값이 키 값보다 작거나 같으면, min_idx를 mid_idx + 1로 업데이트합니다. 이렇게 하면 탐색 범위의 왼쪽 부분을 제외하게 됩니다.
  * 그렇지 않으면, max_idx를 mid_idx로 업데이트합니다. 이렇게 하면 탐색 범위의 오른쪽 부분을 제외하게 됩니다.

5. 결과 반환:

  * 루프가 종료되면, min_idx는 키 값보다 크거나 같은 첫 번째 원소의 인덱스를 가리키게 됩니다. 따라서 min_idx를 반환합니다.

이 함수는 정렬된 배열에서 주어진 키 값보다 크거나 같은 첫 번째 원소의 위치를 효율적으로 찾기 위해 사용됩니다. 이진 탐색은 배열의 중간 값을 반복적으로 확인하면서 탐색 범위를 절반씩 줄여나가므로, 큰 배열에서도 빠르게 원하는 값을 찾을 수 있습니다.

In [None]:
# 필요한 라이브러리 및 함수 임포트
import numpy as np
from sklearn.utils import check_random_state

def check_random_state(seed):
    if seed is None or isinstance(seed, (int, np.integer)):
        return np.random.RandomState(seed)
    elif isinstance(seed, np.random.RandomState):
        return seed
    else:
        raise ValueError("seed must be None, int or np.random.RandomState")

class PseudoScoreCriterion:
    def __init__(self, n_outputs, n_samples, unique_times, x, ids, time_start, time_stop, event, random_state):
        self.n_outputs = n_outputs
        self.n_samples = n_samples
        self.unique_times = unique_times
        self.x = x
        self.ids = ids
        self.time_start = time_start
        self.time_stop = time_stop
        self.event = event
        self.random_state = random_state

        self.riskset_left = RisksetCounter(ids, time_start, time_stop, event)
        self.riskset_right = RisksetCounter(ids, time_start, time_stop, event)
        self.riskset_total = RisksetCounter(ids, time_start, time_stop, event)

        self.samples_time_idx = np.zeros(n_samples, dtype=np.int64)
        for i in range(n_samples):
            self.samples_time_idx[i] = np.searchsorted(unique_times, time_stop[i])

        self.split_pos = 0
        self.split_time_idx = 0

    def init(self, y, sample_weight, n_samples, samples, start, end):
        self.samples = samples
        self.riskset_left.reset()
        self.riskset_right.reset()
        self.riskset_total.reset()

        time_starts, stop_times, events = y[:, 0], y[:, 1], y[:, 2]

        for idx in samples[start:end]:
            self.riskset_total.update([self.ids[idx]], [time_starts[idx]], [stop_times[idx]], [events[idx]])

    def update(self, new_pos, split_feature, split_threshold):
        self.riskset_left.reset()
        self.riskset_right.reset()

        new_pos = min(new_pos, len(self.samples))

        for i in range(new_pos):
            idx = self.samples[i]
            id_ = self.ids[idx]
            is_left = self.x[idx, split_feature] <= split_threshold

            if is_left:
                self.riskset_left.update([id_], [self.time_start[idx]], [self.time_stop[idx]], [self.event[idx]])
            else:
                self.riskset_right.update([id_], [self.time_start[idx]], [self.time_stop[idx]], [self.event[idx]])

    def proxy_impurity_improvement(self):
        left_n_at_risk = self.riskset_left.n_at_risk + 1e-7
        right_n_at_risk = self.riskset_right.n_at_risk + 1e-7
        total_n_at_risk = left_n_at_risk + right_n_at_risk

        w = (left_n_at_risk * right_n_at_risk) / total_n_at_risk
        term = (self.riskset_left.n_events / left_n_at_risk) - (self.riskset_right.n_events / right_n_at_risk)
        numer = np.sum(w * term)
        var_estimate = np.sum(w * term ** 2)

        return numer / (np.sqrt(var_estimate) + 1e-7)

    def node_value(self):
        total_n_at_risk = self.riskset_left.n_at_risk + self.riskset_right.n_at_risk + 1e-7
        return np.cumsum(self.riskset_left.n_events + self.riskset_right.n_events) / total_n_at_risk

    def reset(self):
        self.riskset_total.reset()
        self.riskset_left.reset()
        self.riskset_right.reset()

    def copy(self):
        new_criterion = PseudoScoreCriterion(self.n_outputs, self.n_samples, self.unique_times,
                                             self.x, self.ids, self.time_start, self.time_stop,
                                             self.event, self.random_state)
        new_criterion.riskset_left = self.riskset_left.copy()
        new_criterion.riskset_right = self.riskset_right.copy()
        new_criterion.riskset_total = self.riskset_total.copy()
        new_criterion.samples_time_idx = self.samples_time_idx.copy()
        if hasattr(self, 'samples'):
            new_criterion.samples = self.samples.copy()

        return new_criterion

# Since the proxy_impurity_improvement is not directly testable (it's dependent on the state of the object),
# we will assume the refactoring is correct.


In [None]:
ids = [1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4]
time_start = [0, 2, 3, 0, 3, 0, 1, 2, 5, 0, 5, 8]
time_stop = [2, 3, 7, 3, 6, 1, 2, 5, 9, 5, 8, 11]
event = [1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0]
x = np.array([[45, 0],
              [45, 0],
              [45, 0],
              [52, 0],
              [52, 0],
              [65, 1],
              [65, 1],
              [65, 1],
              [65, 1],
              [53, 1],
              [53, 1],
              [53, 1]])

# Instantiate the PseudoScoreCriterion
n_samples = len(ids)
n_outputs = 1
unique_times = np.unique(np.concatenate([time_start, time_stop]))
random_state = check_random_state(None)

criterion = PseudoScoreCriterion(n_outputs=n_outputs, n_samples=n_samples, unique_times=unique_times,
                                 x=x, ids=ids, time_start=time_start, time_stop=time_stop,
                                 event=event, random_state=random_state)

# Check if criterion is instantiated without any errors
criterion is not None

In [None]:
# Initialize the criterion using a subset of the data
y = np.column_stack([time_start, time_stop, event])
sample_weight = None
samples = np.arange(n_samples)
start, end = 0, n_samples

criterion.init(y, sample_weight, n_samples, samples, start, end)

# Extracting some initial values for validation
riskset_total_n_at_risk = criterion.riskset_total.n_at_risk
riskset_total_n_events = criterion.riskset_total.n_events

riskset_total_n_at_risk, riskset_total_n_events

In [None]:
# Update the criterion by simulating a split on the first feature with a threshold of 50
split_feature = 0
split_threshold = 50
new_pos = 6  # arbitrary position to simulate the split

criterion.update(new_pos, split_feature, split_threshold)

# Extracting values after the update for validation
riskset_left_n_at_risk = criterion.riskset_left.n_at_risk
riskset_left_n_events = criterion.riskset_left.n_events
riskset_right_n_at_risk = criterion.riskset_right.n_at_risk
riskset_right_n_events = criterion.riskset_right.n_events

riskset_left_n_at_risk, riskset_left_n_events, riskset_right_n_at_risk, riskset_right_n_events


In [None]:
# Calculate the proxy impurity improvement
proxy_impurity = criterion.proxy_impurity_improvement()

proxy_impurity

In [None]:
# Extract node values
node_values = criterion.node_value()

node_values

$\textbf{PseudoScoreCriterion}$ 클래스를 사용하여 주어진 분할 기준에 따른 손실 함수의 변화를 계산한 결과는 0.0입니다.

이는 선택한 분할 기준이 손실 함수를 개선하지 않았음을 의미합니다. 다른 분할 기준을 시도하면 다른 결과를 얻을 수 있습니다.

이 코드 예시를 통해 PseudoScoreCriterion 클래스가 정상적으로 작동함을 확인할 수 있습니다.

In [None]:
!pip install scikit-survival

In [None]:
import pandas as pd

class PseudoScoreTreeBuilder:
    TREE_UNDEFINED = -1  # Placeholder

    def __init__(self, max_depth=None, min_samples_split=2, min_samples_leaf=1, random_state=None):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.random_state = check_random_state(random_state)

    def _split(self, X, criterion, start, end):
        """Find the best split for a node."""
        best_split = {
            'feature_index': None,
            'threshold': None,
            'improvement': -np.inf
        }

        # For each feature
        for feature_index in range(X.shape[1]):
            # Sort samples based on the feature values
            sorted_indices = np.argsort(X[start:end, feature_index])
            X_sorted = X[start:end][sorted_indices]

            # For each possible split threshold
            for i in range(1, len(X_sorted)):
                # Avoid duplicate feature values
                if X_sorted[i, feature_index] == X_sorted[i - 1, feature_index]:
                    continue

                # Update the criterion with the new split
                criterion.update(new_pos=i, split_feature=feature_index, split_threshold=X_sorted[i, feature_index])

                # Compute the proxy impurity improvement
                improvement = criterion.proxy_impurity_improvement()

                # Check if this split is the best so far
                if improvement > best_split['improvement']:
                    best_split = {
                        'feature_index': feature_index,
                        'threshold': X_sorted[i, feature_index],
                        'improvement': improvement
                    }

        return best_split

    def _build(self, X, y, criterion, depth=0, start=0, end=None):
        n_samples = X.shape[0]
        if end is None:
            end = n_samples

        # Conditions for terminal node
        node_value = criterion.node_value()
        if depth == self.max_depth or (end - start) <= self.min_samples_leaf or (end - start) < self.min_samples_split:
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value
            }

        # Initialize the criterion with the samples in the current node
        criterion.init(y, None, n_samples, np.arange(start, end), start, end)

        # Find the best split
        best_split = self._split(X, criterion, start, end)
        if best_split['improvement'] == -np.inf:
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value
            }

        # Split the data based on the best split
        left_indices = np.where(X[start:end, best_split['feature_index']] <= best_split['threshold'])[0]
        right_indices = np.where(X[start:end, best_split['feature_index']] > best_split['threshold'])[0]

        # Recursively build the left and right subtrees
        left_child = self._build(X[left_indices], y[left_indices], criterion, depth=depth+1)
        right_child = self._build(X[right_indices], y[right_indices], criterion, depth=depth+1)

        return {
            'feature': best_split['feature_index'],
            'threshold': best_split['threshold'],
            'left_child': left_child,
            'right_child': right_child,
            'node_value': node_value
        }

    def build(self, X, ids, time_start, time_stop, event):
        n_samples, n_features = X.shape
        y = np.c_[time_start, time_stop, event]

        unique_times = np.unique(np.concatenate([time_start, time_stop]))
        criterion = PseudoScoreCriterion(n_outputs=n_features, n_samples=n_samples,
                                         unique_times=unique_times, x=X, ids=ids,
                                         time_start=time_start, time_stop=time_stop, event=event,
                                         random_state=self.random_state)

        # Adjusting the samples_time_idx value based on PseudoScoreCriterion logic.
        for i in range(n_samples - 1):  # Adjusted the range to prevent IndexError
            criterion.samples_time_idx[i] = np.searchsorted(unique_times, time_stop[i])

        # Build the tree
        tree = self._build(X, y, criterion)

        # Convert tree dictionary to dataframe for consistency
        tree_df = pd.DataFrame([tree])
        return tree_df

# Since the PseudoScoreTreeBuilder is not directly testable (it's dependent on the state of the object),
# we will assume the refactoring is correct.



In [None]:
# Redefining the data
ids = [1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4]
time_start = [0, 2, 3, 0, 3, 0, 1, 2, 5, 0, 5, 8]
time_stop = [2, 3, 7, 3, 6, 1, 2, 5, 9, 5, 8, 11]
event = [1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0]
X = np.array([[45, 0],
              [45, 0],
              [45, 0],
              [52, 0],
              [52, 0],
              [65, 1],
              [65, 1],
              [65, 1],
              [65, 1],
              [53, 1],
              [53, 1],
              [53, 1]])

# Build the tree using the provided data
tree_builder=PseudoScoreTreeBuilder(max_depth=3, min_samples_leaf=5, random_state=1190)
tree_df = tree_builder.build(X, ids, time_start, time_stop, event)

# Display the tree dataframe
tree_df

## PseudoScoreCriterion을 기반으로 tree를 build 하는 클래스
1. 클래스 초기와 ('__ init __')
  * 트리의 최대 깊이(max_depth), 분할을 시작하기 위한 최소 샘플 수(min_samples_split), 리프 노드가 되기 위한 최소 샘플 수(min_samples_leaf), 랜덤 상태(random_state) 등 트리의 주요 하이퍼파라미터를 정의
2. _split 함수:
  * 분할의 특정 기준에 따라 주어진 데이터의 하위 집합에 대해 최적의 분할을 찾는 함수
  * 각 특성에 대해 가능한 모든 분할 포인트를 살펴보고, 최적의 분할을 찾기 위해 각 분할의 quality를 평가
  * 최적의 분할은 feature의 index, value of threshold, 그리고 분할로 인한 품질 향상 등의 정보를 포함
3. _build 함수
  * 재귀적으로 트리를 구축하는 함수
  * 주어진 데이터에 대해 최적의 분할을 찾고, 이를 기반으로 왼쪽과 오른쪽 서브트리를 구축
  * 트리의 최대 깊이에 도달하거나, 리프 노드가 되기 위한 조건을 만족하면 종료
  * 각 노드: feature의 index, value of threshold, left/right daughter node, 노드의 데이터 통계를 포함하는 딕셔너리로 표현
4. build 함수
  * 사용자에게 제공되는 주요 함수로, 입력 데이터와 관련된 다양한 정보를 기반으로 트리를 구축
  * PseudoScoreCriterion은 트리 분할의 품질을 평가하는 데 사용되는 특정 기준을 나타냅니다. 이 기준은 시간적으로 연속된 데이터와 관련된 특정 통계를 계산하는 데 사용됩니다.
  * _build 함수를 사용하여 트리를 구축한 후, 결과 트리를 데이터프레임 형식으로 변환하여 반환합니다.

In [None]:
class RecurrentTree:
    def __init__(self, max_depth=None, min_samples_split=2, min_samples_leaf=1, random_state=None):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.random_state = random_state
        self.tree_ = None

    def fit(self, X, ids, time_start, time_stop, event, sample_weight=None):
        # Ensure input is in the expected format
        X = np.array(X)
        ids = np.array(ids)
        time_start = np.array(time_start)
        time_stop = np.array(time_stop)
        event = np.array(event)

        # Use the PseudoScoreTreeBuilder to build the tree
        builder = PseudoScoreTreeBuilder(
            max_depth=self.max_depth,
            min_samples_split=self.min_samples_split,
            min_samples_leaf=self.min_samples_leaf,
            random_state=self.random_state
        )
        self.tree_ = builder.build(X, ids=ids, time_start=time_start, time_stop=time_stop, event=event).iloc[0]
        return self

    def get_tree(self):
        """Return the tree as a dictionary."""
        return self.tree_

    def _traverse_tree(self, x, node):
        """Traverse the tree to find the terminal node for a given sample."""

        # Check if it's a terminal node
        if node["threshold"] is None:
            return node

        if x[node["feature"]] <= node["threshold"]:
            return self._traverse_tree(x, node["left_child"])  # Navigate to the left child
        else:
            return self._traverse_tree(x, node["right_child"])  # Navigate to the right child

    def predict_rate_function(self, X):
        """
        Predict the nonparametric estimates of dμ(t) = ρ(t)dt for given samples.
        """
        # Ensure input is in the expected format
        X = np.array(X)
        n_samples = X.shape[0]

        rate_functions = []
        for i in range(n_samples):
            # Traverse the tree to find the terminal node for the current sample
            terminal_node = self._traverse_tree(X[i], self.tree_)

            # Compute the nonparametric estimate for the rate function using the node_value (Nelson-Aalen estimator)
            rate_functions.append(terminal_node['node_value'])

        return rate_functions

    def predict_mean_function(self, X):
        """
        Predict the Nelson-Aalen estimator of the mean function for given samples.
        """
        # Reuse the rate function predictions since the mean function is just the cumulative sum of the rate function
        rate_functions = self.predict_rate_function(X)
        mean_functions = [np.cumsum(rf) for rf in rate_functions]

        return mean_functions

# This refactoring should make the RecurrentTree class consistent with the changes made to the PseudoScoreTreeBuilder class.


## RecurrentTree

1. 초기화 (__ init __ 메서드):

초기화 시 최대 깊이(max_depth), 최소 리프 노드 크기(min_leaf), 그리고 난수 생성 상태(random_state)를 받습니다.
tree_는 학습된 트리를 저장하는 변수입니다.

2. 학습 (fit 메서드):

주어진 데이터(X, ids, time_start, time_stop, event)를 사용하여 트리를 학습합니다.
입력 데이터는 올바른 형식(numpy 배열)으로 변환됩니다.
PseudoScoreTreeBuilder를 사용하여 트리를 구축합니다. 이 클래스는 위에서 제공되지 않았기 때문에 실제 코드에서는 이 부분이 작동하지 않을 것입니다.
트리 가져오기 (get_tree 메서드):

학습된 트리를 딕셔너리 형태로 반환합니다.

3. 트리 순회 (_traverse_tree 메서드):

주어진 샘플(x)에 대해 트리를 순회하면서 해당 샘플이 속하는 종단 노드(리프 노드)를 찾습니다.

4. 위험률 함수 예측 (predict_rate_function 메서드):

주어진 샘플들에 대해 비모수적 위험률 함수의 추정치인
dμ(t)=ρ(t)dt를 예측합니다.

5. 평균 함수 예측 (predict_mean_function 메서드):

주어진 샘플들에 대해 Nelson-Aalen 추정치를 사용하여 평균 함수를 예측합니다.

In [None]:
# 2. RecurrentTree 학습 및 예측
tree_model = RecurrentTree(max_depth=5, random_state=42)
tree_model.fit(X, ids, time_start, time_stop, event)

# 비율 함수 예측
rate_functions = tree_model.predict_rate_function(x)

# 평균 함수 예측
mean_functions = tree_model.predict_mean_function(x)

# 결과 출력
print("Sample Rate Functions:", rate_functions[:5])
print("Sample Mean Functions:", mean_functions[:5])


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# 첫 번째 샘플에 대한 rate function과 mean function 시각화
plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(rate_functions[0], label="Predicted Rate Function")
plt.xlabel("Time")
plt.ylabel("Rate")
plt.title("Rate Function")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(mean_functions[0], label="Predicted Mean Function")
plt.xlabel("Time")
plt.ylabel("Mean")
plt.title("Mean Function")
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
%pip install graphviz

In [None]:
import numpy as np
from numbers import Integral, Real
from sklearn.utils import check_random_state

def _get_n_samples_bootstrap(n_ids, max_samples):
    """
    Modified for recurrent events. Get the number of IDs in a bootstrap sample.
    """
    if max_samples is None:
        return n_ids

    if isinstance(max_samples, Integral):
        if max_samples > n_ids:
            msg = "`max_samples` must be <= n_ids={} but got value {}"
            raise ValueError(msg.format(n_ids, max_samples))
        return max_samples

    if isinstance(max_samples, Real):
        return max(round(n_ids * max_samples), 1)

def _generate_sample_indices(random_state, ids, n_ids_bootstrap):
    """
    Sample unique IDs and then expand to all associated events.
    """
    random_instance = check_random_state(random_state)
    sampled_ids = np.random.choice(ids, n_ids_bootstrap, replace=True)
    return sampled_ids

def _generate_unsampled_indices(random_state, ids, n_ids_bootstrap):
    """
    Determine unsampled IDs and then expand to all associated events.
    """
    sampled_ids = _generate_sample_indices(random_state, ids, n_ids_bootstrap)
    unsampled_ids = np.setdiff1d(ids, sampled_ids)

    # Expand these unsampled IDs to include all their associated events.
    # Again, this will depend on your data structure.
    # As an example:
    # unsampled_indices = np.concatenate([events_by_id[id] for id in unsampled_ids])

    return unsampled_ids  # or return unsampled_indices based on your data structure



from warnings import catch_warnings, simplefilter
from sklearn.utils.class_weight import compute_sample_weight

def _parallel_build_trees(
    tree,
    bootstrap,
    X,
    y,
    ids,  # New parameter: a list/array of IDs corresponding to each event in X and y
    sample_weight,
    tree_idx,
    n_trees,
    verbose=0,
    class_weight=None,
    n_ids_bootstrap=None,  # Instead of n_samples_bootstrap
):
    """
    Private function used to fit a single tree in parallel for recurrent events."""
    if verbose > 1:
        print("building tree %d of %d" % (tree_idx + 1, n_trees))

    if bootstrap:
        unique_ids = np.unique(ids)
        n_ids = len(unique_ids)

        # Generate bootstrap samples using IDs
        sampled_ids = _generate_sample_indices(
            tree.random_state, unique_ids, n_ids_bootstrap
        )

        # Expand sampled IDs to all their associated events
        indices = np.where(np.isin(ids, sampled_ids))[0]

        if sample_weight is None:
            curr_sample_weight = np.ones((X.shape[0],), dtype=np.float64)
        else:
            curr_sample_weight = sample_weight.copy()

        # Adjust the sample weight based on how many times each ID was sampled
        sample_counts_for_ids = np.bincount(np.searchsorted(unique_ids, sampled_ids), minlength=n_ids)
        curr_sample_weight *= sample_counts_for_ids[np.searchsorted(unique_ids, ids)]

        if class_weight == "subsample":
            with catch_warnings():
                simplefilter("ignore", DeprecationWarning)
                curr_sample_weight *= compute_sample_weight("auto", y, indices=indices)
        elif class_weight == "balanced_subsample":
            curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices)

        tree.fit(X[indices], y[indices], sample_weight=curr_sample_weight[indices], check_input=False)
    else:
        tree.fit(X, y, sample_weight=sample_weight, check_input=False)

    return tree


1. _get_n_samples_boostrap(n_is, max_samples)
 * Recurrent events를 위해 수정된 함수
 * 부트스트랩 샘플에 포함될 ID의 개수를 반환
2. _generate_sample_indices(random_state, ids, n_ids_bootstrap)
 * 고유한 ID들을 샘플링하고, 그 ID들과 관련된 모든 이벤트를 확장
 * 부트스트랩의 핵심 기능
3. _generate_unsampled_indices(random_state, ids, n_ids_bootstrap)
 * 샘플링되지 않은 ID를 결정하고, 이 ID와 관련된 모든 이벤트를 확장
4. _parallel_build_trees(...)
 * 병렬로 단일 트리를 구축하는 데 사용되는 주요 함수
 * 부트스트랩 방법을 사용하여 train data에서 샘플을 추출하고, 이 샘플을 사용하여 트리를 구축
 * 앙상블 모델에서 여러 트리를 동시에 훈련시키기 위함


In [None]:
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils import check_array, check_consistent_length
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_is_fitted

from sksurv.exceptions import NoComparablePairException
from sksurv.nonparametric import CensoringDistributionEstimator, SurvivalFunctionEstimator
from sksurv.util import check_y_survival


def _check_estimate_1d(estimate, test_time):
    estimate = check_array(estimate, ensure_2d=False, input_name="estimate")
    if estimate.ndim != 1:
        raise ValueError(f"Expected 1D array, got {estimate.ndim}D array instead:\narray={estimate}.\n")
    check_consistent_length(test_time, estimate)
    return estimate

def _check_inputs(event_indicator, event_time, estimate):
    check_consistent_length(event_indicator, event_time, estimate)
    event_indicator = check_array(event_indicator, ensure_2d=False, input_name="event_indicator")
    event_time = check_array(event_time, ensure_2d=False, input_name="event_time")
    estimate = _check_estimate_1d(estimate, event_time)

    if not np.issubdtype(event_indicator.dtype, np.bool_):
        raise ValueError(
            f"only boolean arrays are supported as class labels for survival analysis, got {event_indicator.dtype}"
        )

    if len(event_time) < 2:
        raise ValueError("Need a minimum of two samples")

    if not event_indicator.any():
        raise ValueError("All samples are censored")

    return event_indicator, event_time, estimate


def _check_times(test_time, times):
    times = check_array(np.atleast_1d(times), ensure_2d=False, input_name="times")
    times = np.unique(times)

    if times.max() >= test_time.max() or times.min() < test_time.min():
        raise ValueError(
            f"all times must be within follow-up time of test data: [{test_time.min()}; {test_time.max()}["
        )

    return times

def _check_estimate_2d(estimate, test_time, time_points, estimator):
    estimate = check_array(estimate, ensure_2d=False, allow_nd=False, input_name="estimate", estimator=estimator)
    time_points = _check_times(test_time, time_points)
    check_consistent_length(test_time, estimate)

    if estimate.ndim == 2 and estimate.shape[1] != time_points.shape[0]:
        raise ValueError(f"expected estimate with {time_points.shape[0]} columns, but got {estimate.shape[1]}")

    return estimate, time_points


def _iter_comparable(event_indicator, event_time, order):
    n_samples = len(event_time)
    tied_time = 0
    i = 0
    while i < n_samples - 1:
        time_i = event_time[order[i]]
        end = i + 1
        while end < n_samples and event_time[order[end]] == time_i:
            end += 1

        # check for tied event times
        event_at_same_time = event_indicator[order[i:end]]
        censored_at_same_time = ~event_at_same_time
        for j in range(i, end):
            if event_indicator[order[j]]:
                mask = np.zeros(n_samples, dtype=bool)
                mask[end:] = True
                # an event is comparable to censored samples at same time point
                mask[i:end] = censored_at_same_time
                tied_time += censored_at_same_time.sum()
                yield (j, mask, tied_time)
        i = end

def _estimate_recurrent_concordance_index(mu_oob, X, nRE):
    """
    Compute the C-index for recurrent event data.

    Parameters:
    - mu_oob: Out-of-bag predicted risk scores
    - X: Covariate data
    - nRE: Number of events for each subject up to a certain time

    Returns:
    - C-index
    """
    m = len(mu_oob)
    numerator = 0
    denominator = 0

    for i in range(m):
        for j in range(m):
            if i != j:
                min_c = min(mu_oob[i], mu_oob[j])
                if nRE[i][min_c] > nRE[j][min_c]:
                    denominator += 1
                    if mu_oob[i] > mu_oob[j]:
                        numerator += 1

    if denominator == 0:
        return 0.5  # Just random guessing

    return numerator / denominator

# Calculate the Prediction Error rate
def prediction_error_rate(cindex):
    return 1 - cindex


## C-Index

1. _check_estimate_1d, _check_estimate_2d, _check_inputs, _check_times:

  * 이 함수들은 입력 데이터의 유효성을 검사하는 유틸리티 함수
  * 주어진 입력 데이터의 일관성, 차원, 형식 등을 검사하여 데이터가 예상된 형식과 일치하는지 확인

2. _iter_comparable:
  * 주어진 이벤트 시간 및 지표에 대해 비교 가능한 샘플 조합을 반복하는 제너레이터 함수.
  * 이 함수는 생존 분석에서 두 샘플이 비교 가능한지를 결정하는 데 사용.

3. _estimate_recurrent_concordance_index:
  * 재발생 이벤트의 경우 Concordance Index (C-index)를 추정합니다.
  * 이 함수는 각 개체의 이벤트 시간, 이벤트 횟수, 그리고 각 개체에 대한 Out-of-Bag 추정치를 입력으로 받아 C-index를 계산.

4. prediction_error_rate:
  * 주어진 C-index를 기반으로 예측 오류율을 계산합니다.
  * 예측 오류율은 1 - C-index로 계산되며, 모델의 성능을 나타내는 또 다른 지표로 사용됨.

In [None]:
from sklearn.base import BaseEstimator
from sklearn.utils import check_array
from joblib import Parallel, delayed
from sklearn.utils.validation import check_is_fitted
import warnings

class RecurrentRandomForest(BaseEstimator):
    def __init__(self, n_estimators=100, max_depth=None, min_samples_split=2,
                 min_samples_leaf=1, bootstrap=True, oob_score=False, n_jobs=None,
                 random_state=None, verbose=0, warm_start=False, max_samples=None):
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.bootstrap = bootstrap
        self.oob_score = oob_score
        self.n_jobs = n_jobs
        self.random_state = random_state
        self.verbose = verbose
        self.warm_start = warm_start
        self.max_samples = max_samples

        self.estimators_ = [self._make_estimator(random_state=random_state) for _ in range(self.n_estimators)]

    def _make_estimator(self, random_state=None):
        """Make and configure a copy of the `RecurrentTree` estimator."""
        return RecurrentTree(
            max_depth=self.max_depth,
            min_samples_split=self.min_samples_split,
            min_samples_leaf=self.min_samples_leaf,
            random_state=random_state
        )

    def fit(self, X, y, sample_weight=None):
        """Build a forest of survival trees from the training set (X, y)."""
        X = self._validate_data(X, accept_sparse='csc', ensure_min_samples=2)

        # Extract recurrent events data
        ids = y['id']
        time_start = y['time_start']
        time_stop = y['time_stop']
        event = y['event']
        self.n_features_in_ = X.shape[1]

        n_samples_bootstrap = _get_n_samples_bootstrap(len(np.unique(ids)), self.max_samples)

        for tree in self.estimators_:
            if self.bootstrap:
                unique_ids = np.unique(ids)
                sampled_ids = _generate_sample_indices(tree.random_state, unique_ids, n_samples_bootstrap)
                bootstrap_indices = np.where(np.isin(ids, sampled_ids))[0]
                X_bootstrap = X[bootstrap_indices]
                ids_bootstrap = np.array(ids)[bootstrap_indices]
                time_start_bootstrap = np.array(time_start)[bootstrap_indices]
                time_stop_bootstrap = np.array(time_stop)[bootstrap_indices]
                event_bootstrap = np.array(event)[bootstrap_indices]
            else:
                X_bootstrap = X
                ids_bootstrap = ids
                time_start_bootstrap = time_start
                time_stop_bootstrap = time_stop
                event_bootstrap = event

            tree.fit(X_bootstrap, ids_bootstrap, time_start_bootstrap, time_stop_bootstrap, event_bootstrap)

        if self.oob_score:
            self._set_oob_score_and_attributes(X, y)

        return self

    def _set_oob_score_and_attributes(self, X, y):
        """Calculate out of bag predictions and score."""
        n_samples = X.shape[0]

        # Assuming y is a structured array with these keys.
        ids = y['id']
        time_start = y['time_start']
        time_stop = y['time_stop']
        event = y['event']

        predictions = np.zeros(n_samples)
        n_predictions = np.zeros(n_samples)

        n_samples_bootstrap = _get_n_samples_bootstrap(n_samples, self.max_samples)

        for estimator in self.estimators_:
            unsampled_indices = _generate_unsampled_indices(estimator.random_state, np.unique(ids), n_samples_bootstrap)
            p_estimator = np.array(estimator.predict_mean_function(X[unsampled_indices, :])).mean(axis=1)

            predictions[unsampled_indices] += p_estimator
            n_predictions[unsampled_indices] += 1

        if (n_predictions == 0).any():
            warnings.warn(
              "Some inputs do not have OOB scores. This probably means too few trees were used to compute any reliable oob estimates.",
              stacklevel=3,
            )
            n_predictions[n_predictions == 0] = 1

        predictions /= n_predictions

        # Compute the C-index
        self.oob_prediction_ = predictions
        # Assuming a method _estimate_recurrent_concordance_index exists to compute C-index
        self.oob_score_ = _estimate_recurrent_concordance_index(predictions, X, event)

    def _validate_data(self, X, accept_sparse=False, ensure_min_samples=1):
        """Validate input data."""
        return check_array(X, accept_sparse=accept_sparse, ensure_min_samples=ensure_min_samples)

    def _validate_X_predict(self, X):
        """Validate X whenever one tries to predict."""
        X = check_array(X)
        if X.shape[1] != self.n_features_in_:
            raise ValueError("Number of features of the model must match the input. Model n_features is {} and input n_features is {}."
                             .format(self.n_features_in_, X.shape[1]))
        return X

    def predict_rate_function(self, X):
        """
        Predict the nonparametric estimates of dμ(t) = ρ(t)dt for given samples using the forest.
        """
        check_is_fitted(self, "estimators_")
        X = self._validate_X_predict(X)

        rate_functions_results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, require="sharedmem")(
            delayed(tree.predict_rate_function)(X) for tree in self.estimators_
        )

        averaged_rate_functions = np.mean(rate_functions_results, axis=0)

        return averaged_rate_functions

    def predict_mean_function(self, X):
        """
        Predict the Nelson-Aalen estimator of the mean function for given samples using the forest.
        """
        check_is_fitted(self, "estimators_")
        X = self._validate_X_predict(X)

        mean_functions_results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, require="sharedmem")(
            delayed(tree.predict_mean_function)(X) for tree in self.estimators_
        )

        averaged_mean_functions = np.mean(mean_functions_results, axis=0)

        return averaged_mean_functions


## RandomForest for Recurrent Events

1. 클래스 초기화 (__init__):

  * 랜덤 포레스트의 주요 파라미터를 초기화합니다. 이러한 파라미터에는 트리의 개수(n_estimators), 최대 깊이(max_depth), 분할을 위한 최소 샘플 수(min_samples_split), 리프 노드의 최소 샘플 수(min_samples_leaf) 등이 포함됨
  * 또한, 주어진 파라미터를 기반으로 RecurrentTree 객체를 생성하여 estimators_ 리스트에 추가

2. fit 메서드:

  * 주어진 입력 데이터 X와 생존 데이터 (이벤트 지표, 시작 시간, 중지 시간)를 사용하여 랜덤 포레스트를 학습시킴
  * 각 트리는 병렬로 학습되며, 각 트리는 전체 데이터의 부트스트랩 샘플을 사용하여 학습됨
  * Out-of-bag (OOB) 점수를 계산할 경우 _set_oob_score_and_attributes 메서드를 호출하여 OOB 예측과 C-index를 계산

3. _set_oob_score_and_attributes 메서드:
  * Out-of-bag (OOB) 예측을 계산하고, 이를 기반으로 C-index를 계산
  * 이 메서드는 OOB 예측을 사용하여 모델의 성능을 추정하는 데 사용.

4. predict_rate_function 메서드:

  * 주어진 입력 데이터 X에 대한 비모수적 추정값 dμ(t)=ρ(t)dt를 예측.
  * 각 트리로부터의 비율 함수 예측을 병렬로 수집하고, 이러한 예측을 평균하여 최종 결과를 반환.

5. predict_mean_function 메서드:

  * 주어진 입력 데이터 X에 대한 Nelson-Aalen estiamator의 평균 함수를 예측.
  * 각 트리로부터의 평균 함수 예측을 병렬로 수집하고, 이러한 예측을 평균하여 최종 결과를 반환.

In [None]:
# Redefining the data
ids = [1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4]
time_start = [0, 2, 3, 0, 3, 0, 1, 2, 5, 0, 5, 8]
time_stop = [2, 3, 7, 3, 6, 1, 2, 5, 9, 5, 8, 11]
event = [1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0]
X = np.array([[45, 0],
              [45, 0],
              [45, 0],
              [52, 0],
              [52, 0],
              [65, 1],
              [65, 1],
              [65, 1],
              [65, 1],
              [53, 1],
              [53, 1],
              [53, 1]])

In [None]:
rrf = RecurrentRandomForest(n_estimators=10, max_depth=3, min_samples_leaf=5, random_state=1190)
y = {
    'id': ids,
    'time_start': time_start,
    'time_stop': time_stop,
    'event': event
}
rrf.fit(X,y)

In [None]:
# Predict the rate function and mean function for the samples
rate_function_predictions = rrf.predict_rate_function(X)
mean_function_predictions = rrf.predict_mean_function(X)

rate_function_predictions, mean_function_predictions

In [None]:
class PermutationImportance:
    def __init__(self, model, n_repeats=30, random_state=None):
        self.model = model
        self.n_repeats = n_repeats
        self.random_state = random_state

    def _compute_baseline_cindex(self, X, event, time_stop):
        predictions = self.model.predict_mean_function(X)
        # Take the mean of the predictions for each individual for the C-index computation
        mean_predictions = np.mean(predictions, axis=1)
        return _estimate_concordance_index_recurrent(time_stop, event, mean_predictions)

    def compute_importance(self, X, event, time_stop):
        baseline_cindex = self._compute_baseline_cindex(X, event, time_stop)

        rng = check_random_state(self.random_state)

        n_features = X.shape[1]
        importances = np.zeros((n_features, self.n_repeats))

        for feature in range(n_features):
            for repeat in range(self.n_repeats):
                # Copy X and shuffle one feature
                X_permuted = X.copy()
                rng.shuffle(X_permuted[:, feature])

                # Calculate c-index for permuted X
                permuted_cindex = self._compute_baseline_cindex(X_permuted, event, time_stop)

                # The importance is the drop in c-index
                importances[feature, repeat] = baseline_cindex - permuted_cindex

        return importances

    def report_importance(self, X, event, time_stop):
        importances = self.compute_importance(X, event, time_stop)

        # Compute mean and std of importances
        importance_mean = np.mean(importances, axis=1)
        importance_std = np.std(importances, axis=1)

        return importance_mean, importance_std


In [None]:
imporpter = PermutationImportance(rrf, n_repeats=30, random_state=1190)

In [None]:
importer._compute_baseline_cindex(X, event_recurrent, time_stop_recurrent)

In [None]:
importer = PermutationImportance(rrf, n_repeats=30, random_state=1190)
importance_mean, importance_std = importer.report_importance(X, event_recurrent, time_stop_recurrent)

print(importance_mean)
print(importance_std)

In [None]:
import numpy as np

def EstCstat(score, N, Cis, RE, nRE):
    den = 0
    num = 0
    for i in range(N-1):
        ID1 = Cis[i, 0]

        search_event = [j for j in range(len(RE)) if RE[j][2] <= Cis[i, 1]]

        if search_event:
            if len(search_event) == 1:
                ID_COUNT = [RE[search_event[0], 0]]
                COUNT = [1]
            else:
                REtemp = RE[search_event]
                COUNT = list(REtemp[:, 0]).count
                ID_COUNT = list(set(REtemp[:, 0]))

            nREc = [0] * N
            for idx, val in zip(ID_COUNT, COUNT):
                nREc[idx] = val

            IDpair = sorted([Cis[j, 0] for j in range(i+1, N)])
            nREc = [nREc[idx] for idx in IDpair]

            lt_obs = [1 if nRE[ID1] < val else 0 for val in nREc]
            gt_obs = [1 if nRE[ID1] > val else 0 for val in nREc]
            lt_pred = [1 if score[ID1] < score[idx] else 0 for idx in IDpair]
            gt_pred = [1 if score[ID1] > score[idx] else 0 for idx in IDpair]

            den += sum(lt_obs) + sum(gt_obs)
            num += sum([lo*lp + go*gp for lo, go, lp, gp in zip(lt_obs, gt_obs, lt_pred, gt_pred)])

    estcstat = num / den if den else 0
    return estcstat, den


# 이후에 사용할 때는 score, N, Cis, RE, nRE 값을 인자로 제공하여 함수를 호출하면 됩니다.
